Merge remote-tracking branch 'origin/main' into HEAD

This commit is contained in:
Heikki Linnakangas
2025-07-20 00:58:57 +03:00
70 changed files with 2396 additions and 1148 deletions

7
Cargo.lock generated
View File

@@ -1951,6 +1951,7 @@ dependencies = [
"diesel_derives", "diesel_derives",
"itoa", "itoa",
"serde_json", "serde_json",
"uuid",
] ]
[[package]] [[package]]
@@ -3962,7 +3963,6 @@ version = "0.1.0"
dependencies = [ dependencies = [
"ahash", "ahash",
"criterion", "criterion",
"foldhash",
"hashbrown 0.15.4", "hashbrown 0.15.4",
"libc", "libc",
"lock_api", "lock_api",
@@ -5509,7 +5509,7 @@ dependencies = [
"reqwest-tracing", "reqwest-tracing",
"rsa", "rsa",
"rstest", "rstest",
"rustc-hash 1.1.0", "rustc-hash 2.1.1",
"rustls 0.23.27", "rustls 0.23.27",
"rustls-native-certs 0.8.0", "rustls-native-certs 0.8.0",
"rustls-pemfile 2.1.1", "rustls-pemfile 2.1.1",
@@ -7138,6 +7138,7 @@ dependencies = [
"tokio-util", "tokio-util",
"tracing", "tracing",
"utils", "utils",
"uuid",
"workspace_hack", "workspace_hack",
] ]
@@ -8444,6 +8445,7 @@ dependencies = [
"tracing-error", "tracing-error",
"tracing-subscriber", "tracing-subscriber",
"tracing-utils", "tracing-utils",
"uuid",
"walkdir", "walkdir",
] ]
@@ -9063,7 +9065,6 @@ dependencies = [
"tracing-log", "tracing-log",
"tracing-subscriber", "tracing-subscriber",
"url", "url",
"uuid",
"zeroize", "zeroize",
"zstd", "zstd",
"zstd-safe", "zstd-safe",

View File

@@ -132,6 +132,7 @@ jemalloc_pprof = { version = "0.7", features = ["symbolize", "flamegraph"] }
jsonwebtoken = "9" jsonwebtoken = "9"
lasso = "0.7" lasso = "0.7"
libc = "0.2" libc = "0.2"
lock_api = "0.4.13"
md5 = "0.7.0" md5 = "0.7.0"
measured = { version = "0.0.22", features=["lasso"] } measured = { version = "0.0.22", features=["lasso"] }
measured-process = { version = "0.0.22" } measured-process = { version = "0.0.22" }
@@ -168,7 +169,7 @@ reqwest-middleware = "0.4"
reqwest-retry = "0.7" reqwest-retry = "0.7"
routerify = "3" routerify = "3"
rpds = "0.13" rpds = "0.13"
rustc-hash = "1.1.0" rustc-hash = "2.1.1"
rustls = { version = "0.23.16", default-features = false } rustls = { version = "0.23.16", default-features = false }
rustls-pemfile = "2" rustls-pemfile = "2"
rustls-pki-types = "1.11" rustls-pki-types = "1.11"

View File

@@ -1,4 +1,4 @@
use anyhow::{Context, Result, anyhow}; use anyhow::{Context, Result};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use compute_api::privilege::Privilege; use compute_api::privilege::Privilege;
use compute_api::responses::{ use compute_api::responses::{
@@ -7,7 +7,7 @@ use compute_api::responses::{
}; };
use compute_api::spec::{ use compute_api::spec::{
ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, ExtVersion, PageserverConnectionInfo, ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, ExtVersion, PageserverConnectionInfo,
PageserverShardConnectionInfo, PgIdent, PageserverProtocol, PageserverShardConnectionInfo, PageserverShardInfo, PgIdent,
}; };
use futures::StreamExt; use futures::StreamExt;
use futures::future::join_all; use futures::future::join_all;
@@ -281,53 +281,114 @@ impl ParsedSpec {
} }
} }
fn extract_pageserver_conninfo_from_guc( /// Extract PageserverConnectionInfo from a comma-separated list of libpq connection strings.
pageserver_connstring_guc: &str, ///
) -> PageserverConnectionInfo { /// This is used for backwards-compatilibity, to parse the legacye `pageserver_connstr`
PageserverConnectionInfo { /// field in the compute spec, or the 'neon.pageserver_connstring' GUC. Nowadays, the
shards: pageserver_connstring_guc /// 'pageserver_connection_info' field should be used instead.
.split(',') fn extract_pageserver_conninfo_from_connstr(
.enumerate() connstr: &str,
.map(|(i, connstr)| { stripe_size: Option<u32>,
( ) -> Result<PageserverConnectionInfo, anyhow::Error> {
i as u32, let shard_infos: Vec<_> = connstr
PageserverShardConnectionInfo { .split(',')
libpq_url: Some(connstr.to_string()), .map(|connstr| PageserverShardInfo {
grpc_url: None, pageservers: vec![PageserverShardConnectionInfo {
}, id: None,
) libpq_url: Some(connstr.to_string()),
grpc_url: None,
}],
})
.collect();
match shard_infos.len() {
0 => anyhow::bail!("empty connection string"),
1 => {
// We assume that if there's only connection string, it means "unsharded",
// rather than a sharded system with just a single shard. The latter is
// possible in principle, but we never do it.
let shard_count = ShardCount::unsharded();
let only_shard = shard_infos.first().unwrap().clone();
let shards = vec![(ShardIndex::unsharded(), only_shard)];
Ok(PageserverConnectionInfo {
shard_count,
stripe_size: None,
shards: shards.into_iter().collect(),
prefer_protocol: PageserverProtocol::Libpq,
}) })
.collect(), }
prefer_grpc: false, n => {
if stripe_size.is_none() {
anyhow::bail!("{n} shards but no stripe_size");
}
let shard_count = ShardCount(n.try_into()?);
let shards = shard_infos
.into_iter()
.enumerate()
.map(|(idx, shard_info)| {
(
ShardIndex {
shard_count,
shard_number: ShardNumber(
idx.try_into().expect("shard number fits in u8"),
),
},
shard_info,
)
})
.collect();
Ok(PageserverConnectionInfo {
shard_count,
stripe_size,
shards,
prefer_protocol: PageserverProtocol::Libpq,
})
}
} }
} }
impl TryFrom<ComputeSpec> for ParsedSpec { impl TryFrom<ComputeSpec> for ParsedSpec {
type Error = String; type Error = anyhow::Error;
fn try_from(spec: ComputeSpec) -> Result<Self, String> { fn try_from(spec: ComputeSpec) -> Result<Self, anyhow::Error> {
// Extract the options from the spec file that are needed to connect to // Extract the options from the spec file that are needed to connect to
// the storage system. // the storage system.
// //
// For backwards-compatibility, the top-level fields in the spec file // In compute specs generated by old control plane versions, the spec file might
// may be empty. In that case, we need to dig them from the GUCs in the // be missing the `pageserver_connection_info` field. In that case, we need to dig
// cluster.settings field. // the pageserver connection info from the `pageserver_connstr` field instead, or
let pageserver_conninfo = match &spec.pageserver_connection_info { // if that's missing too, from the GUC in the cluster.settings field.
Some(x) => x.clone(), let mut pageserver_conninfo = spec.pageserver_connection_info.clone();
None => { if pageserver_conninfo.is_none() {
if let Some(guc) = spec.cluster.settings.find("neon.pageserver_connstring") { if let Some(pageserver_connstr_field) = &spec.pageserver_connstring {
extract_pageserver_conninfo_from_guc(&guc) pageserver_conninfo = Some(extract_pageserver_conninfo_from_connstr(
} else { pageserver_connstr_field,
return Err("pageserver connstr should be provided".to_string()); spec.shard_stripe_size,
} )?);
} }
}; }
if pageserver_conninfo.is_none() {
if let Some(guc) = spec.cluster.settings.find("neon.pageserver_connstring") {
let stripe_size = if let Some(guc) = spec.cluster.settings.find("neon.stripe_size")
{
Some(u32::from_str(&guc)?)
} else {
None
};
pageserver_conninfo =
Some(extract_pageserver_conninfo_from_connstr(&guc, stripe_size)?);
}
}
let pageserver_conninfo = pageserver_conninfo.ok_or(anyhow::anyhow!(
"pageserver connection information should be provided"
))?;
// Similarly for safekeeper connection strings
let safekeeper_connstrings = if spec.safekeeper_connstrings.is_empty() { let safekeeper_connstrings = if spec.safekeeper_connstrings.is_empty() {
if matches!(spec.mode, ComputeMode::Primary) { if matches!(spec.mode, ComputeMode::Primary) {
spec.cluster spec.cluster
.settings .settings
.find("neon.safekeepers") .find("neon.safekeepers")
.ok_or("safekeeper connstrings should be provided")? .ok_or(anyhow::anyhow!("safekeeper connstrings should be provided"))?
.split(',') .split(',')
.map(|str| str.to_string()) .map(|str| str.to_string())
.collect() .collect()
@@ -342,22 +403,22 @@ impl TryFrom<ComputeSpec> for ParsedSpec {
let tenant_id: TenantId = if let Some(tenant_id) = spec.tenant_id { let tenant_id: TenantId = if let Some(tenant_id) = spec.tenant_id {
tenant_id tenant_id
} else { } else {
spec.cluster let guc = spec
.cluster
.settings .settings
.find("neon.tenant_id") .find("neon.tenant_id")
.ok_or("tenant id should be provided") .ok_or(anyhow::anyhow!("tenant id should be provided"))?;
.map(|s| TenantId::from_str(&s))? TenantId::from_str(&guc).context("invalid tenant id")?
.or(Err("invalid tenant id"))?
}; };
let timeline_id: TimelineId = if let Some(timeline_id) = spec.timeline_id { let timeline_id: TimelineId = if let Some(timeline_id) = spec.timeline_id {
timeline_id timeline_id
} else { } else {
spec.cluster let guc = spec
.cluster
.settings .settings
.find("neon.timeline_id") .find("neon.timeline_id")
.ok_or("timeline id should be provided") .ok_or(anyhow::anyhow!("timeline id should be provided"))?;
.map(|s| TimelineId::from_str(&s))? TimelineId::from_str(&guc).context(anyhow::anyhow!("invalid timeline id"))?
.or(Err("invalid timeline id"))?
}; };
let endpoint_storage_addr: Option<String> = spec let endpoint_storage_addr: Option<String> = spec
@@ -381,7 +442,7 @@ impl TryFrom<ComputeSpec> for ParsedSpec {
}; };
// Now check validity of the parsed specification // Now check validity of the parsed specification
res.validate()?; res.validate().map_err(anyhow::Error::msg)?;
Ok(res) Ok(res)
} }
} }
@@ -461,7 +522,7 @@ impl ComputeNode {
let mut new_state = ComputeState::new(); let mut new_state = ComputeState::new();
if let Some(spec) = config.spec { if let Some(spec) = config.spec {
let pspec = ParsedSpec::try_from(spec).map_err(|msg| anyhow!(msg))?; let pspec = ParsedSpec::try_from(spec).map_err(|msg| anyhow::anyhow!(msg))?;
new_state.pspec = Some(pspec); new_state.pspec = Some(pspec);
} }
@@ -1069,10 +1130,9 @@ impl ComputeNode {
let spec = compute_state.pspec.as_ref().expect("spec must be set"); let spec = compute_state.pspec.as_ref().expect("spec must be set");
let started = Instant::now(); let started = Instant::now();
let (connected, size) = if spec.pageserver_conninfo.prefer_grpc { let (connected, size) = match spec.pageserver_conninfo.prefer_protocol {
self.try_get_basebackup_grpc(spec, lsn)? PageserverProtocol::Grpc => self.try_get_basebackup_grpc(spec, lsn)?,
} else { PageserverProtocol::Libpq => self.try_get_basebackup_libpq(spec, lsn)?,
self.try_get_basebackup_libpq(spec, lsn)?
}; };
self.fix_zenith_signal_neon_signal()?; self.fix_zenith_signal_neon_signal()?;
@@ -1110,24 +1170,32 @@ impl ComputeNode {
/// Fetches a basebackup via gRPC. The connstring must use grpc://. Returns the timestamp when /// Fetches a basebackup via gRPC. The connstring must use grpc://. Returns the timestamp when
/// the connection was established, and the (compressed) size of the basebackup. /// the connection was established, and the (compressed) size of the basebackup.
fn try_get_basebackup_grpc(&self, spec: &ParsedSpec, lsn: Lsn) -> Result<(Instant, usize)> { fn try_get_basebackup_grpc(&self, spec: &ParsedSpec, lsn: Lsn) -> Result<(Instant, usize)> {
let shard0_index = ShardIndex {
shard_number: ShardNumber(0),
shard_count: spec.pageserver_conninfo.shard_count,
};
let shard0 = spec let shard0 = spec
.pageserver_conninfo .pageserver_conninfo
.shards .shards
.get(&0) .get(&shard0_index)
.expect("shard 0 connection info missing"); .ok_or_else(|| {
let shard0_url = shard0.grpc_url.clone().expect("no grpc_url for shard 0"); anyhow::anyhow!("shard connection info missing for shard {}", shard0_index)
})?;
let shard_index = match spec.pageserver_conninfo.shards.len() as u8 { let pageserver = shard0
0 | 1 => ShardIndex::unsharded(), .pageservers
count => ShardIndex::new(ShardNumber(0), ShardCount(count)), .first()
}; .expect("must have at least one pageserver");
let shard0_url = pageserver
.grpc_url
.clone()
.expect("no grpc_url for shard 0");
let (reader, connected) = tokio::runtime::Handle::current().block_on(async move { let (reader, connected) = tokio::runtime::Handle::current().block_on(async move {
let mut client = page_api::Client::connect( let mut client = page_api::Client::connect(
shard0_url, shard0_url,
spec.tenant_id, spec.tenant_id,
spec.timeline_id, spec.timeline_id,
shard_index, shard0_index,
spec.storage_auth_token.clone(), spec.storage_auth_token.clone(),
None, // NB: base backups use payload compression None, // NB: base backups use payload compression
) )
@@ -1159,12 +1227,25 @@ impl ComputeNode {
/// Fetches a basebackup via libpq. The connstring must use postgresql://. Returns the timestamp /// Fetches a basebackup via libpq. The connstring must use postgresql://. Returns the timestamp
/// when the connection was established, and the (compressed) size of the basebackup. /// when the connection was established, and the (compressed) size of the basebackup.
fn try_get_basebackup_libpq(&self, spec: &ParsedSpec, lsn: Lsn) -> Result<(Instant, usize)> { fn try_get_basebackup_libpq(&self, spec: &ParsedSpec, lsn: Lsn) -> Result<(Instant, usize)> {
let shard0_index = ShardIndex {
shard_number: ShardNumber(0),
shard_count: spec.pageserver_conninfo.shard_count,
};
let shard0 = spec let shard0 = spec
.pageserver_conninfo .pageserver_conninfo
.shards .shards
.get(&0) .get(&shard0_index)
.expect("shard 0 connection info missing"); .ok_or_else(|| {
let shard0_connstr = shard0.libpq_url.clone().expect("no libpq_url for shard 0"); anyhow::anyhow!("shard connection info missing for shard {}", shard0_index)
})?;
let pageserver = shard0
.pageservers
.first()
.expect("must have at least one pageserver");
let shard0_connstr = pageserver
.libpq_url
.clone()
.expect("no libpq_url for shard 0");
let mut config = postgres::Config::from_str(&shard0_connstr)?; let mut config = postgres::Config::from_str(&shard0_connstr)?;
// Use the storage auth token from the config file, if given. // Use the storage auth token from the config file, if given.
@@ -2128,7 +2209,7 @@ LIMIT 100",
self.params self.params
.remote_ext_base_url .remote_ext_base_url
.as_ref() .as_ref()
.ok_or(DownloadError::BadInput(anyhow!( .ok_or(DownloadError::BadInput(anyhow::anyhow!(
"Remote extensions storage is not configured", "Remote extensions storage is not configured",
)))?; )))?;
@@ -2324,7 +2405,7 @@ LIMIT 100",
let remote_extensions = spec let remote_extensions = spec
.remote_extensions .remote_extensions
.as_ref() .as_ref()
.ok_or(anyhow!("Remote extensions are not configured"))?; .ok_or(anyhow::anyhow!("Remote extensions are not configured"))?;
info!("parse shared_preload_libraries from spec.cluster.settings"); info!("parse shared_preload_libraries from spec.cluster.settings");
let mut libs_vec = Vec::new(); let mut libs_vec = Vec::new();
@@ -2472,14 +2553,31 @@ LIMIT 100",
pub fn spawn_lfc_offload_task(self: &Arc<Self>, interval: Duration) { pub fn spawn_lfc_offload_task(self: &Arc<Self>, interval: Duration) {
self.terminate_lfc_offload_task(); self.terminate_lfc_offload_task();
let secs = interval.as_secs(); let secs = interval.as_secs();
info!("spawning lfc offload worker with {secs}s interval");
let this = self.clone(); let this = self.clone();
info!("spawning LFC offload worker with {secs}s interval");
let handle = spawn(async move { let handle = spawn(async move {
let mut interval = time::interval(interval); let mut interval = time::interval(interval);
interval.tick().await; // returns immediately interval.tick().await; // returns immediately
loop { loop {
interval.tick().await; interval.tick().await;
this.offload_lfc_async().await;
let prewarm_state = this.state.lock().unwrap().lfc_prewarm_state.clone();
// Do not offload LFC state if we are currently prewarming or any issue occurred.
// If we'd do that, we might override the LFC state in endpoint storage with some
// incomplete state. Imagine a situation:
// 1. Endpoint started with `autoprewarm: true`
// 2. While prewarming is not completed, we upload the new incomplete state
// 3. Compute gets interrupted and restarts
// 4. We start again and try to prewarm with the state from 2. instead of the previous complete state
if matches!(
prewarm_state,
LfcPrewarmState::Completed
| LfcPrewarmState::NotPrewarmed
| LfcPrewarmState::Skipped
) {
this.offload_lfc_async().await;
}
} }
}); });
*self.lfc_offload_task.lock().unwrap() = Some(handle); *self.lfc_offload_task.lock().unwrap() = Some(handle);
@@ -2631,7 +2729,10 @@ mod tests {
match ParsedSpec::try_from(spec.clone()) { match ParsedSpec::try_from(spec.clone()) {
Ok(_p) => panic!("Failed to detect duplicate entry"), Ok(_p) => panic!("Failed to detect duplicate entry"),
Err(e) => assert!(e.starts_with("duplicate entry in safekeeper_connstrings:")), Err(e) => assert!(
e.to_string()
.starts_with("duplicate entry in safekeeper_connstrings:")
),
}; };
} }
} }

View File

@@ -89,7 +89,7 @@ impl ComputeNode {
self.state.lock().unwrap().lfc_offload_state.clone() self.state.lock().unwrap().lfc_offload_state.clone()
} }
/// If there is a prewarm request ongoing, return false, true otherwise /// If there is a prewarm request ongoing, return `false`, `true` otherwise.
pub fn prewarm_lfc(self: &Arc<Self>, from_endpoint: Option<String>) -> bool { pub fn prewarm_lfc(self: &Arc<Self>, from_endpoint: Option<String>) -> bool {
{ {
let state = &mut self.state.lock().unwrap().lfc_prewarm_state; let state = &mut self.state.lock().unwrap().lfc_prewarm_state;
@@ -101,15 +101,25 @@ impl ComputeNode {
let cloned = self.clone(); let cloned = self.clone();
spawn(async move { spawn(async move {
let Err(err) = cloned.prewarm_impl(from_endpoint).await else { let state = match cloned.prewarm_impl(from_endpoint).await {
cloned.state.lock().unwrap().lfc_prewarm_state = LfcPrewarmState::Completed; Ok(true) => LfcPrewarmState::Completed,
return; Ok(false) => {
}; info!(
crate::metrics::LFC_PREWARM_ERRORS.inc(); "skipping LFC prewarm because LFC state is not found in endpoint storage"
error!(%err, "prewarming lfc"); );
cloned.state.lock().unwrap().lfc_prewarm_state = LfcPrewarmState::Failed { LfcPrewarmState::Skipped
error: err.to_string(), }
Err(err) => {
crate::metrics::LFC_PREWARM_ERRORS.inc();
error!(%err, "could not prewarm LFC");
LfcPrewarmState::Failed {
error: err.to_string(),
}
}
}; };
cloned.state.lock().unwrap().lfc_prewarm_state = state;
}); });
true true
} }
@@ -120,15 +130,21 @@ impl ComputeNode {
EndpointStoragePair::from_spec_and_endpoint(state.pspec.as_ref().unwrap(), from_endpoint) EndpointStoragePair::from_spec_and_endpoint(state.pspec.as_ref().unwrap(), from_endpoint)
} }
async fn prewarm_impl(&self, from_endpoint: Option<String>) -> Result<()> { /// Request LFC state from endpoint storage and load corresponding pages into Postgres.
/// Returns a result with `false` if the LFC state is not found in endpoint storage.
async fn prewarm_impl(&self, from_endpoint: Option<String>) -> Result<bool> {
let EndpointStoragePair { url, token } = self.endpoint_storage_pair(from_endpoint)?; let EndpointStoragePair { url, token } = self.endpoint_storage_pair(from_endpoint)?;
info!(%url, "requesting LFC state from endpoint storage");
info!(%url, "requesting LFC state from endpoint storage");
let request = Client::new().get(&url).bearer_auth(token); let request = Client::new().get(&url).bearer_auth(token);
let res = request.send().await.context("querying endpoint storage")?; let res = request.send().await.context("querying endpoint storage")?;
let status = res.status(); let status = res.status();
if status != StatusCode::OK { match status {
bail!("{status} querying endpoint storage") StatusCode::OK => (),
StatusCode::NOT_FOUND => {
return Ok(false);
}
_ => bail!("{status} querying endpoint storage"),
} }
let mut uncompressed = Vec::new(); let mut uncompressed = Vec::new();
@@ -141,7 +157,8 @@ impl ComputeNode {
.await .await
.context("decoding LFC state")?; .context("decoding LFC state")?;
let uncompressed_len = uncompressed.len(); let uncompressed_len = uncompressed.len();
info!(%url, "downloaded LFC state, uncompressed size {uncompressed_len}, loading into postgres");
info!(%url, "downloaded LFC state, uncompressed size {uncompressed_len}, loading into Postgres");
ComputeNode::get_maintenance_client(&self.tokio_conn_conf) ComputeNode::get_maintenance_client(&self.tokio_conn_conf)
.await .await
@@ -149,7 +166,9 @@ impl ComputeNode {
.query_one("select neon.prewarm_local_cache($1)", &[&uncompressed]) .query_one("select neon.prewarm_local_cache($1)", &[&uncompressed])
.await .await
.context("loading LFC state into postgres") .context("loading LFC state into postgres")
.map(|_| ()) .map(|_| ())?;
Ok(true)
} }
/// If offload request is ongoing, return false, true otherwise /// If offload request is ongoing, return false, true otherwise
@@ -177,12 +196,14 @@ impl ComputeNode {
async fn offload_lfc_with_state_update(&self) { async fn offload_lfc_with_state_update(&self) {
crate::metrics::LFC_OFFLOADS.inc(); crate::metrics::LFC_OFFLOADS.inc();
let Err(err) = self.offload_lfc_impl().await else { let Err(err) = self.offload_lfc_impl().await else {
self.state.lock().unwrap().lfc_offload_state = LfcOffloadState::Completed; self.state.lock().unwrap().lfc_offload_state = LfcOffloadState::Completed;
return; return;
}; };
crate::metrics::LFC_OFFLOAD_ERRORS.inc(); crate::metrics::LFC_OFFLOAD_ERRORS.inc();
error!(%err, "offloading lfc"); error!(%err, "could not offload LFC state to endpoint storage");
self.state.lock().unwrap().lfc_offload_state = LfcOffloadState::Failed { self.state.lock().unwrap().lfc_offload_state = LfcOffloadState::Failed {
error: err.to_string(), error: err.to_string(),
}; };
@@ -190,7 +211,7 @@ impl ComputeNode {
async fn offload_lfc_impl(&self) -> Result<()> { async fn offload_lfc_impl(&self) -> Result<()> {
let EndpointStoragePair { url, token } = self.endpoint_storage_pair(None)?; let EndpointStoragePair { url, token } = self.endpoint_storage_pair(None)?;
info!(%url, "requesting LFC state from postgres"); info!(%url, "requesting LFC state from Postgres");
let mut compressed = Vec::new(); let mut compressed = Vec::new();
ComputeNode::get_maintenance_client(&self.tokio_conn_conf) ComputeNode::get_maintenance_client(&self.tokio_conn_conf)
@@ -205,13 +226,17 @@ impl ComputeNode {
.read_to_end(&mut compressed) .read_to_end(&mut compressed)
.await .await
.context("compressing LFC state")?; .context("compressing LFC state")?;
let compressed_len = compressed.len(); let compressed_len = compressed.len();
info!(%url, "downloaded LFC state, compressed size {compressed_len}, writing to endpoint storage"); info!(%url, "downloaded LFC state, compressed size {compressed_len}, writing to endpoint storage");
let request = Client::new().put(url).bearer_auth(token).body(compressed); let request = Client::new().put(url).bearer_auth(token).body(compressed);
match request.send().await { match request.send().await {
Ok(res) if res.status() == StatusCode::OK => Ok(()), Ok(res) if res.status() == StatusCode::OK => Ok(()),
Ok(res) => bail!("Error writing to endpoint storage: {}", res.status()), Ok(res) => bail!(
"Request to endpoint storage failed with status: {}",
res.status()
),
Err(err) => Err(err).context("writing to endpoint storage"), Err(err) => Err(err).context("writing to endpoint storage"),
} }
} }

View File

@@ -15,6 +15,8 @@ use crate::pg_helpers::{
}; };
use crate::tls::{self, SERVER_CRT, SERVER_KEY}; use crate::tls::{self, SERVER_CRT, SERVER_KEY};
use utils::shard::{ShardIndex, ShardNumber};
/// Check that `line` is inside a text file and put it there if it is not. /// Check that `line` is inside a text file and put it there if it is not.
/// Create file if it doesn't exist. /// Create file if it doesn't exist.
pub fn line_in_file(path: &Path, line: &str) -> Result<bool> { pub fn line_in_file(path: &Path, line: &str) -> Result<bool> {
@@ -58,24 +60,53 @@ pub fn write_postgres_conf(
// Add options for connecting to storage // Add options for connecting to storage
writeln!(file, "# Neon storage settings")?; writeln!(file, "# Neon storage settings")?;
writeln!(file)?;
if let Some(conninfo) = &spec.pageserver_connection_info { if let Some(conninfo) = &spec.pageserver_connection_info {
// Stripe size GUC should be defined prior to connection string
if let Some(stripe_size) = conninfo.stripe_size {
writeln!(
file,
"# from compute spec's pageserver_conninfo.stripe_size field"
)?;
writeln!(file, "neon.stripe_size={stripe_size}")?;
}
let mut libpq_urls: Option<Vec<String>> = Some(Vec::new()); let mut libpq_urls: Option<Vec<String>> = Some(Vec::new());
let mut grpc_urls: Option<Vec<String>> = Some(Vec::new()); let mut grpc_urls: Option<Vec<String>> = Some(Vec::new());
let num_shards = if conninfo.shard_count.0 == 0 {
1 // unsharded, treat it as a single shard
} else {
conninfo.shard_count.0
};
for shardno in 0..conninfo.shards.len() { for shard_number in 0..num_shards {
let info = conninfo.shards.get(&(shardno as u32)).ok_or_else(|| { let shard_index = ShardIndex {
anyhow::anyhow!("shard {shardno} missing from pageserver_connection_info shard map") shard_number: ShardNumber(shard_number),
shard_count: conninfo.shard_count,
};
let info = conninfo.shards.get(&shard_index).ok_or_else(|| {
anyhow::anyhow!(
"shard {shard_index} missing from pageserver_connection_info shard map"
)
})?; })?;
if let Some(url) = &info.libpq_url { let first_pageserver = info
.pageservers
.first()
.expect("must have at least one pageserver");
// Add the libpq URL to the array, or if the URL is missing, reset the array
// forgetting any previous entries. All servers must have a libpq URL, or none
// at all.
if let Some(url) = &first_pageserver.libpq_url {
if let Some(ref mut urls) = libpq_urls { if let Some(ref mut urls) = libpq_urls {
urls.push(url.clone()); urls.push(url.clone());
} }
} else { } else {
libpq_urls = None libpq_urls = None
} }
if let Some(url) = &info.grpc_url { // Similarly for gRPC URLs
if let Some(url) = &first_pageserver.grpc_url {
if let Some(ref mut urls) = grpc_urls { if let Some(ref mut urls) = grpc_urls {
urls.push(url.clone()); urls.push(url.clone());
} }
@@ -84,6 +115,10 @@ pub fn write_postgres_conf(
} }
} }
if let Some(libpq_urls) = libpq_urls { if let Some(libpq_urls) = libpq_urls {
writeln!(
file,
"# derived from compute spec's pageserver_conninfo field"
)?;
writeln!( writeln!(
file, file,
"neon.pageserver_connstring={}", "neon.pageserver_connstring={}",
@@ -93,6 +128,10 @@ pub fn write_postgres_conf(
writeln!(file, "# no neon.pageserver_connstring")?; writeln!(file, "# no neon.pageserver_connstring")?;
} }
if let Some(grpc_urls) = grpc_urls { if let Some(grpc_urls) = grpc_urls {
writeln!(
file,
"# derived from compute spec's pageserver_conninfo field"
)?;
writeln!( writeln!(
file, file,
"neon.pageserver_grpc_urls={}", "neon.pageserver_grpc_urls={}",
@@ -101,11 +140,19 @@ pub fn write_postgres_conf(
} else { } else {
writeln!(file, "# no neon.pageserver_grpc_urls")?; writeln!(file, "# no neon.pageserver_grpc_urls")?;
} }
} else {
// Stripe size GUC should be defined prior to connection string
if let Some(stripe_size) = spec.shard_stripe_size {
writeln!(file, "# from compute spec's shard_stripe_size field")?;
writeln!(file, "neon.stripe_size={stripe_size}")?;
}
if let Some(s) = &spec.pageserver_connstring {
writeln!(file, "# from compute spec's pageserver_connstring field")?;
writeln!(file, "neon.pageserver_connstring={}", escape_conf_value(s))?;
}
} }
if let Some(stripe_size) = spec.shard_stripe_size {
writeln!(file, "neon.stripe_size={stripe_size}")?;
}
if !spec.safekeeper_connstrings.is_empty() { if !spec.safekeeper_connstrings.is_empty() {
let mut neon_safekeepers_value = String::new(); let mut neon_safekeepers_value = String::new();
tracing::info!( tracing::info!(

View File

@@ -613,11 +613,11 @@ components:
- skipped - skipped
properties: properties:
status: status:
description: Lfc prewarm status description: LFC prewarm status
enum: [not_prewarmed, prewarming, completed, failed] enum: [not_prewarmed, prewarming, completed, failed, skipped]
type: string type: string
error: error:
description: Lfc prewarm error, if any description: LFC prewarm error, if any
type: string type: string
total: total:
description: Total pages processed description: Total pages processed
@@ -635,11 +635,11 @@ components:
- status - status
properties: properties:
status: status:
description: Lfc offload status description: LFC offload status
enum: [not_offloaded, offloading, completed, failed] enum: [not_offloaded, offloading, completed, failed]
type: string type: string
error: error:
description: Lfc offload error, if any description: LFC offload error, if any
type: string type: string
PromoteState: PromoteState:

View File

@@ -4,13 +4,13 @@ use std::thread;
use std::time::{Duration, SystemTime}; use std::time::{Duration, SystemTime};
use anyhow::{Result, bail}; use anyhow::{Result, bail};
use compute_api::spec::{ComputeMode, PageserverConnectionInfo}; use compute_api::spec::{ComputeMode, PageserverConnectionInfo, PageserverProtocol};
use pageserver_page_api as page_api; use pageserver_page_api as page_api;
use postgres::{NoTls, SimpleQueryMessage}; use postgres::{NoTls, SimpleQueryMessage};
use tracing::{info, warn}; use tracing::{info, warn};
use utils::id::{TenantId, TimelineId}; use utils::id::{TenantId, TimelineId};
use utils::lsn::Lsn; use utils::lsn::Lsn;
use utils::shard::{ShardCount, ShardNumber, TenantShardId}; use utils::shard::TenantShardId;
use crate::compute::ComputeNode; use crate::compute::ComputeNode;
@@ -116,37 +116,38 @@ fn try_acquire_lsn_lease(
timeline_id: TimelineId, timeline_id: TimelineId,
lsn: Lsn, lsn: Lsn,
) -> Result<Option<SystemTime>> { ) -> Result<Option<SystemTime>> {
let shard_count = conninfo.shards.len();
let mut leases = Vec::new(); let mut leases = Vec::new();
for (shard_number, shard) in conninfo.shards.into_iter() { for (shard_index, shard) in conninfo.shards.into_iter() {
let tenant_shard_id = match shard_count { let tenant_shard_id = TenantShardId {
0 | 1 => TenantShardId::unsharded(tenant_id), tenant_id,
shard_count => TenantShardId { shard_number: shard_index.shard_number,
tenant_id, shard_count: shard_index.shard_count,
shard_number: ShardNumber(shard_number as u8),
shard_count: ShardCount::new(shard_count as u8),
},
}; };
let lease = if conninfo.prefer_grpc { // XXX: If there are more than pageserver for the one shard, do we need to get a
acquire_lsn_lease_grpc( // leas on all of them? Currently, that's what we assume, but this is hypothetical
&shard.grpc_url.unwrap(), // as of this writing, as we never pass the info for more than one pageserver per
auth, // shard.
tenant_shard_id, for pageserver in shard.pageservers {
timeline_id, let lease = match conninfo.prefer_protocol {
lsn, PageserverProtocol::Grpc => acquire_lsn_lease_grpc(
)? &pageserver.grpc_url.unwrap(),
} else { auth,
acquire_lsn_lease_libpq( tenant_shard_id,
&shard.libpq_url.unwrap(), timeline_id,
auth, lsn,
tenant_shard_id, )?,
timeline_id, PageserverProtocol::Libpq => acquire_lsn_lease_libpq(
lsn, &pageserver.libpq_url.unwrap(),
)? auth,
}; tenant_shard_id,
leases.push(lease); timeline_id,
lsn,
)?,
};
leases.push(lease);
}
} }
Ok(leases.into_iter().min().flatten()) Ok(leases.into_iter().min().flatten())

View File

@@ -8,10 +8,10 @@ code changes locally, but not suitable for running production systems.
## Example: Start with Postgres 16 ## Example: Start with Postgres 16
To create and start a local development environment with Postgres 16, you will need to provide `--pg-version` flag to 3 of the start-up commands. To create and start a local development environment with Postgres 16, you will need to provide `--pg-version` flag to 2 of the start-up commands.
```shell ```shell
cargo neon init --pg-version 16 cargo neon init
cargo neon start cargo neon start
cargo neon tenant create --set-default --pg-version 16 cargo neon tenant create --set-default --pg-version 16
cargo neon endpoint create main --pg-version 16 cargo neon endpoint create main --pg-version 16

View File

@@ -16,9 +16,14 @@ use std::time::Duration;
use anyhow::{Context, Result, anyhow, bail}; use anyhow::{Context, Result, anyhow, bail};
use clap::Parser; use clap::Parser;
use compute_api::requests::ComputeClaimsScope; use compute_api::requests::ComputeClaimsScope;
use compute_api::spec::{ComputeMode, PageserverConnectionInfo, PageserverShardConnectionInfo}; use compute_api::spec::{
ComputeMode, PageserverConnectionInfo, PageserverProtocol, PageserverShardInfo,
};
use control_plane::broker::StorageBroker; use control_plane::broker::StorageBroker;
use control_plane::endpoint::{ComputeControlPlane, EndpointTerminateMode}; use control_plane::endpoint::{ComputeControlPlane, EndpointTerminateMode};
use control_plane::endpoint::{
pageserver_conf_to_shard_conn_info, tenant_locate_response_to_conn_info,
};
use control_plane::endpoint_storage::{ENDPOINT_STORAGE_DEFAULT_ADDR, EndpointStorage}; use control_plane::endpoint_storage::{ENDPOINT_STORAGE_DEFAULT_ADDR, EndpointStorage};
use control_plane::local_env; use control_plane::local_env;
use control_plane::local_env::{ use control_plane::local_env::{
@@ -44,7 +49,6 @@ use pageserver_api::models::{
}; };
use pageserver_api::shard::{DEFAULT_STRIPE_SIZE, ShardCount, ShardStripeSize, TenantShardId}; use pageserver_api::shard::{DEFAULT_STRIPE_SIZE, ShardCount, ShardStripeSize, TenantShardId};
use postgres_backend::AuthType; use postgres_backend::AuthType;
use postgres_connection::parse_host_port;
use safekeeper_api::membership::{SafekeeperGeneration, SafekeeperId}; use safekeeper_api::membership::{SafekeeperGeneration, SafekeeperId};
use safekeeper_api::{ use safekeeper_api::{
DEFAULT_HTTP_LISTEN_PORT as DEFAULT_SAFEKEEPER_HTTP_PORT, DEFAULT_HTTP_LISTEN_PORT as DEFAULT_SAFEKEEPER_HTTP_PORT,
@@ -52,11 +56,11 @@ use safekeeper_api::{
}; };
use storage_broker::DEFAULT_LISTEN_ADDR as DEFAULT_BROKER_ADDR; use storage_broker::DEFAULT_LISTEN_ADDR as DEFAULT_BROKER_ADDR;
use tokio::task::JoinSet; use tokio::task::JoinSet;
use url::Host;
use utils::auth::{Claims, Scope}; use utils::auth::{Claims, Scope};
use utils::id::{NodeId, TenantId, TenantTimelineId, TimelineId}; use utils::id::{NodeId, TenantId, TenantTimelineId, TimelineId};
use utils::lsn::Lsn; use utils::lsn::Lsn;
use utils::project_git_version; use utils::project_git_version;
use utils::shard::ShardIndex;
// Default id of a safekeeper node, if not specified on the command line. // Default id of a safekeeper node, if not specified on the command line.
const DEFAULT_SAFEKEEPER_ID: NodeId = NodeId(1); const DEFAULT_SAFEKEEPER_ID: NodeId = NodeId(1);
@@ -1521,74 +1525,56 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
)?; )?;
} }
let (shards, stripe_size) = if let Some(ps_id) = pageserver_id { let prefer_protocol = if endpoint.grpc {
let conf = env.get_pageserver_conf(ps_id).unwrap(); PageserverProtocol::Grpc
let libpq_url = Some({ } else {
let (host, port) = parse_host_port(&conf.listen_pg_addr)?; PageserverProtocol::Libpq
let port = port.unwrap_or(5432); };
format!("postgres://no_user@{host}:{port}")
});
let grpc_url = if let Some(grpc_addr) = &conf.listen_grpc_addr {
let (host, port) = parse_host_port(grpc_addr)?;
let port = port.unwrap_or(DEFAULT_PAGESERVER_GRPC_PORT);
Some(format!("grpc://no_user@{host}:{port}"))
} else {
None
};
let pageserver = PageserverShardConnectionInfo {
libpq_url,
grpc_url,
};
let mut pageserver_conninfo = if let Some(ps_id) = pageserver_id {
let conf = env.get_pageserver_conf(ps_id).unwrap();
let ps_conninfo = pageserver_conf_to_shard_conn_info(conf)?;
let shard_info = PageserverShardInfo {
pageservers: vec![ps_conninfo],
};
// If caller is telling us what pageserver to use, this is not a tenant which is // If caller is telling us what pageserver to use, this is not a tenant which is
// fully managed by storage controller, therefore not sharded. // fully managed by storage controller, therefore not sharded.
(vec![(0, pageserver)], DEFAULT_STRIPE_SIZE) let shards: HashMap<_, _> = vec![(ShardIndex::unsharded(), shard_info)]
.into_iter()
.collect();
PageserverConnectionInfo {
shard_count: ShardCount(0),
stripe_size: None,
shards,
prefer_protocol,
}
} else { } else {
// Look up the currently attached location of the tenant, and its striping metadata, // Look up the currently attached location of the tenant, and its striping metadata,
// to pass these on to postgres. // to pass these on to postgres.
let storage_controller = StorageController::from_env(env); let storage_controller = StorageController::from_env(env);
let locate_result = storage_controller.tenant_locate(endpoint.tenant_id).await?; let locate_result = storage_controller.tenant_locate(endpoint.tenant_id).await?;
let shards = futures::future::try_join_all(locate_result.shards.into_iter().map( assert!(!locate_result.shards.is_empty());
|shard| async move {
if let ComputeMode::Static(lsn) = endpoint.mode { // Initialize LSN leases for static computes.
// Initialize LSN leases for static computes. if let ComputeMode::Static(lsn) = endpoint.mode {
futures::future::try_join_all(locate_result.shards.iter().map(
|shard| async move {
let conf = env.get_pageserver_conf(shard.node_id).unwrap(); let conf = env.get_pageserver_conf(shard.node_id).unwrap();
let pageserver = PageServerNode::from_env(env, conf); let pageserver = PageServerNode::from_env(env, conf);
pageserver pageserver
.http_client .http_client
.timeline_init_lsn_lease(shard.shard_id, endpoint.timeline_id, lsn) .timeline_init_lsn_lease(shard.shard_id, endpoint.timeline_id, lsn)
.await?; .await
} },
))
.await?;
}
let libpq_host = Host::parse(&shard.listen_pg_addr)?; tenant_locate_response_to_conn_info(&locate_result)?
let libpq_port = shard.listen_pg_port;
let libpq_url =
Some(format!("postgres://no_user@{libpq_host}:{libpq_port}"));
let grpc_url = if let Some(grpc_host) = shard.listen_grpc_addr {
let grpc_port = shard.listen_grpc_port.expect("no gRPC port");
Some(format!("grpc://no_user@{grpc_host}:{grpc_port}"))
} else {
None
};
let pageserver = PageserverShardConnectionInfo {
libpq_url,
grpc_url,
};
anyhow::Ok((shard.shard_id.shard_number.0 as u32, pageserver))
},
))
.await?;
let stripe_size = locate_result.shard_params.stripe_size;
(shards, stripe_size)
};
assert!(!shards.is_empty());
let pageserver_conninfo = PageserverConnectionInfo {
shards: shards.into_iter().collect(),
prefer_grpc: endpoint.grpc,
}; };
pageserver_conninfo.prefer_protocol = prefer_protocol;
let ps_conf = env.get_pageserver_conf(DEFAULT_PAGESERVER_ID)?; let ps_conf = env.get_pageserver_conf(DEFAULT_PAGESERVER_ID)?;
let auth_token = if matches!(ps_conf.pg_auth_type, AuthType::NeonJWT) { let auth_token = if matches!(ps_conf.pg_auth_type, AuthType::NeonJWT) {
@@ -1620,7 +1606,6 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
safekeepers, safekeepers,
pageserver_conninfo, pageserver_conninfo,
remote_ext_base_url: remote_ext_base_url.clone(), remote_ext_base_url: remote_ext_base_url.clone(),
shard_stripe_size: stripe_size.0 as usize,
create_test_user: args.create_test_user, create_test_user: args.create_test_user,
start_timeout: args.start_timeout, start_timeout: args.start_timeout,
autoprewarm: args.autoprewarm, autoprewarm: args.autoprewarm,
@@ -1637,66 +1622,45 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
.endpoints .endpoints
.get(endpoint_id.as_str()) .get(endpoint_id.as_str())
.with_context(|| format!("postgres endpoint {endpoint_id} is not found"))?; .with_context(|| format!("postgres endpoint {endpoint_id} is not found"))?;
let shards = if let Some(ps_id) = args.endpoint_pageserver_id {
let prefer_protocol = if endpoint.grpc {
PageserverProtocol::Grpc
} else {
PageserverProtocol::Libpq
};
let mut pageserver_conninfo = if let Some(ps_id) = args.endpoint_pageserver_id {
let conf = env.get_pageserver_conf(ps_id)?; let conf = env.get_pageserver_conf(ps_id)?;
let libpq_url = Some({ let ps_conninfo = pageserver_conf_to_shard_conn_info(conf)?;
let (host, port) = parse_host_port(&conf.listen_pg_addr)?; let shard_info = PageserverShardInfo {
let port = port.unwrap_or(5432); pageservers: vec![ps_conninfo],
format!("postgres://no_user@{host}:{port}")
});
let grpc_url = if let Some(grpc_addr) = &conf.listen_grpc_addr {
let (host, port) = parse_host_port(grpc_addr)?;
let port = port.unwrap_or(DEFAULT_PAGESERVER_GRPC_PORT);
Some(format!("grpc://no_user@{host}:{port}"))
} else {
None
};
let pageserver = PageserverShardConnectionInfo {
libpq_url,
grpc_url,
}; };
// If caller is telling us what pageserver to use, this is not a tenant which is // If caller is telling us what pageserver to use, this is not a tenant which is
// fully managed by storage controller, therefore not sharded. // fully managed by storage controller, therefore not sharded.
vec![(0, pageserver)] let shards: HashMap<_, _> = vec![(ShardIndex::unsharded(), shard_info)]
} else {
let storage_controller = StorageController::from_env(env);
storage_controller
.tenant_locate(endpoint.tenant_id)
.await?
.shards
.into_iter() .into_iter()
.map(|shard| { .collect();
// Use gRPC if requested. PageserverConnectionInfo {
let libpq_host = Host::parse(&shard.listen_pg_addr).expect("bad hostname"); shard_count: ShardCount::unsharded(),
let libpq_port = shard.listen_pg_port; stripe_size: None,
let libpq_url = shards,
Some(format!("postgres://no_user@{libpq_host}:{libpq_port}")); prefer_protocol,
}
} else {
// Look up the currently attached location of the tenant, and its striping metadata,
// to pass these on to postgres.
let storage_controller = StorageController::from_env(env);
let locate_result = storage_controller.tenant_locate(endpoint.tenant_id).await?;
let grpc_url = if let Some(grpc_host) = shard.listen_grpc_addr { tenant_locate_response_to_conn_info(&locate_result)?
let grpc_port = shard.listen_grpc_port.expect("no gRPC port");
Some(format!("grpc://no_user@{grpc_host}:{grpc_port}"))
} else {
None
};
(
shard.shard_id.shard_number.0 as u32,
PageserverShardConnectionInfo {
libpq_url,
grpc_url,
},
)
})
.collect::<Vec<_>>()
};
let pageserver_conninfo = PageserverConnectionInfo {
shards: shards.into_iter().collect(),
prefer_grpc: endpoint.grpc,
}; };
pageserver_conninfo.prefer_protocol = prefer_protocol;
// If --safekeepers argument is given, use only the listed // If --safekeepers argument is given, use only the listed
// safekeeper nodes; otherwise all from the env. // safekeeper nodes; otherwise all from the env.
let safekeepers = parse_safekeepers(&args.safekeepers)?; let safekeepers = parse_safekeepers(&args.safekeepers)?;
endpoint endpoint
.reconfigure(Some(pageserver_conninfo), None, safekeepers, None) .reconfigure(Some(&pageserver_conninfo), safekeepers, None)
.await?; .await?;
} }
EndpointCmd::Stop(args) => { EndpointCmd::Stop(args) => {

View File

@@ -37,7 +37,7 @@
//! <other PostgreSQL files> //! <other PostgreSQL files>
//! ``` //! ```
//! //!
use std::collections::BTreeMap; use std::collections::{BTreeMap, HashMap};
use std::fmt::Display; use std::fmt::Display;
use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream}; use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream};
use std::path::PathBuf; use std::path::PathBuf;
@@ -57,8 +57,8 @@ use compute_api::responses::{
TlsConfig, TlsConfig,
}; };
use compute_api::spec::{ use compute_api::spec::{
Cluster, ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, Database, PgIdent, Cluster, ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, Database, PageserverProtocol,
RemoteExtSpec, Role, PageserverShardInfo, PgIdent, RemoteExtSpec, Role,
}; };
// re-export these, because they're used in the reconfigure() function // re-export these, because they're used in the reconfigure() function
@@ -69,7 +69,6 @@ use jsonwebtoken::jwk::{
OctetKeyPairParameters, OctetKeyPairType, PublicKeyUse, OctetKeyPairParameters, OctetKeyPairType, PublicKeyUse,
}; };
use nix::sys::signal::{Signal, kill}; use nix::sys::signal::{Signal, kill};
use pageserver_api::shard::ShardStripeSize;
use pem::Pem; use pem::Pem;
use reqwest::header::CONTENT_TYPE; use reqwest::header::CONTENT_TYPE;
use safekeeper_api::PgMajorVersion; use safekeeper_api::PgMajorVersion;
@@ -80,6 +79,10 @@ use spki::der::Decode;
use spki::{SubjectPublicKeyInfo, SubjectPublicKeyInfoRef}; use spki::{SubjectPublicKeyInfo, SubjectPublicKeyInfoRef};
use tracing::debug; use tracing::debug;
use utils::id::{NodeId, TenantId, TimelineId}; use utils::id::{NodeId, TenantId, TimelineId};
use utils::shard::{ShardIndex, ShardNumber};
use pageserver_api::config::DEFAULT_GRPC_LISTEN_PORT as DEFAULT_PAGESERVER_GRPC_PORT;
use postgres_connection::parse_host_port;
use crate::local_env::LocalEnv; use crate::local_env::LocalEnv;
use crate::postgresql_conf::PostgresConf; use crate::postgresql_conf::PostgresConf;
@@ -392,7 +395,6 @@ pub struct EndpointStartArgs {
pub safekeepers: Vec<NodeId>, pub safekeepers: Vec<NodeId>,
pub pageserver_conninfo: PageserverConnectionInfo, pub pageserver_conninfo: PageserverConnectionInfo,
pub remote_ext_base_url: Option<String>, pub remote_ext_base_url: Option<String>,
pub shard_stripe_size: usize,
pub create_test_user: bool, pub create_test_user: bool,
pub start_timeout: Duration, pub start_timeout: Duration,
pub autoprewarm: bool, pub autoprewarm: bool,
@@ -724,6 +726,46 @@ impl Endpoint {
remote_extensions = None; remote_extensions = None;
}; };
// For the sake of backwards-compatibility, also fill in 'pageserver_connstring'
//
// XXX: I believe this is not really needed, except to make
// test_forward_compatibility happy.
//
// Use a closure so that we can conviniently return None in the middle of the
// loop.
let pageserver_connstring = (|| {
let num_shards = if args.pageserver_conninfo.shard_count.is_unsharded() {
1
} else {
args.pageserver_conninfo.shard_count.0
};
let mut connstrings = Vec::new();
for shard_no in 0..num_shards {
let shard_index = ShardIndex {
shard_count: args.pageserver_conninfo.shard_count,
shard_number: ShardNumber(shard_no),
};
let shard = args
.pageserver_conninfo
.shards
.get(&shard_index)
.expect(&format!(
"shard {} not found in pageserver_connection_info",
shard_index
));
let pageserver = shard
.pageservers
.first()
.expect("must have at least one pageserver");
if let Some(libpq_url) = &pageserver.libpq_url {
connstrings.push(libpq_url.clone());
} else {
return None;
}
}
Some(connstrings.join(","))
})();
// Create config file // Create config file
let config = { let config = {
let mut spec = ComputeSpec { let mut spec = ComputeSpec {
@@ -768,13 +810,14 @@ impl Endpoint {
branch_id: None, branch_id: None,
endpoint_id: Some(self.endpoint_id.clone()), endpoint_id: Some(self.endpoint_id.clone()),
mode: self.mode, mode: self.mode,
pageserver_connection_info: Some(args.pageserver_conninfo), pageserver_connection_info: Some(args.pageserver_conninfo.clone()),
pageserver_connstring,
safekeepers_generation: args.safekeepers_generation.map(|g| g.into_inner()), safekeepers_generation: args.safekeepers_generation.map(|g| g.into_inner()),
safekeeper_connstrings, safekeeper_connstrings,
storage_auth_token: args.auth_token.clone(), storage_auth_token: args.auth_token.clone(),
remote_extensions, remote_extensions,
pgbouncer_settings: None, pgbouncer_settings: None,
shard_stripe_size: Some(args.shard_stripe_size), shard_stripe_size: args.pageserver_conninfo.stripe_size, // redundant with pageserver_connection_info.stripe_size
local_proxy_config: None, local_proxy_config: None,
reconfigure_concurrency: self.reconfigure_concurrency, reconfigure_concurrency: self.reconfigure_concurrency,
drop_subscriptions_before_start: self.drop_subscriptions_before_start, drop_subscriptions_before_start: self.drop_subscriptions_before_start,
@@ -986,8 +1029,7 @@ impl Endpoint {
pub async fn reconfigure( pub async fn reconfigure(
&self, &self,
pageserver_conninfo: Option<PageserverConnectionInfo>, pageserver_conninfo: Option<&PageserverConnectionInfo>,
stripe_size: Option<ShardStripeSize>,
safekeepers: Option<Vec<NodeId>>, safekeepers: Option<Vec<NodeId>>,
safekeeper_generation: Option<SafekeeperGeneration>, safekeeper_generation: Option<SafekeeperGeneration>,
) -> Result<()> { ) -> Result<()> {
@@ -1009,10 +1051,8 @@ impl Endpoint {
!pageserver_conninfo.shards.is_empty(), !pageserver_conninfo.shards.is_empty(),
"no pageservers provided" "no pageservers provided"
); );
spec.pageserver_connection_info = Some(pageserver_conninfo); spec.pageserver_connection_info = Some(pageserver_conninfo.clone());
} spec.shard_stripe_size = pageserver_conninfo.stripe_size;
if stripe_size.is_some() {
spec.shard_stripe_size = stripe_size.map(|s| s.0 as usize);
} }
// If safekeepers are not specified, don't change them. // If safekeepers are not specified, don't change them.
@@ -1061,11 +1101,9 @@ impl Endpoint {
pub async fn reconfigure_pageservers( pub async fn reconfigure_pageservers(
&self, &self,
pageservers: PageserverConnectionInfo, pageservers: &PageserverConnectionInfo,
stripe_size: Option<ShardStripeSize>,
) -> Result<()> { ) -> Result<()> {
self.reconfigure(Some(pageservers), stripe_size, None, None) self.reconfigure(Some(pageservers), None, None).await
.await
} }
pub async fn reconfigure_safekeepers( pub async fn reconfigure_safekeepers(
@@ -1073,7 +1111,7 @@ impl Endpoint {
safekeepers: Vec<NodeId>, safekeepers: Vec<NodeId>,
generation: SafekeeperGeneration, generation: SafekeeperGeneration,
) -> Result<()> { ) -> Result<()> {
self.reconfigure(None, None, Some(safekeepers), Some(generation)) self.reconfigure(None, Some(safekeepers), Some(generation))
.await .await
} }
@@ -1129,3 +1167,68 @@ impl Endpoint {
) )
} }
} }
pub fn pageserver_conf_to_shard_conn_info(
conf: &crate::local_env::PageServerConf,
) -> Result<PageserverShardConnectionInfo> {
let libpq_url = {
let (host, port) = parse_host_port(&conf.listen_pg_addr)?;
let port = port.unwrap_or(5432);
Some(format!("postgres://no_user@{host}:{port}"))
};
let grpc_url = if let Some(grpc_addr) = &conf.listen_grpc_addr {
let (host, port) = parse_host_port(grpc_addr)?;
let port = port.unwrap_or(DEFAULT_PAGESERVER_GRPC_PORT);
Some(format!("grpc://no_user@{host}:{port}"))
} else {
None
};
Ok(PageserverShardConnectionInfo {
id: Some(conf.id.to_string()),
libpq_url,
grpc_url,
})
}
pub fn tenant_locate_response_to_conn_info(
response: &pageserver_api::controller_api::TenantLocateResponse,
) -> Result<PageserverConnectionInfo> {
let mut shards = HashMap::new();
for shard in response.shards.iter() {
tracing::info!("parsing {}", shard.listen_pg_addr);
let libpq_url = {
let host = &shard.listen_pg_addr;
let port = shard.listen_pg_port;
Some(format!("postgres://no_user@{host}:{port}"))
};
let grpc_url = if let Some(grpc_addr) = &shard.listen_grpc_addr {
let host = grpc_addr;
let port = shard.listen_grpc_port.expect("no gRPC port");
Some(format!("grpc://no_user@{host}:{port}"))
} else {
None
};
let shard_info = PageserverShardInfo {
pageservers: vec![PageserverShardConnectionInfo {
id: Some(shard.node_id.to_string()),
libpq_url,
grpc_url,
}],
};
shards.insert(shard.shard_id.to_index(), shard_info);
}
let stripe_size = if response.shard_params.count.is_unsharded() {
None
} else {
Some(response.shard_params.stripe_size.0)
};
Ok(PageserverConnectionInfo {
shard_count: response.shard_params.count,
stripe_size,
shards,
prefer_protocol: PageserverProtocol::default(),
})
}

View File

@@ -76,6 +76,12 @@ enum Command {
NodeStartDelete { NodeStartDelete {
#[arg(long)] #[arg(long)]
node_id: NodeId, node_id: NodeId,
/// When `force` is true, skip waiting for shards to prewarm during migration.
/// This can significantly speed up node deletion since prewarming all shards
/// can take considerable time, but may result in slower initial access to
/// migrated shards until they warm up naturally.
#[arg(long)]
force: bool,
}, },
/// Cancel deletion of the specified pageserver and wait for `timeout` /// Cancel deletion of the specified pageserver and wait for `timeout`
/// for the operation to be canceled. May be retried. /// for the operation to be canceled. May be retried.
@@ -952,13 +958,14 @@ async fn main() -> anyhow::Result<()> {
.dispatch::<(), ()>(Method::DELETE, format!("control/v1/node/{node_id}"), None) .dispatch::<(), ()>(Method::DELETE, format!("control/v1/node/{node_id}"), None)
.await?; .await?;
} }
Command::NodeStartDelete { node_id } => { Command::NodeStartDelete { node_id, force } => {
let query = if force {
format!("control/v1/node/{node_id}/delete?force=true")
} else {
format!("control/v1/node/{node_id}/delete")
};
storcon_client storcon_client
.dispatch::<(), ()>( .dispatch::<(), ()>(Method::PUT, query, None)
Method::PUT,
format!("control/v1/node/{node_id}/delete"),
None,
)
.await?; .await?;
println!("Delete started for {node_id}"); println!("Delete started for {node_id}");
} }

View File

@@ -46,16 +46,33 @@ pub struct ExtensionInstallResponse {
pub version: ExtVersion, pub version: ExtVersion,
} }
/// Status of the LFC prewarm process. The same state machine is reused for
/// both autoprewarm (prewarm after compute/Postgres start using the previously
/// stored LFC state) and explicit prewarming via API.
#[derive(Serialize, Default, Debug, Clone, PartialEq)] #[derive(Serialize, Default, Debug, Clone, PartialEq)]
#[serde(tag = "status", rename_all = "snake_case")] #[serde(tag = "status", rename_all = "snake_case")]
pub enum LfcPrewarmState { pub enum LfcPrewarmState {
/// Default value when compute boots up.
#[default] #[default]
NotPrewarmed, NotPrewarmed,
/// Prewarming thread is active and loading pages into LFC.
Prewarming, Prewarming,
/// We found requested LFC state in the endpoint storage and
/// completed prewarming successfully.
Completed, Completed,
Failed { /// Unexpected error happened during prewarming. Note, `Not Found 404`
error: String, /// response from the endpoint storage is explicitly excluded here
}, /// because it can normally happen on the first compute start,
/// since LFC state is not available yet.
Failed { error: String },
/// We tried to fetch the corresponding LFC state from the endpoint storage,
/// but received `Not Found 404`. This should normally happen only during the
/// first endpoint start after creation with `autoprewarm: true`.
///
/// During the orchestrated prewarm via API, when a caller explicitly
/// provides the LFC state key to prewarm from, it's the caller responsibility
/// to handle this status as an error state in this case.
Skipped,
} }
impl Display for LfcPrewarmState { impl Display for LfcPrewarmState {
@@ -64,6 +81,7 @@ impl Display for LfcPrewarmState {
LfcPrewarmState::NotPrewarmed => f.write_str("NotPrewarmed"), LfcPrewarmState::NotPrewarmed => f.write_str("NotPrewarmed"),
LfcPrewarmState::Prewarming => f.write_str("Prewarming"), LfcPrewarmState::Prewarming => f.write_str("Prewarming"),
LfcPrewarmState::Completed => f.write_str("Completed"), LfcPrewarmState::Completed => f.write_str("Completed"),
LfcPrewarmState::Skipped => f.write_str("Skipped"),
LfcPrewarmState::Failed { error } => write!(f, "Error({error})"), LfcPrewarmState::Failed { error } => write!(f, "Error({error})"),
} }
} }

View File

@@ -14,6 +14,7 @@ use serde::{Deserialize, Serialize};
use url::Url; use url::Url;
use utils::id::{TenantId, TimelineId}; use utils::id::{TenantId, TimelineId};
use utils::lsn::Lsn; use utils::lsn::Lsn;
use utils::shard::{ShardCount, ShardIndex};
use crate::responses::TlsConfig; use crate::responses::TlsConfig;
@@ -106,11 +107,18 @@ pub struct ComputeSpec {
pub tenant_id: Option<TenantId>, pub tenant_id: Option<TenantId>,
pub timeline_id: Option<TimelineId>, pub timeline_id: Option<TimelineId>,
// Pageserver information can be passed in two different ways: /// Pageserver information can be passed in three different ways:
// 1. Here /// 1. Here in `pageserver_connection_info`
// 2. in cluster.settings. This is legacy, we are switching to method 1. /// 2. In the `pageserver_connstring` field.
/// 3. in `cluster.settings`.
///
/// The goal is to use method 1. everywhere. But for backwards-compatibility with old
/// versions of the control plane, `compute_ctl` will check 2. and 3. if the
/// `pageserver_connection_info` field is missing.
pub pageserver_connection_info: Option<PageserverConnectionInfo>, pub pageserver_connection_info: Option<PageserverConnectionInfo>,
pub pageserver_connstring: Option<String>,
// More neon ids that we expose to the compute_ctl // More neon ids that we expose to the compute_ctl
// and to postgres as neon extension GUCs. // and to postgres as neon extension GUCs.
pub project_id: Option<String>, pub project_id: Option<String>,
@@ -145,7 +153,7 @@ pub struct ComputeSpec {
// Stripe size for pageserver sharding, in pages // Stripe size for pageserver sharding, in pages
#[serde(default)] #[serde(default)]
pub shard_stripe_size: Option<usize>, pub shard_stripe_size: Option<u32>,
/// Local Proxy configuration used for JWT authentication /// Local Proxy configuration used for JWT authentication
#[serde(default)] #[serde(default)]
@@ -218,16 +226,28 @@ pub enum ComputeFeature {
UnknownFeature, UnknownFeature,
} }
/// Feature flag to signal `compute_ctl` to enable certain experimental functionality. #[derive(Clone, Debug, Deserialize, Serialize, Eq, PartialEq)]
#[derive(Clone, Debug, Default, Deserialize, Serialize, Eq, PartialEq)]
pub struct PageserverConnectionInfo { pub struct PageserverConnectionInfo {
pub shards: HashMap<u32, PageserverShardConnectionInfo>, /// NB: 0 for unsharded tenants, 1 for sharded tenants with 1 shard, following storage
pub shard_count: ShardCount,
pub prefer_grpc: bool, /// INVARIANT: null if shard_count is 0, otherwise non-null and immutable
pub stripe_size: Option<u32>,
pub shards: HashMap<ShardIndex, PageserverShardInfo>,
#[serde(default)]
pub prefer_protocol: PageserverProtocol,
} }
#[derive(Clone, Debug, Default, Deserialize, Serialize, Eq, PartialEq)] #[derive(Clone, Debug, Deserialize, Serialize, Eq, PartialEq)]
pub struct PageserverShardInfo {
pub pageservers: Vec<PageserverShardConnectionInfo>,
}
#[derive(Clone, Debug, Deserialize, Serialize, Eq, PartialEq)]
pub struct PageserverShardConnectionInfo { pub struct PageserverShardConnectionInfo {
pub id: Option<String>,
pub libpq_url: Option<String>, pub libpq_url: Option<String>,
pub grpc_url: Option<String>, pub grpc_url: Option<String>,
} }
@@ -465,13 +485,15 @@ pub struct JwksSettings {
pub jwt_audience: Option<String>, pub jwt_audience: Option<String>,
} }
/// Protocol used to connect to a Pageserver. Parsed from the connstring scheme. /// Protocol used to connect to a Pageserver.
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] #[derive(Clone, Copy, Debug, Default, Deserialize, Serialize, PartialEq, Eq)]
pub enum PageserverProtocol { pub enum PageserverProtocol {
/// The original protocol based on libpq and COPY. Uses postgresql:// or postgres:// scheme. /// The original protocol based on libpq and COPY. Uses postgresql:// or postgres:// scheme.
#[default] #[default]
#[serde(rename = "libpq")]
Libpq, Libpq,
/// A newer, gRPC-based protocol. Uses grpc:// scheme. /// A newer, gRPC-based protocol. Uses grpc:// scheme.
#[serde(rename = "grpc")]
Grpc, Grpc,
} }

View File

@@ -4,12 +4,14 @@
//! a default registry. //! a default registry.
#![deny(clippy::undocumented_unsafe_blocks)] #![deny(clippy::undocumented_unsafe_blocks)]
use std::sync::RwLock;
use measured::label::{LabelGroupSet, LabelGroupVisitor, LabelName, NoLabels}; use measured::label::{LabelGroupSet, LabelGroupVisitor, LabelName, NoLabels};
use measured::metric::counter::CounterState; use measured::metric::counter::CounterState;
use measured::metric::gauge::GaugeState; use measured::metric::gauge::GaugeState;
use measured::metric::group::Encoding; use measured::metric::group::Encoding;
use measured::metric::name::{MetricName, MetricNameEncoder}; use measured::metric::name::{MetricName, MetricNameEncoder};
use measured::metric::{MetricEncoding, MetricFamilyEncoding}; use measured::metric::{MetricEncoding, MetricFamilyEncoding, MetricType};
use measured::{FixedCardinalityLabel, LabelGroup, MetricGroup}; use measured::{FixedCardinalityLabel, LabelGroup, MetricGroup};
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use prometheus::Registry; use prometheus::Registry;
@@ -116,12 +118,52 @@ pub fn pow2_buckets(start: usize, end: usize) -> Vec<f64> {
.collect() .collect()
} }
pub struct InfoMetric<L: LabelGroup, M: MetricType = GaugeState> {
label: RwLock<L>,
metric: M,
}
impl<L: LabelGroup> InfoMetric<L> {
pub fn new(label: L) -> Self {
Self::with_metric(label, GaugeState::new(1))
}
}
impl<L: LabelGroup, M: MetricType<Metadata = ()>> InfoMetric<L, M> {
pub fn with_metric(label: L, metric: M) -> Self {
Self {
label: RwLock::new(label),
metric,
}
}
pub fn set_label(&self, label: L) {
*self.label.write().unwrap() = label;
}
}
impl<L, M, E> MetricFamilyEncoding<E> for InfoMetric<L, M>
where
L: LabelGroup,
M: MetricEncoding<E, Metadata = ()>,
E: Encoding,
{
fn collect_family_into(
&self,
name: impl measured::metric::name::MetricNameEncoder,
enc: &mut E,
) -> Result<(), E::Err> {
M::write_type(&name, enc)?;
self.metric
.collect_into(&(), &*self.label.read().unwrap(), name, enc)
}
}
pub struct BuildInfo { pub struct BuildInfo {
pub revision: &'static str, pub revision: &'static str,
pub build_tag: &'static str, pub build_tag: &'static str,
} }
// todo: allow label group without the set
impl LabelGroup for BuildInfo { impl LabelGroup for BuildInfo {
fn visit_values(&self, v: &mut impl LabelGroupVisitor) { fn visit_values(&self, v: &mut impl LabelGroupVisitor) {
const REVISION: &LabelName = LabelName::from_str("revision"); const REVISION: &LabelName = LabelName::from_str("revision");
@@ -131,24 +173,6 @@ impl LabelGroup for BuildInfo {
} }
} }
impl<T: Encoding> MetricFamilyEncoding<T> for BuildInfo
where
GaugeState: MetricEncoding<T>,
{
fn collect_family_into(
&self,
name: impl measured::metric::name::MetricNameEncoder,
enc: &mut T,
) -> Result<(), T::Err> {
enc.write_help(&name, "Build/version information")?;
GaugeState::write_type(&name, enc)?;
GaugeState {
count: std::sync::atomic::AtomicI64::new(1),
}
.collect_into(&(), self, name, enc)
}
}
#[derive(MetricGroup)] #[derive(MetricGroup)]
#[metric(new(build_info: BuildInfo))] #[metric(new(build_info: BuildInfo))]
pub struct NeonMetrics { pub struct NeonMetrics {
@@ -165,8 +189,8 @@ pub struct NeonMetrics {
#[derive(MetricGroup)] #[derive(MetricGroup)]
#[metric(new(build_info: BuildInfo))] #[metric(new(build_info: BuildInfo))]
pub struct LibMetrics { pub struct LibMetrics {
#[metric(init = build_info)] #[metric(init = InfoMetric::new(build_info))]
build_info: BuildInfo, build_info: InfoMetric<BuildInfo>,
#[metric(flatten)] #[metric(flatten)]
rusage: Rusage, rusage: Rusage,

View File

@@ -8,20 +8,19 @@ license.workspace = true
thiserror.workspace = true thiserror.workspace = true
nix.workspace = true nix.workspace = true
workspace_hack = { version = "0.1", path = "../../workspace_hack" } workspace_hack = { version = "0.1", path = "../../workspace_hack" }
rustc-hash = { version = "2.1.1" }
rand = "0.9.1"
libc.workspace = true libc.workspace = true
lock_api = "0.4.13" lock_api.workspace = true
rustc-hash.workspace = true
[dev-dependencies] [dev-dependencies]
criterion = { workspace = true, features = ["html_reports"] } criterion = { workspace = true, features = ["html_reports"] }
rand = "0.9"
rand_distr = "0.5.1" rand_distr = "0.5.1"
xxhash-rust = { version = "0.8.15", features = ["xxh3"] } xxhash-rust = { version = "0.8.15", features = ["xxh3"] }
ahash.workspace = true ahash.workspace = true
twox-hash = { version = "2.1.1" } twox-hash = { version = "2.1.1" }
seahash = "4.1.0" seahash = "4.1.0"
hashbrown = { git = "https://github.com/quantumish/hashbrown.git", rev = "6610e6d" } hashbrown = { git = "https://github.com/quantumish/hashbrown.git", rev = "6610e6d" }
foldhash = "0.1.5"
[target.'cfg(target_os = "macos")'.dependencies] [target.'cfg(target_os = "macos")'.dependencies]

View File

@@ -13,6 +13,8 @@
//! This map is resizable (if initialized on top of a [`ShmemHandle`]). Both growing and shrinking happen //! This map is resizable (if initialized on top of a [`ShmemHandle`]). Both growing and shrinking happen
//! in-place and are at a high level achieved by expanding/reducing the bucket array and rebuilding the //! in-place and are at a high level achieved by expanding/reducing the bucket array and rebuilding the
//! dictionary by rehashing all keys. //! dictionary by rehashing all keys.
//!
//! Concurrency is managed very simply: the entire map is guarded by one shared-memory RwLock.
use std::fmt::Debug; use std::fmt::Debug;
use std::hash::{BuildHasher, Hash}; use std::hash::{BuildHasher, Hash};
@@ -30,6 +32,19 @@ mod tests;
use core::{Bucket, CoreHashMap, INVALID_POS}; use core::{Bucket, CoreHashMap, INVALID_POS};
use entry::{Entry, OccupiedEntry, PrevPos, VacantEntry}; use entry::{Entry, OccupiedEntry, PrevPos, VacantEntry};
use thiserror::Error;
/// Error type for a hashmap shrink operation.
#[derive(Error, Debug)]
pub enum HashMapShrinkError {
/// There was an error encountered while resizing the memory area.
#[error("shmem resize failed: {0}")]
ResizeError(shmem::Error),
/// Occupied entries in to-be-shrunk space were encountered beginning at the given index.
#[error("occupied entry in deallocated space found at {0}")]
RemainingEntries(usize),
}
/// This represents a hash table that (possibly) lives in shared memory. /// This represents a hash table that (possibly) lives in shared memory.
/// If a new process is launched with fork(), the child process inherits /// If a new process is launched with fork(), the child process inherits
/// this struct. /// this struct.
@@ -147,8 +162,8 @@ impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> {
}; };
let hashmap = CoreHashMap::new(buckets, dictionary); let hashmap = CoreHashMap::new(buckets, dictionary);
let lock = RwLock::from_raw(PthreadRwLock::new(raw_lock_ptr.cast()), hashmap);
unsafe { unsafe {
let lock = RwLock::from_raw(PthreadRwLock::new(raw_lock_ptr.cast()), hashmap);
std::ptr::write(shared_ptr, lock); std::ptr::write(shared_ptr, lock);
} }
@@ -171,6 +186,9 @@ impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> {
} }
/// Initialize a table for reading. Currently identical to [`HashMapInit::attach_writer`]. /// Initialize a table for reading. Currently identical to [`HashMapInit::attach_writer`].
///
/// This is a holdover from a previous implementation and is being kept around for
/// backwards compatibility reasons.
pub fn attach_reader(self) -> HashMapAccess<'a, K, V, S> { pub fn attach_reader(self) -> HashMapAccess<'a, K, V, S> {
self.attach_writer() self.attach_writer()
} }
@@ -184,8 +202,8 @@ impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> {
/// ///
/// [`libc::pthread_rwlock_t`] /// [`libc::pthread_rwlock_t`]
/// [`HashMapShared`] /// [`HashMapShared`]
/// [buckets] /// buckets
/// [dictionary] /// dictionary
/// ///
/// In between the above parts, there can be padding bytes to align the parts correctly. /// In between the above parts, there can be padding bytes to align the parts correctly.
type HashMapShared<'a, K, V> = RwLock<CoreHashMap<'a, K, V>>; type HashMapShared<'a, K, V> = RwLock<CoreHashMap<'a, K, V>>;
@@ -310,6 +328,9 @@ where
} }
/// Get a reference to the entry containing a key. /// Get a reference to the entry containing a key.
///
/// NB: This takes a write lock as there's no way to distinguish whether the intention
/// is to use the entry for reading or for writing in advance.
pub fn entry(&self, key: K) -> Entry<'a, '_, K, V> { pub fn entry(&self, key: K) -> Entry<'a, '_, K, V> {
let hash = self.get_hash_value(&key); let hash = self.get_hash_value(&key);
self.entry_with_hash(key, hash) self.entry_with_hash(key, hash)
@@ -317,7 +338,7 @@ where
/// Remove a key given its hash. Returns the associated value if it existed. /// Remove a key given its hash. Returns the associated value if it existed.
pub fn remove(&self, key: &K) -> Option<V> { pub fn remove(&self, key: &K) -> Option<V> {
let hash = self.get_hash_value(&key); let hash = self.get_hash_value(key);
match self.entry_with_hash(key.clone(), hash) { match self.entry_with_hash(key.clone(), hash) {
Entry::Occupied(e) => Some(e.remove()), Entry::Occupied(e) => Some(e.remove()),
Entry::Vacant(_) => None, Entry::Vacant(_) => None,
@@ -355,7 +376,7 @@ where
Some((key, _)) => Some(OccupiedEntry { Some((key, _)) => Some(OccupiedEntry {
_key: key.clone(), _key: key.clone(),
bucket_pos: pos as u32, bucket_pos: pos as u32,
prev_pos: entry::PrevPos::Unknown(self.get_hash_value(&key)), prev_pos: entry::PrevPos::Unknown(self.get_hash_value(key)),
map, map,
}), }),
_ => None, _ => None,
@@ -550,12 +571,7 @@ where
/// The following cases result in a panic: /// The following cases result in a panic:
/// - Calling this function on a map initialized with [`HashMapInit::with_fixed`]. /// - Calling this function on a map initialized with [`HashMapInit::with_fixed`].
/// - Calling this function on a map when no shrink operation is in progress. /// - Calling this function on a map when no shrink operation is in progress.
/// - Calling this function on a map with `shrink_mode` set to [`HashMapShrinkMode::Remap`] and pub fn finish_shrink(&self) -> Result<(), HashMapShrinkError> {
/// there are more buckets in use than the value returned by [`HashMapAccess::shrink_goal`].
///
/// # Errors
/// Returns an [`shmem::Error`] if any errors occur resizing the memory region.
pub fn finish_shrink(&self) -> Result<(), shmem::Error> {
let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write(); let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
assert!( assert!(
map.alloc_limit != INVALID_POS, map.alloc_limit != INVALID_POS,
@@ -574,10 +590,8 @@ where
); );
for i in (num_buckets as usize)..map.buckets.len() { for i in (num_buckets as usize)..map.buckets.len() {
if let Some((k, v)) = map.buckets[i].inner.take() { if map.buckets[i].inner.is_some() {
// alloc_bucket increases count, so need to decrease since we're just moving return Err(HashMapShrinkError::RemainingEntries(i));
map.buckets_in_use -= 1;
map.alloc_bucket(k, v).unwrap();
} }
} }
@@ -587,7 +601,9 @@ where
.expect("shrink called on a fixed-size hash table"); .expect("shrink called on a fixed-size hash table");
let size_bytes = HashMapInit::<K, V, S>::estimate_size(num_buckets); let size_bytes = HashMapInit::<K, V, S>::estimate_size(num_buckets);
shmem_handle.set_size(size_bytes)?; if let Err(e) = shmem_handle.set_size(size_bytes) {
return Err(HashMapShrinkError::ResizeError(e));
}
let end_ptr: *mut u8 = unsafe { shmem_handle.data_ptr.as_ptr().add(size_bytes) }; let end_ptr: *mut u8 = unsafe { shmem_handle.data_ptr.as_ptr().add(size_bytes) };
let buckets_ptr = map.buckets.as_mut_ptr(); let buckets_ptr = map.buckets.as_mut_ptr();
self.rehash_dict(&mut map, buckets_ptr, end_ptr, num_buckets, num_buckets); self.rehash_dict(&mut map, buckets_ptr, end_ptr, num_buckets, num_buckets);

View File

@@ -43,9 +43,6 @@ pub(crate) struct CoreHashMap<'a, K, V> {
pub(crate) alloc_limit: u32, pub(crate) alloc_limit: u32,
/// The number of currently occupied buckets. /// The number of currently occupied buckets.
pub(crate) buckets_in_use: u32, pub(crate) buckets_in_use: u32,
// pub(crate) lock: libc::pthread_mutex_t,
// Unclear what the purpose of this is.
pub(crate) _user_list_head: u32,
} }
impl<'a, K, V> Debug for CoreHashMap<'a, K, V> impl<'a, K, V> Debug for CoreHashMap<'a, K, V>
@@ -66,7 +63,7 @@ where
/// Error for when there are no empty buckets left but one is needed. /// Error for when there are no empty buckets left but one is needed.
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub struct FullError(); pub struct FullError;
impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> { impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> {
const FILL_FACTOR: f32 = 0.60; const FILL_FACTOR: f32 = 0.60;
@@ -118,7 +115,6 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> {
buckets, buckets,
free_head: 0, free_head: 0,
buckets_in_use: 0, buckets_in_use: 0,
_user_list_head: INVALID_POS,
alloc_limit: INVALID_POS, alloc_limit: INVALID_POS,
} }
} }
@@ -179,7 +175,7 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> {
pos = bucket.next; pos = bucket.next;
} }
if pos == INVALID_POS { if pos == INVALID_POS {
return Err(FullError()); return Err(FullError);
} }
// Repair the freelist. // Repair the freelist.

View File

@@ -90,7 +90,6 @@ impl<K, V> OccupiedEntry<'_, '_, K, V> {
self.map.dictionary[dict_pos as usize] = bucket.next; self.map.dictionary[dict_pos as usize] = bucket.next;
} }
PrevPos::Chained(bucket_pos) => { PrevPos::Chained(bucket_pos) => {
// println!("we think prev of {} is {bucket_pos}", self.bucket_pos);
self.map.buckets[bucket_pos as usize].next = bucket.next; self.map.buckets[bucket_pos as usize].next = bucket.next;
} }
_ => unreachable!(), _ => unreachable!(),
@@ -125,9 +124,6 @@ impl<'b, K: Clone + Hash + Eq, V> VacantEntry<'_, 'b, K, V> {
/// Will return [`FullError`] if there are no unoccupied buckets in the map. /// Will return [`FullError`] if there are no unoccupied buckets in the map.
pub fn insert(mut self, value: V) -> Result<ValueWriteGuard<'b, V>, FullError> { pub fn insert(mut self, value: V) -> Result<ValueWriteGuard<'b, V>, FullError> {
let pos = self.map.alloc_bucket(self.key, value)?; let pos = self.map.alloc_bucket(self.key, value)?;
if pos == INVALID_POS {
return Err(FullError());
}
self.map.buckets[pos as usize].next = self.map.dictionary[self.dict_pos as usize]; self.map.buckets[pos as usize].next = self.map.dictionary[self.dict_pos as usize];
self.map.dictionary[self.dict_pos as usize] = pos; self.map.dictionary[self.dict_pos as usize] = pos;

View File

@@ -164,16 +164,16 @@ fn do_deletes(
fn do_shrink( fn do_shrink(
writer: &mut HashMapAccess<TestKey, usize>, writer: &mut HashMapAccess<TestKey, usize>,
shadow: &mut BTreeMap<TestKey, usize>, shadow: &mut BTreeMap<TestKey, usize>,
from: u32,
to: u32, to: u32,
) { ) {
assert!(writer.shrink_goal().is_none()); assert!(writer.shrink_goal().is_none());
writer.begin_shrink(to); writer.begin_shrink(to);
assert_eq!(writer.shrink_goal(), Some(to as usize)); assert_eq!(writer.shrink_goal(), Some(to as usize));
while writer.get_num_buckets_in_use() > to as usize { for i in to..from {
let (k, _) = shadow.pop_first().unwrap(); if let Some(entry) = writer.entry_at_bucket(i as usize) {
let entry = writer.entry(k); shadow.remove(&entry._key);
if let Entry::Occupied(e) = entry { entry.remove();
e.remove();
} }
} }
let old_usage = writer.get_num_buckets_in_use(); let old_usage = writer.get_num_buckets_in_use();
@@ -298,7 +298,7 @@ fn test_shrink() {
let mut rng = rand::rng(); let mut rng = rand::rng();
do_random_ops(10000, 1500, 0.75, &mut writer, &mut shadow, &mut rng); do_random_ops(10000, 1500, 0.75, &mut writer, &mut shadow, &mut rng);
do_shrink(&mut writer, &mut shadow, 1000); do_shrink(&mut writer, &mut shadow, 1500, 1000);
assert_eq!(writer.get_num_buckets(), 1000); assert_eq!(writer.get_num_buckets(), 1000);
do_deletes(500, &mut writer, &mut shadow); do_deletes(500, &mut writer, &mut shadow);
do_random_ops(10000, 500, 0.75, &mut writer, &mut shadow, &mut rng); do_random_ops(10000, 500, 0.75, &mut writer, &mut shadow, &mut rng);
@@ -315,7 +315,7 @@ fn test_shrink_grow_seq() {
do_random_ops(500, 1000, 0.1, &mut writer, &mut shadow, &mut rng); do_random_ops(500, 1000, 0.1, &mut writer, &mut shadow, &mut rng);
eprintln!("Shrinking to 750"); eprintln!("Shrinking to 750");
do_shrink(&mut writer, &mut shadow, 750); do_shrink(&mut writer, &mut shadow, 1000, 750);
do_random_ops(200, 1000, 0.5, &mut writer, &mut shadow, &mut rng); do_random_ops(200, 1000, 0.5, &mut writer, &mut shadow, &mut rng);
eprintln!("Growing to 1500"); eprintln!("Growing to 1500");
writer.grow(1500).unwrap(); writer.grow(1500).unwrap();
@@ -324,7 +324,7 @@ fn test_shrink_grow_seq() {
while shadow.len() > 100 { while shadow.len() > 100 {
do_deletes(1, &mut writer, &mut shadow); do_deletes(1, &mut writer, &mut shadow);
} }
do_shrink(&mut writer, &mut shadow, 200); do_shrink(&mut writer, &mut shadow, 1500, 200);
do_random_ops(50, 1500, 0.25, &mut writer, &mut shadow, &mut rng); do_random_ops(50, 1500, 0.25, &mut writer, &mut shadow, &mut rng);
eprintln!("Growing to 10k"); eprintln!("Growing to 10k");
writer.grow(10000).unwrap(); writer.grow(10000).unwrap();
@@ -349,8 +349,7 @@ fn test_bucket_ops() {
let pos = match writer.entry(1.into()) { let pos = match writer.entry(1.into()) {
Entry::Occupied(e) => { Entry::Occupied(e) => {
assert_eq!(e._key, 1.into()); assert_eq!(e._key, 1.into());
let pos = e.bucket_pos as usize; e.bucket_pos as usize
pos
} }
Entry::Vacant(_) => { Entry::Vacant(_) => {
panic!("Insert didn't affect entry"); panic!("Insert didn't affect entry");

View File

@@ -1,5 +1,3 @@
//! Shared memory utilities for neon communicator
pub mod hash; pub mod hash;
pub mod shmem; pub mod shmem;
pub mod sync; pub mod sync;

View File

@@ -6,7 +6,7 @@ use std::ptr::NonNull;
use nix::errno::Errno; use nix::errno::Errno;
pub type RwLock<T> = lock_api::RwLock<PthreadRwLock, T>; pub type RwLock<T> = lock_api::RwLock<PthreadRwLock, T>;
pub(crate) type RwLockReadGuard<'a, T> = lock_api::RwLockReadGuard<'a, PthreadRwLock, T>; pub type RwLockReadGuard<'a, T> = lock_api::RwLockReadGuard<'a, PthreadRwLock, T>;
pub type RwLockWriteGuard<'a, T> = lock_api::RwLockWriteGuard<'a, PthreadRwLock, T>; pub type RwLockWriteGuard<'a, T> = lock_api::RwLockWriteGuard<'a, PthreadRwLock, T>;
pub type ValueReadGuard<'a, T> = lock_api::MappedRwLockReadGuard<'a, PthreadRwLock, T>; pub type ValueReadGuard<'a, T> = lock_api::MappedRwLockReadGuard<'a, PthreadRwLock, T>;
pub type ValueWriteGuard<'a, T> = lock_api::MappedRwLockWriteGuard<'a, PthreadRwLock, T>; pub type ValueWriteGuard<'a, T> = lock_api::MappedRwLockWriteGuard<'a, PthreadRwLock, T>;
@@ -14,19 +14,34 @@ pub type ValueWriteGuard<'a, T> = lock_api::MappedRwLockWriteGuard<'a, PthreadRw
/// Shared memory read-write lock. /// Shared memory read-write lock.
pub struct PthreadRwLock(Option<NonNull<libc::pthread_rwlock_t>>); pub struct PthreadRwLock(Option<NonNull<libc::pthread_rwlock_t>>);
/// Simple macro that calls a function in the libc namespace and panics if return value is nonzero.
macro_rules! libc_checked {
($fn_name:ident ( $($arg:expr),* )) => {{
let res = libc::$fn_name($($arg),*);
if res != 0 {
panic!("{} failed with {}", stringify!($fn_name), Errno::from_raw(res));
}
}};
}
impl PthreadRwLock { impl PthreadRwLock {
pub fn new(lock: *mut libc::pthread_rwlock_t) -> Self { /// Creates a new `PthreadRwLock` on top of a pointer to a pthread rwlock.
///
/// # Safety
/// `lock` must be non-null. Every unsafe operation will panic in the event of an error.
pub unsafe fn new(lock: *mut libc::pthread_rwlock_t) -> Self {
unsafe { unsafe {
let mut attrs = MaybeUninit::uninit(); let mut attrs = MaybeUninit::uninit();
// Ignoring return value here - only possible error is OOM. libc_checked!(pthread_rwlockattr_init(attrs.as_mut_ptr()));
libc::pthread_rwlockattr_init(attrs.as_mut_ptr()); libc_checked!(pthread_rwlockattr_setpshared(
libc::pthread_rwlockattr_setpshared(attrs.as_mut_ptr(), libc::PTHREAD_PROCESS_SHARED); attrs.as_mut_ptr(),
// TODO(quantumish): worth making this function return Result? libc::PTHREAD_PROCESS_SHARED
libc::pthread_rwlock_init(lock, attrs.as_mut_ptr()); ));
libc_checked!(pthread_rwlock_init(lock, attrs.as_mut_ptr()));
// Safety: POSIX specifies that "any function affecting the attributes // Safety: POSIX specifies that "any function affecting the attributes
// object (including destruction) shall not affect any previously // object (including destruction) shall not affect any previously
// initialized read-write locks". // initialized read-write locks".
libc::pthread_rwlockattr_destroy(attrs.as_mut_ptr()); libc_checked!(pthread_rwlockattr_destroy(attrs.as_mut_ptr()));
Self(Some(NonNull::new_unchecked(lock))) Self(Some(NonNull::new_unchecked(lock)))
} }
} }
@@ -34,7 +49,7 @@ impl PthreadRwLock {
fn inner(&self) -> NonNull<libc::pthread_rwlock_t> { fn inner(&self) -> NonNull<libc::pthread_rwlock_t> {
match self.0 { match self.0 {
None => { None => {
panic!("PthreadRwLock constructed badly - something likely used RawMutex::INIT") panic!("PthreadRwLock constructed badly - something likely used RawRwLock::INIT")
} }
Some(x) => x, Some(x) => x,
} }
@@ -45,31 +60,16 @@ unsafe impl lock_api::RawRwLock for PthreadRwLock {
type GuardMarker = lock_api::GuardSend; type GuardMarker = lock_api::GuardSend;
const INIT: Self = Self(None); const INIT: Self = Self(None);
fn lock_shared(&self) {
unsafe {
let res = libc::pthread_rwlock_rdlock(self.inner().as_ptr());
if res != 0 {
panic!("rdlock failed with {}", Errno::from_raw(res));
}
}
}
fn try_lock_shared(&self) -> bool { fn try_lock_shared(&self) -> bool {
unsafe { unsafe {
let res = libc::pthread_rwlock_tryrdlock(self.inner().as_ptr()); let res = libc::pthread_rwlock_tryrdlock(self.inner().as_ptr());
match res { match res {
0 => true, 0 => true,
libc::EAGAIN => false, libc::EAGAIN => false,
_ => panic!("try_rdlock failed with {}", Errno::from_raw(res)), _ => panic!(
} "pthread_rwlock_tryrdlock failed with {}",
} Errno::from_raw(res)
} ),
fn lock_exclusive(&self) {
unsafe {
let res = libc::pthread_rwlock_wrlock(self.inner().as_ptr());
if res != 0 {
panic!("wrlock failed with {}", Errno::from_raw(res));
} }
} }
} }
@@ -85,20 +85,27 @@ unsafe impl lock_api::RawRwLock for PthreadRwLock {
} }
} }
unsafe fn unlock_exclusive(&self) { fn lock_shared(&self) {
unsafe { unsafe {
let res = libc::pthread_rwlock_unlock(self.inner().as_ptr()); libc_checked!(pthread_rwlock_rdlock(self.inner().as_ptr()));
if res != 0 {
panic!("unlock failed with {}", Errno::from_raw(res));
}
} }
} }
fn lock_exclusive(&self) {
unsafe {
libc_checked!(pthread_rwlock_wrlock(self.inner().as_ptr()));
}
}
unsafe fn unlock_exclusive(&self) {
unsafe {
libc_checked!(pthread_rwlock_unlock(self.inner().as_ptr()));
}
}
unsafe fn unlock_shared(&self) { unsafe fn unlock_shared(&self) {
unsafe { unsafe {
let res = libc::pthread_rwlock_unlock(self.inner().as_ptr()); libc_checked!(pthread_rwlock_unlock(self.inner().as_ptr()));
if res != 0 {
panic!("unlock failed with {}", Errno::from_raw(res));
}
} }
} }
} }

View File

@@ -749,7 +749,18 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
trace!("got query {query_string:?}"); trace!("got query {query_string:?}");
if let Err(e) = handler.process_query(self, query_string).await { if let Err(e) = handler.process_query(self, query_string).await {
match e { match e {
QueryError::Shutdown => return Ok(ProcessMsgResult::Break), err @ QueryError::Shutdown => {
// Notify postgres of the connection shutdown at the libpq
// protocol level. This avoids postgres having to tell apart
// from an idle connection and a stale one, which is bug prone.
let shutdown_error = short_error(&err);
self.write_message_noflush(&BeMessage::ErrorResponse(
&shutdown_error,
Some(err.pg_error_code()),
))?;
return Ok(ProcessMsgResult::Break);
}
QueryError::SimulatedConnectionError => { QueryError::SimulatedConnectionError => {
return Err(QueryError::SimulatedConnectionError); return Err(QueryError::SimulatedConnectionError);
} }

View File

@@ -47,6 +47,7 @@ tracing-subscriber = { workspace = true, features = ["json", "registry"] }
tracing-utils.workspace = true tracing-utils.workspace = true
rand.workspace = true rand.workspace = true
scopeguard.workspace = true scopeguard.workspace = true
uuid.workspace = true
strum.workspace = true strum.workspace = true
strum_macros.workspace = true strum_macros.workspace = true
walkdir.workspace = true walkdir.workspace = true

View File

@@ -12,7 +12,8 @@ use jsonwebtoken::{
Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation, decode, encode, Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation, decode, encode,
}; };
use pem::Pem; use pem::Pem;
use serde::{Deserialize, Serialize, de::DeserializeOwned}; use serde::{Deserialize, Deserializer, Serialize, de::DeserializeOwned};
use uuid::Uuid;
use crate::id::TenantId; use crate::id::TenantId;
@@ -25,6 +26,11 @@ pub enum Scope {
/// Provides access to all data for a specific tenant (specified in `struct Claims` below) /// Provides access to all data for a specific tenant (specified in `struct Claims` below)
// TODO: join these two? // TODO: join these two?
Tenant, Tenant,
/// Provides access to all data for a specific tenant, but based on endpoint ID. This token scope
/// is only used by compute to fetch the spec for a specific endpoint. The spec contains a Tenant-scoped
/// token authorizing access to all data of a tenant, so the spec-fetch API requires a TenantEndpoint
/// scope token to ensure that untrusted compute nodes can't fetch spec for arbitrary endpoints.
TenantEndpoint,
/// Provides blanket access to all tenants on the pageserver plus pageserver-wide APIs. /// Provides blanket access to all tenants on the pageserver plus pageserver-wide APIs.
/// Should only be used e.g. for status check/tenant creation/list. /// Should only be used e.g. for status check/tenant creation/list.
PageServerApi, PageServerApi,
@@ -51,17 +57,43 @@ pub enum Scope {
ControllerPeer, ControllerPeer,
} }
fn deserialize_empty_string_as_none_uuid<'de, D>(deserializer: D) -> Result<Option<Uuid>, D::Error>
where
D: Deserializer<'de>,
{
let opt = Option::<String>::deserialize(deserializer)?;
match opt.as_deref() {
Some("") => Ok(None),
Some(s) => Uuid::parse_str(s)
.map(Some)
.map_err(serde::de::Error::custom),
None => Ok(None),
}
}
/// JWT payload. See docs/authentication.md for the format /// JWT payload. See docs/authentication.md for the format
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct Claims { pub struct Claims {
#[serde(default)] #[serde(default)]
pub tenant_id: Option<TenantId>, pub tenant_id: Option<TenantId>,
#[serde(
default,
skip_serializing_if = "Option::is_none",
// Neon control plane includes this field as empty in the claims.
// Consider it None in those cases.
deserialize_with = "deserialize_empty_string_as_none_uuid"
)]
pub endpoint_id: Option<Uuid>,
pub scope: Scope, pub scope: Scope,
} }
impl Claims { impl Claims {
pub fn new(tenant_id: Option<TenantId>, scope: Scope) -> Self { pub fn new(tenant_id: Option<TenantId>, scope: Scope) -> Self {
Self { tenant_id, scope } Self {
tenant_id,
scope,
endpoint_id: None,
}
} }
} }
@@ -212,6 +244,7 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
let expected_claims = Claims { let expected_claims = Claims {
tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081").unwrap()), tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081").unwrap()),
scope: Scope::Tenant, scope: Scope::Tenant,
endpoint_id: None,
}; };
// A test token containing the following payload, signed using TEST_PRIV_KEY_ED25519: // A test token containing the following payload, signed using TEST_PRIV_KEY_ED25519:
@@ -240,6 +273,7 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
let claims = Claims { let claims = Claims {
tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081").unwrap()), tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081").unwrap()),
scope: Scope::Tenant, scope: Scope::Tenant,
endpoint_id: None,
}; };
let pem = pem::parse(TEST_PRIV_KEY_ED25519).unwrap(); let pem = pem::parse(TEST_PRIV_KEY_ED25519).unwrap();

View File

@@ -53,6 +53,10 @@ impl ShardCount {
pub const MAX: Self = Self(u8::MAX); pub const MAX: Self = Self(u8::MAX);
pub const MIN: Self = Self(0); pub const MIN: Self = Self(0);
pub fn unsharded() -> Self {
ShardCount(0)
}
/// The internal value of a ShardCount may be zero, which means "1 shard, but use /// The internal value of a ShardCount may be zero, which means "1 shard, but use
/// legacy format for TenantShardId that excludes the shard suffix", also known /// legacy format for TenantShardId that excludes the shard suffix", also known
/// as [`TenantShardId::unsharded`]. /// as [`TenantShardId::unsharded`].

View File

@@ -873,6 +873,22 @@ impl Client {
.map_err(Error::ReceiveBody) .map_err(Error::ReceiveBody)
} }
pub async fn reset_alert_gauges(&self) -> Result<()> {
let uri = format!(
"{}/hadron-internal/reset_alert_gauges",
self.mgmt_api_endpoint
);
self.start_request(Method::POST, uri)
.send()
.await
.map_err(Error::SendRequest)?
.error_from_body()
.await?
.json()
.await
.map_err(Error::ReceiveBody)
}
pub async fn wait_lsn( pub async fn wait_lsn(
&self, &self,
tenant_shard_id: TenantShardId, tenant_shard_id: TenantShardId,

View File

@@ -20,7 +20,8 @@ pub fn check_permission(claims: &Claims, tenant_id: Option<TenantId>) -> Result<
| Scope::GenerationsApi | Scope::GenerationsApi
| Scope::Infra | Scope::Infra
| Scope::Scrubber | Scope::Scrubber
| Scope::ControllerPeer, | Scope::ControllerPeer
| Scope::TenantEndpoint,
_, _,
) => Err(AuthError( ) => Err(AuthError(
format!( format!(

View File

@@ -817,6 +817,7 @@ impl Timeline {
let gc_cutoff_lsn_guard = self.get_applied_gc_cutoff_lsn(); let gc_cutoff_lsn_guard = self.get_applied_gc_cutoff_lsn();
let gc_cutoff_planned = { let gc_cutoff_planned = {
let gc_info = self.gc_info.read().unwrap(); let gc_info = self.gc_info.read().unwrap();
info!(cutoffs=?gc_info.cutoffs, applied_cutoff=%*gc_cutoff_lsn_guard, "starting find_lsn_for_timestamp");
gc_info.min_cutoff() gc_info.min_cutoff()
}; };
// Usually the planned cutoff is newer than the cutoff of the last gc run, // Usually the planned cutoff is newer than the cutoff of the last gc run,

View File

@@ -1,4 +1,3 @@
// Definitions of some core PostgreSQL datatypes. // Definitions of some core PostgreSQL datatypes.
/// XLogRecPtr is defined in "access/xlogdefs.h" as: /// XLogRecPtr is defined in "access/xlogdefs.h" as:

View File

@@ -1,10 +1,13 @@
//! Glue code to hook up Rust logging with the `tracing` crate to the PostgreSQL log //! Glue code to hook up Rust logging with the `tracing` crate to the PostgreSQL log
//! //!
//! In the Rust threads, the log messages are written to a mpsc Channel, and the Postgres //! In the Rust threads, the log messages are written to a mpsc Channel, and the Postgres
//! process latch is raised. That wakes up the loop in the main thread. It reads the //! process latch is raised. That wakes up the loop in the main thread, see
//! message from the channel and ereport()s it. This ensures that only one thread, the main //! `communicator_new_bgworker_main()`. It reads the message from the channel and
//! thread, calls the PostgreSQL logging routines at any time. //! ereport()s it. This ensures that only one thread, the main thread, calls the
//! PostgreSQL logging routines at any time.
use std::ffi::c_char;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::mpsc::sync_channel; use std::sync::mpsc::sync_channel;
use std::sync::mpsc::{Receiver, SyncSender}; use std::sync::mpsc::{Receiver, SyncSender};
use std::sync::mpsc::{TryRecvError, TrySendError}; use std::sync::mpsc::{TryRecvError, TrySendError};
@@ -12,27 +15,32 @@ use std::sync::mpsc::{TryRecvError, TrySendError};
use tracing::info; use tracing::info;
use tracing::{Event, Level, Metadata, Subscriber}; use tracing::{Event, Level, Metadata, Subscriber};
use tracing_subscriber::filter::LevelFilter; use tracing_subscriber::filter::LevelFilter;
use tracing_subscriber::fmt::FmtContext;
use tracing_subscriber::fmt::FormatEvent;
use tracing_subscriber::fmt::FormatFields;
use tracing_subscriber::fmt::FormattedFields;
use tracing_subscriber::fmt::MakeWriter;
use tracing_subscriber::fmt::format::Writer; use tracing_subscriber::fmt::format::Writer;
use tracing_subscriber::fmt::{FmtContext, FormatEvent, FormatFields, FormattedFields, MakeWriter};
use tracing_subscriber::registry::LookupSpan; use tracing_subscriber::registry::LookupSpan;
use crate::worker_process::callbacks::callback_set_my_latch; use crate::worker_process::callbacks::callback_set_my_latch;
pub struct LoggingState { /// This handle is passed to the C code, and used by [`communicator_worker_poll_logging`]
pub struct LoggingReceiver {
receiver: Receiver<FormattedEventWithMeta>, receiver: Receiver<FormattedEventWithMeta>,
} }
/// This is passed to `tracing`
struct LoggingSender {
sender: SyncSender<FormattedEventWithMeta>,
}
static DROPPED_EVENT_COUNT: AtomicU64 = AtomicU64::new(0);
/// Called once, at worker process startup. The returned LoggingState is passed back /// Called once, at worker process startup. The returned LoggingState is passed back
/// in the subsequent calls to `pump_logging`. It is opaque to the C code. /// in the subsequent calls to `pump_logging`. It is opaque to the C code.
#[unsafe(no_mangle)] #[unsafe(no_mangle)]
pub extern "C" fn configure_logging() -> Box<LoggingState> { pub extern "C" fn communicator_worker_configure_logging() -> Box<LoggingReceiver> {
let (sender, receiver) = sync_channel(1000); let (sender, receiver) = sync_channel(1000);
let maker = Maker { channel: sender }; let receiver = LoggingReceiver { receiver };
let sender = LoggingSender { sender };
use tracing_subscriber::prelude::*; use tracing_subscriber::prelude::*;
let r = tracing_subscriber::registry(); let r = tracing_subscriber::registry();
@@ -41,32 +49,45 @@ pub extern "C" fn configure_logging() -> Box<LoggingState> {
tracing_subscriber::fmt::layer() tracing_subscriber::fmt::layer()
.with_ansi(false) .with_ansi(false)
.event_format(SimpleFormatter::new()) .event_format(SimpleFormatter::new())
.with_writer(maker) .with_writer(sender)
// TODO: derive this from log_min_messages? // TODO: derive this from log_min_messages? Currently the code in
// communicator_process.c forces log_min_messages='INFO'.
.with_filter(LevelFilter::from_level(Level::INFO)), .with_filter(LevelFilter::from_level(Level::INFO)),
); );
r.init(); r.init();
info!("communicator process logging started"); info!("communicator process logging started");
let state = LoggingState { receiver }; Box::new(receiver)
Box::new(state)
} }
/// Read one message from the logging queue. This is essentially a wrapper to Receiver, /// Read one message from the logging queue. This is essentially a wrapper to Receiver,
/// with a C-friendly signature. /// with a C-friendly signature.
/// ///
/// The message is copied into *errbuf, which is a caller-supplied buffer of size `errbuf_len`. /// The message is copied into *errbuf, which is a caller-supplied buffer of size
/// If the message doesn't fit in the buffer, it is truncated. It is always NULL-terminated. /// `errbuf_len`. If the message doesn't fit in the buffer, it is truncated. It is always
/// NULL-terminated.
/// ///
/// The error level is returned *elevel_p. It's one of the PostgreSQL error levels, see elog.h /// The error level is returned *elevel_p. It's one of the PostgreSQL error levels, see
/// elog.h
///
/// If there was a message, *dropped_event_count_p is also updated with a counter of how
/// many log messages in total has been dropped. By comparing that with the value from
/// previous call, you can tell how many were dropped since last call.
///
/// Returns:
///
/// 0 if there were no messages
/// 1 if there was a message. The message and its level are returned in
/// *errbuf and *elevel_p. *dropped_event_count_p is also updated.
/// -1 on error, i.e the other end of the queue was disconnected
#[unsafe(no_mangle)] #[unsafe(no_mangle)]
pub extern "C" fn pump_logging( pub extern "C" fn communicator_worker_poll_logging(
state: &mut LoggingState, state: &mut LoggingReceiver,
errbuf: *mut u8, errbuf: *mut c_char,
errbuf_len: u32, errbuf_len: u32,
elevel_p: &mut i32, elevel_p: &mut i32,
dropped_event_count_p: &mut u64,
) -> i32 { ) -> i32 {
let msg = match state.receiver.try_recv() { let msg = match state.receiver.try_recv() {
Err(TryRecvError::Empty) => return 0, Err(TryRecvError::Empty) => return 0,
@@ -75,15 +96,17 @@ pub extern "C" fn pump_logging(
}; };
let src: &[u8] = &msg.message; let src: &[u8] = &msg.message;
let dst = errbuf; let dst: *mut u8 = errbuf.cast();
let len = std::cmp::min(src.len(), errbuf_len as usize - 1); let len = std::cmp::min(src.len(), errbuf_len as usize - 1);
unsafe { unsafe {
std::ptr::copy_nonoverlapping(src.as_ptr(), dst, len); std::ptr::copy_nonoverlapping(src.as_ptr(), dst, len);
*(errbuf.add(len)) = b'\0'; // NULL terminator *(dst.add(len)) = b'\0'; // NULL terminator
} }
// XXX: these levels are copied from PostgreSQL's elog.h. Introduce another enum // Map the tracing Level to PostgreSQL elevel.
// to hide these? //
// XXX: These levels are copied from PostgreSQL's elog.h. Introduce another enum to
// hide these?
*elevel_p = match msg.level { *elevel_p = match msg.level {
Level::TRACE => 10, // DEBUG5 Level::TRACE => 10, // DEBUG5
Level::DEBUG => 14, // DEBUG1 Level::DEBUG => 14, // DEBUG1
@@ -92,6 +115,8 @@ pub extern "C" fn pump_logging(
Level::ERROR => 21, // ERROR Level::ERROR => 21, // ERROR
}; };
*dropped_event_count_p = DROPPED_EVENT_COUNT.load(Ordering::Relaxed);
1 1
} }
@@ -115,7 +140,7 @@ impl Default for FormattedEventWithMeta {
struct EventBuilder<'a> { struct EventBuilder<'a> {
event: FormattedEventWithMeta, event: FormattedEventWithMeta,
maker: &'a Maker, sender: &'a LoggingSender,
} }
impl std::io::Write for EventBuilder<'_> { impl std::io::Write for EventBuilder<'_> {
@@ -123,25 +148,21 @@ impl std::io::Write for EventBuilder<'_> {
self.event.message.write(buf) self.event.message.write(buf)
} }
fn flush(&mut self) -> std::io::Result<()> { fn flush(&mut self) -> std::io::Result<()> {
self.maker.send_event(self.event.clone()); self.sender.send_event(self.event.clone());
Ok(()) Ok(())
} }
} }
impl Drop for EventBuilder<'_> { impl Drop for EventBuilder<'_> {
fn drop(&mut self) { fn drop(&mut self) {
let maker = self.maker; let sender = self.sender;
let event = std::mem::take(&mut self.event); let event = std::mem::take(&mut self.event);
maker.send_event(event); sender.send_event(event);
} }
} }
struct Maker { impl<'a> MakeWriter<'a> for LoggingSender {
channel: SyncSender<FormattedEventWithMeta>,
}
impl<'a> MakeWriter<'a> for Maker {
type Writer = EventBuilder<'a>; type Writer = EventBuilder<'a>;
fn make_writer(&'a self) -> Self::Writer { fn make_writer(&'a self) -> Self::Writer {
@@ -154,33 +175,38 @@ impl<'a> MakeWriter<'a> for Maker {
message: Vec::new(), message: Vec::new(),
level: *meta.level(), level: *meta.level(),
}, },
maker: self, sender: self,
} }
} }
} }
impl Maker { impl LoggingSender {
fn send_event(&self, e: FormattedEventWithMeta) { fn send_event(&self, e: FormattedEventWithMeta) {
match self.channel.try_send(e) { match self.sender.try_send(e) {
Ok(()) => { Ok(()) => {
// notify the main thread // notify the main thread
callback_set_my_latch(); callback_set_my_latch();
} }
Err(TrySendError::Disconnected(_)) => {} Err(TrySendError::Disconnected(_)) => {}
Err(TrySendError::Full(_)) => { Err(TrySendError::Full(_)) => {
// TODO: record that some messages were lost // The queue is full, cannot send any more. To avoid blocking the tokio
// thread, simply drop the message. Better to lose some logs than get
// stuck if there's a problem with the logging.
//
// Record the fact that was a message was dropped by incrementing the
// counter.
DROPPED_EVENT_COUNT.fetch_add(1, Ordering::Relaxed);
} }
} }
} }
} }
/// Simple formatter implementation for tracing_subscriber, which prints the log /// Simple formatter implementation for tracing_subscriber, which prints the log spans and
/// spans and message part like the default formatter, but no timestamp or error /// message part like the default formatter, but no timestamp or error level. The error
/// level. The error level is captured separately by `FormattedEventWithMeta', /// level is captured separately by `FormattedEventWithMeta', and when the error is
/// and when the error is printed by the main thread, with PostgreSQL ereport(), /// printed by the main thread, with PostgreSQL ereport(), it gets a timestamp at that
/// it gets a timestamp at that point. (The timestamp printed will therefore lag /// point. (The timestamp printed will therefore lag behind the timestamp on the event
/// behind the timestamp on the event here, if the main thread doesn't process /// here, if the main thread doesn't process the log message promptly)
/// the log message promptly)
struct SimpleFormatter; struct SimpleFormatter;
impl<S, N> FormatEvent<S, N> for SimpleFormatter impl<S, N> FormatEvent<S, N> for SimpleFormatter
@@ -199,11 +225,10 @@ where
for span in scope.from_root() { for span in scope.from_root() {
write!(writer, "{}", span.name())?; write!(writer, "{}", span.name())?;
// `FormattedFields` is a formatted representation of the span's // `FormattedFields` is a formatted representation of the span's fields,
// fields, which is stored in its extensions by the `fmt` layer's // which is stored in its extensions by the `fmt` layer's `new_span`
// `new_span` method. The fields will have been formatted // method. The fields will have been formatted by the same field formatter
// by the same field formatter that's provided to the event // that's provided to the event formatter in the `FmtContext`.
// formatter in the `FmtContext`.
let ext = span.extensions(); let ext = span.extensions();
let fields = &ext let fields = &ext
.get::<FormattedFields<N>>() .get::<FormattedFields<N>>()
@@ -220,7 +245,7 @@ where
// Write fields on the event // Write fields on the event
ctx.field_format().format_fields(writer.by_ref(), event)?; ctx.field_format().format_fields(writer.by_ref(), event)?;
writeln!(writer) Ok(())
} }
} }

View File

@@ -30,6 +30,7 @@ pub extern "C" fn communicator_worker_process_launch(
file_cache_path: *const c_char, file_cache_path: *const c_char,
initial_file_cache_size: u64, initial_file_cache_size: u64,
) -> &'static CommunicatorWorkerProcessStruct<'static> { ) -> &'static CommunicatorWorkerProcessStruct<'static> {
tracing::warn!("starting threads in rust code");
// Convert the arguments into more convenient Rust types // Convert the arguments into more convenient Rust types
let tenant_id = unsafe { CStr::from_ptr(tenant_id) }.to_str().unwrap(); let tenant_id = unsafe { CStr::from_ptr(tenant_id) }.to_str().unwrap();
let timeline_id = unsafe { CStr::from_ptr(timeline_id) }.to_str().unwrap(); let timeline_id = unsafe { CStr::from_ptr(timeline_id) }.to_str().unwrap();

View File

@@ -142,6 +142,7 @@ static bool bounce_needed(void *buffer);
static void *bounce_buf(void); static void *bounce_buf(void);
static void *bounce_write_if_needed(void *buffer); static void *bounce_write_if_needed(void *buffer);
static void pump_logging(struct LoggingReceiver *logging);
PGDLLEXPORT void communicator_new_bgworker_main(Datum main_arg); PGDLLEXPORT void communicator_new_bgworker_main(Datum main_arg);
static void communicator_new_backend_exit(int code, Datum arg); static void communicator_new_backend_exit(int code, Datum arg);
@@ -184,6 +185,9 @@ pg_init_communicator_new(void)
{ {
BackgroundWorker bgw; BackgroundWorker bgw;
if (!neon_use_communicator_worker)
return;
if (pageserver_connstring[0] == '\0' && pageserver_grpc_urls[0] == '\0') if (pageserver_connstring[0] == '\0' && pageserver_grpc_urls[0] == '\0')
{ {
/* running with local storage */ /* running with local storage */
@@ -211,6 +215,9 @@ communicator_new_shmem_size(void)
size_t size = 0; size_t size = 0;
int num_request_slots; int num_request_slots;
if (!neon_use_communicator_worker)
return 0;
size += MAXALIGN( size += MAXALIGN(
offsetof(CommunicatorShmemData, backends) + offsetof(CommunicatorShmemData, backends) +
MaxProcs * sizeof(CommunicatorShmemPerBackendData) MaxProcs * sizeof(CommunicatorShmemPerBackendData)
@@ -225,13 +232,16 @@ communicator_new_shmem_size(void)
} }
void void
communicator_new_shmem_request(void) CommunicatorNewShmemRequest(void)
{ {
if (!neon_use_communicator_worker)
return;
RequestAddinShmemSpace(communicator_new_shmem_size()); RequestAddinShmemSpace(communicator_new_shmem_size());
} }
void void
communicator_new_shmem_startup(void) CommunicatorNewShmemInit(void)
{ {
bool found; bool found;
int pipefd[2]; int pipefd[2];
@@ -242,6 +252,9 @@ communicator_new_shmem_startup(void)
uint64 initial_file_cache_size; uint64 initial_file_cache_size;
uint64 max_file_cache_size; uint64 max_file_cache_size;
if (!neon_use_communicator_worker)
return;
rc = pipe(pipefd); rc = pipe(pipefd);
if (rc != 0) if (rc != 0)
ereport(ERROR, ereport(ERROR,
@@ -294,10 +307,8 @@ communicator_new_bgworker_main(Datum main_arg)
{ {
char **connstrings; char **connstrings;
ShardMap shard_map; ShardMap shard_map;
struct LoggingState *logging;
char errbuf[1000];
int elevel;
uint64 file_cache_size; uint64 file_cache_size;
struct LoggingReceiver *logging;
const struct CommunicatorWorkerProcessStruct *proc_handle; const struct CommunicatorWorkerProcessStruct *proc_handle;
/* /*
@@ -334,8 +345,21 @@ communicator_new_bgworker_main(Datum main_arg)
for (int i = 0; i < shard_map.num_shards; i++) for (int i = 0; i < shard_map.num_shards; i++)
connstrings[i] = shard_map.connstring[i]; connstrings[i] = shard_map.connstring[i];
logging = configure_logging(); /*
* By default, INFO messages are not printed to the log. We want
* `tracing::info!` messages emitted from the communicator to be printed,
* however, so increase the log level.
*
* XXX: This overrides any user-set value from the config file. That's not
* great, but on the other hand, there should be little reason for user to
* control the verbosity of the communicator. It's not too verbose by
* default.
*/
SetConfigOption("log_min_messages", "INFO", PGC_SUSET, PGC_S_OVERRIDE);
logging = communicator_worker_configure_logging();
elog(LOG, "launching worker process threads");
proc_handle = communicator_worker_process_launch( proc_handle = communicator_worker_process_launch(
cis, cis,
neon_tenant, neon_tenant,
@@ -348,11 +372,27 @@ communicator_new_bgworker_main(Datum main_arg)
file_cache_size); file_cache_size);
pfree(connstrings); pfree(connstrings);
cis = NULL; cis = NULL;
if (proc_handle == NULL)
{
/*
* Something went wrong. Before exiting, forward any log messages that
* might've been generated during the failed launch.
*/
pump_logging(logging);
elog(PANIC, "failure launching threads");
}
elog(LOG, "communicator threads started"); elog(LOG, "communicator threads started");
for (;;) for (;;)
{ {
int32 rc; ResetLatch(MyLatch);
/*
* Forward any log messages from the Rust threads into the normal
* Postgres logging facility.
*/
pump_logging(logging);
CHECK_FOR_INTERRUPTS(); CHECK_FOR_INTERRUPTS();
@@ -383,37 +423,81 @@ communicator_new_bgworker_main(Datum main_arg)
pfree(connstrings); pfree(connstrings);
} }
for (;;)
{
rc = pump_logging(logging, (uint8 *) errbuf, sizeof(errbuf), &elevel);
if (rc == 0)
{
/* nothing to do */
break;
}
else if (rc == 1)
{
/* Because we don't want to exit on error */
if (elevel == ERROR)
elevel = LOG;
if (elevel == INFO)
elevel = LOG;
elog(elevel, "[COMMUNICATOR] %s", errbuf);
}
else if (rc == -1)
{
elog(ERROR, "logging channel was closed unexpectedly");
}
}
(void) WaitLatch(MyLatch, (void) WaitLatch(MyLatch,
WL_LATCH_SET | WL_EXIT_ON_PM_DEATH, WL_LATCH_SET | WL_EXIT_ON_PM_DEATH,
0, 0,
PG_WAIT_EXTENSION); PG_WAIT_EXTENSION);
ResetLatch(MyLatch);
} }
} }
static void
pump_logging(struct LoggingReceiver *logging)
{
char errbuf[1000];
int elevel;
int32 rc;
static uint64_t last_dropped_event_count = 0;
uint64_t dropped_event_count;
uint64_t dropped_now;
for (;;)
{
rc = communicator_worker_poll_logging(logging,
errbuf,
sizeof(errbuf),
&elevel,
&dropped_event_count);
if (rc == 0)
{
/* nothing to do */
break;
}
else if (rc == 1)
{
/* Because we don't want to exit on error */
if (message_level_is_interesting(elevel))
{
/*
* Prevent interrupts while cleaning up.
*
* (Not sure if this is required, but all the error handlers
* in Postgres that are installed as sigsetjmp() targets do
* this, so let's follow the example)
*/
HOLD_INTERRUPTS();
errstart(elevel, TEXTDOMAIN);
errmsg_internal("[COMMUNICATOR] %s", errbuf);
EmitErrorReport();
FlushErrorState();
/* Now we can allow interrupts again */
RESUME_INTERRUPTS();
}
}
else if (rc == -1)
{
elog(ERROR, "logging channel was closed unexpectedly");
}
}
/*
* If the queue was full at any time since the last time we reported it,
* report how many messages were lost. We do this outside the loop, so
* that if the logging system is clogged, we don't exacerbate it by
* printing lots of warnings about dropped messages.
*/
dropped_now = dropped_event_count - last_dropped_event_count;
if (dropped_now != 0)
{
elog(WARNING, "%lu communicator log messages were dropped because the log buffer was full",
(unsigned long) dropped_now);
last_dropped_event_count = dropped_event_count;
}
}
/* /*
* Callbacks from the rust code, in the communicator process. * Callbacks from the rust code, in the communicator process.
* *

View File

@@ -20,8 +20,8 @@
/* initialization at postmaster startup */ /* initialization at postmaster startup */
extern void pg_init_communicator_new(void); extern void pg_init_communicator_new(void);
extern void communicator_new_shmem_request(void); extern void CommunicatorNewShmemRequest(void);
extern void communicator_new_shmem_startup(void); extern void CommunicatorNewShmemInit(void);
/* initialization at backend startup */ /* initialization at backend startup */
extern void communicator_new_init(void); extern void communicator_new_init(void);

View File

@@ -219,10 +219,6 @@ char *lfc_path;
static uint64 lfc_generation; static uint64 lfc_generation;
static FileCacheControl *lfc_ctl; static FileCacheControl *lfc_ctl;
static bool lfc_do_prewarm; static bool lfc_do_prewarm;
static shmem_startup_hook_type prev_shmem_startup_hook;
#if PG_VERSION_NUM>=150000
static shmem_request_hook_type prev_shmem_request_hook;
#endif
bool lfc_store_prefetch_result; bool lfc_store_prefetch_result;
bool lfc_prewarm_update_ws_estimation; bool lfc_prewarm_update_ws_estimation;
@@ -346,20 +342,17 @@ lfc_ensure_opened(void)
return true; return true;
} }
static void void
lfc_shmem_startup(void) LfcShmemInit(void)
{ {
bool found; bool found;
static HASHCTL info; static HASHCTL info;
Assert(!neon_use_communicator_worker); if (neon_use_communicator_worker)
return;
if (prev_shmem_startup_hook) if (lfc_max_size <= 0)
{ return;
prev_shmem_startup_hook();
}
LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
lfc_ctl = (FileCacheControl *) ShmemInitStruct("lfc", sizeof(FileCacheControl), &found); lfc_ctl = (FileCacheControl *) ShmemInitStruct("lfc", sizeof(FileCacheControl), &found);
if (!found) if (!found)
@@ -404,19 +397,16 @@ lfc_shmem_startup(void)
ConditionVariableInit(&lfc_ctl->cv[i]); ConditionVariableInit(&lfc_ctl->cv[i]);
} }
LWLockRelease(AddinShmemInitLock);
} }
static void void
lfc_shmem_request(void) LfcShmemRequest(void)
{ {
#if PG_VERSION_NUM>=150000 if (lfc_max_size > 0)
if (prev_shmem_request_hook) {
prev_shmem_request_hook(); RequestAddinShmemSpace(sizeof(FileCacheControl) + hash_estimate_size(SIZE_MB_TO_CHUNKS(lfc_max_size) + 1, FILE_CACHE_ENRTY_SIZE));
#endif RequestNamedLWLockTranche("lfc_lock", 1);
}
RequestAddinShmemSpace(sizeof(FileCacheControl) + hash_estimate_size(SIZE_MB_TO_CHUNKS(lfc_max_size) + 1, FILE_CACHE_ENRTY_SIZE));
RequestNamedLWLockTranche("lfc_lock", 1);
} }
static bool static bool
@@ -550,7 +540,6 @@ lfc_init(void)
if (!process_shared_preload_libraries_in_progress) if (!process_shared_preload_libraries_in_progress)
neon_log(ERROR, "Neon module should be loaded via shared_preload_libraries"); neon_log(ERROR, "Neon module should be loaded via shared_preload_libraries");
DefineCustomBoolVariable("neon.store_prefetch_result_in_lfc", DefineCustomBoolVariable("neon.store_prefetch_result_in_lfc",
"Immediately store received prefetch result in LFC", "Immediately store received prefetch result in LFC",
NULL, NULL,
@@ -648,21 +637,6 @@ lfc_init(void)
NULL, NULL,
NULL, NULL,
NULL); NULL);
if (lfc_max_size == 0)
return;
if (neon_use_communicator_worker)
return;
prev_shmem_startup_hook = shmem_startup_hook;
shmem_startup_hook = lfc_shmem_startup;
#if PG_VERSION_NUM>=150000
prev_shmem_request_hook = shmem_request_hook;
shmem_request_hook = lfc_shmem_request;
#else
lfc_shmem_request();
#endif
} }
FileCacheState* FileCacheState*

View File

@@ -105,6 +105,11 @@ static int pageserver_response_disconnect_timeout = 150000;
* has changed since last access, and to detect and retry copying the value if * has changed since last access, and to detect and retry copying the value if
* the postmaster changes the value concurrently. (Postmaster doesn't have a * the postmaster changes the value concurrently. (Postmaster doesn't have a
* PGPROC entry and therefore cannot use LWLocks.) * PGPROC entry and therefore cannot use LWLocks.)
*
* stripe_size is now also part of ShardMap, although it is defined by separate GUC.
* Postgres doesn't provide any mechanism to enforce dependencies between GUCs,
* that it we we have to rely on order of GUC definition in config file.
* "neon.stripe_size" should be defined prior to "neon.pageserver_connstring"
*/ */
typedef struct typedef struct
{ {
@@ -113,10 +118,6 @@ typedef struct
ShardMap shard_map; ShardMap shard_map;
} PagestoreShmemState; } PagestoreShmemState;
#if PG_VERSION_NUM >= 150000
static shmem_request_hook_type prev_shmem_request_hook = NULL;
#endif
static shmem_startup_hook_type prev_shmem_startup_hook;
static PagestoreShmemState *pagestore_shared; static PagestoreShmemState *pagestore_shared;
static uint64 pagestore_local_counter = 0; static uint64 pagestore_local_counter = 0;
@@ -231,7 +232,10 @@ parse_shard_map(const char *connstr, ShardMap *result)
p = sep + 1; p = sep + 1;
} }
if (result) if (result)
{
result->num_shards = nshards; result->num_shards = nshards;
result->stripe_size = neon_stripe_size;
}
return true; return true;
} }
@@ -317,12 +321,13 @@ AssignShardMap(const char *newval)
* last call, terminates all existing connections to all pageservers. * last call, terminates all existing connections to all pageservers.
*/ */
static void static void
load_shard_map(shardno_t shard_no, char *connstr_p, shardno_t *num_shards_p) load_shard_map(shardno_t shard_no, char *connstr_p, shardno_t *num_shards_p, size_t* stripe_size_p)
{ {
uint64 begin_update_counter; uint64 begin_update_counter;
uint64 end_update_counter; uint64 end_update_counter;
ShardMap *shard_map = &pagestore_shared->shard_map; ShardMap *shard_map = &pagestore_shared->shard_map;
shardno_t num_shards; shardno_t num_shards;
size_t stripe_size;
/* /*
* Postmaster can update the shared memory values concurrently, in which * Postmaster can update the shared memory values concurrently, in which
@@ -337,6 +342,7 @@ load_shard_map(shardno_t shard_no, char *connstr_p, shardno_t *num_shards_p)
end_update_counter = pg_atomic_read_u64(&pagestore_shared->end_update_counter); end_update_counter = pg_atomic_read_u64(&pagestore_shared->end_update_counter);
num_shards = shard_map->num_shards; num_shards = shard_map->num_shards;
stripe_size = shard_map->stripe_size;
if (connstr_p && shard_no < MAX_SHARDS) if (connstr_p && shard_no < MAX_SHARDS)
strlcpy(connstr_p, shard_map->connstring[shard_no], MAX_PAGESERVER_CONNSTRING_SIZE); strlcpy(connstr_p, shard_map->connstring[shard_no], MAX_PAGESERVER_CONNSTRING_SIZE);
pg_memory_barrier(); pg_memory_barrier();
@@ -371,6 +377,8 @@ load_shard_map(shardno_t shard_no, char *connstr_p, shardno_t *num_shards_p)
if (num_shards_p) if (num_shards_p)
*num_shards_p = num_shards; *num_shards_p = num_shards;
if (stripe_size_p)
*stripe_size_p = stripe_size;
} }
#define MB (1024*1024) #define MB (1024*1024)
@@ -379,9 +387,10 @@ shardno_t
get_shard_number(BufferTag *tag) get_shard_number(BufferTag *tag)
{ {
shardno_t n_shards; shardno_t n_shards;
size_t stripe_size;
uint32 hash; uint32 hash;
load_shard_map(0, NULL, &n_shards); load_shard_map(0, NULL, &n_shards, &stripe_size);
#if PG_MAJORVERSION_NUM < 16 #if PG_MAJORVERSION_NUM < 16
hash = murmurhash32(tag->rnode.relNode); hash = murmurhash32(tag->rnode.relNode);
@@ -434,7 +443,7 @@ pageserver_connect(shardno_t shard_no, int elevel)
* Note that connstr is used both during connection start, and when we * Note that connstr is used both during connection start, and when we
* log the successful connection. * log the successful connection.
*/ */
load_shard_map(shard_no, connstr, NULL); load_shard_map(shard_no, connstr, NULL, NULL);
switch (shard->state) switch (shard->state)
{ {
@@ -1306,18 +1315,12 @@ check_neon_id(char **newval, void **extra, GucSource source)
return **newval == '\0' || HexDecodeString(id, *newval, 16); return **newval == '\0' || HexDecodeString(id, *newval, 16);
} }
static Size
PagestoreShmemSize(void)
{
return add_size(sizeof(PagestoreShmemState), NeonPerfCountersShmemSize());
}
static bool void
PagestoreShmemInit(void) PagestoreShmemInit(void)
{ {
bool found; bool found;
LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
pagestore_shared = ShmemInitStruct("libpagestore shared state", pagestore_shared = ShmemInitStruct("libpagestore shared state",
sizeof(PagestoreShmemState), sizeof(PagestoreShmemState),
&found); &found);
@@ -1328,44 +1331,12 @@ PagestoreShmemInit(void)
memset(&pagestore_shared->shard_map, 0, sizeof(ShardMap)); memset(&pagestore_shared->shard_map, 0, sizeof(ShardMap));
AssignPageserverConnstring(pageserver_connstring, NULL); AssignPageserverConnstring(pageserver_connstring, NULL);
} }
NeonPerfCountersShmemInit();
LWLockRelease(AddinShmemInitLock);
return found;
} }
static void void
pagestore_shmem_startup_hook(void) PagestoreShmemRequest(void)
{ {
if (prev_shmem_startup_hook) RequestAddinShmemSpace(sizeof(PagestoreShmemState));
prev_shmem_startup_hook();
PagestoreShmemInit();
}
static void
pagestore_shmem_request(void)
{
#if PG_VERSION_NUM >= 150000
if (prev_shmem_request_hook)
prev_shmem_request_hook();
#endif
RequestAddinShmemSpace(PagestoreShmemSize());
}
static void
pagestore_prepare_shmem(void)
{
#if PG_VERSION_NUM >= 150000
prev_shmem_request_hook = shmem_request_hook;
shmem_request_hook = pagestore_shmem_request;
#else
pagestore_shmem_request();
#endif
prev_shmem_startup_hook = shmem_startup_hook;
shmem_startup_hook = pagestore_shmem_startup_hook;
} }
/* /*
@@ -1374,8 +1345,6 @@ pagestore_prepare_shmem(void)
void void
pg_init_libpagestore(void) pg_init_libpagestore(void)
{ {
pagestore_prepare_shmem();
DefineCustomStringVariable("neon.pageserver_connstring", DefineCustomStringVariable("neon.pageserver_connstring",
"connection string to the page server", "connection string to the page server",
NULL, NULL,
@@ -1535,8 +1504,6 @@ pg_init_libpagestore(void)
0, 0,
NULL, NULL, NULL); NULL, NULL, NULL);
relsize_hash_init();
if (page_server != NULL) if (page_server != NULL)
neon_log(ERROR, "libpagestore already loaded"); neon_log(ERROR, "libpagestore already loaded");

View File

@@ -23,6 +23,7 @@
#include "replication/walsender.h" #include "replication/walsender.h"
#include "storage/ipc.h" #include "storage/ipc.h"
#include "storage/proc.h" #include "storage/proc.h"
#include "storage/ipc.h"
#include "funcapi.h" #include "funcapi.h"
#include "access/htup_details.h" #include "access/htup_details.h"
#include "utils/builtins.h" #include "utils/builtins.h"
@@ -62,12 +63,13 @@ static void neon_ExecutorStart(QueryDesc *queryDesc, int eflags);
static void neon_ExecutorEnd(QueryDesc *queryDesc); static void neon_ExecutorEnd(QueryDesc *queryDesc);
static shmem_startup_hook_type prev_shmem_startup_hook; static shmem_startup_hook_type prev_shmem_startup_hook;
#if PG_VERSION_NUM>=150000 static void neon_shmem_startup_hook(void);
static shmem_request_hook_type prev_shmem_request_hook; static void neon_shmem_request_hook(void);
#if PG_MAJORVERSION_NUM >= 15
static shmem_request_hook_type prev_shmem_request_hook = NULL;
#endif #endif
static void neon_shmem_request(void);
static void neon_shmem_startup_hook(void);
#if PG_MAJORVERSION_NUM >= 17 #if PG_MAJORVERSION_NUM >= 17
uint32 WAIT_EVENT_NEON_LFC_MAINTENANCE; uint32 WAIT_EVENT_NEON_LFC_MAINTENANCE;
@@ -457,15 +459,6 @@ _PG_init(void)
load_file("$libdir/neon_rmgr", false); load_file("$libdir/neon_rmgr", false);
#endif #endif
prev_shmem_startup_hook = shmem_startup_hook;
shmem_startup_hook = neon_shmem_startup_hook;
#if PG_VERSION_NUM>=150000
prev_shmem_request_hook = shmem_request_hook;
shmem_request_hook = neon_shmem_request;
#else
neon_shmem_request();
#endif
DefineCustomBoolVariable( DefineCustomBoolVariable(
"neon.use_communicator_worker", "neon.use_communicator_worker",
"Uses the communicator worker implementation", "Uses the communicator worker implementation",
@@ -476,14 +469,45 @@ _PG_init(void)
0, 0,
NULL, NULL, NULL); NULL, NULL, NULL);
/*
* Initializing a pre-loaded Postgres extension happens in three stages:
*
* 1. _PG_init() is called early at postmaster startup. In this stage, no
* shared memory has been allocated yet. Core Postgres GUCs have been
* initialized from the config files, but notably, MaxBackends has not
* calculated yet. In this stage, we must register any extension GUCs
* and can do other early initialization that doesn't depend on shared
* memory. In this stage we must also register "shmem request" and
* "shmem starutup" hooks, to be called in stages 2 and 3.
*
* 2. After MaxBackends have been calculated, the "shmem request" hooks
* are called. The hooks can reserve shared memory by calling
* RequestAddinShmemSpace and RequestNamedLWLockTranche(). The "shmem
* request hooks" are a new mechanism in Postgres v15. In v14 and
* below, you had to make those Requests in stage 1 already, which
* means they could not depend on MaxBackends. (See hack in
* NeonPerfCountersShmemRequest())
*
* 3. After some more runtime-computed GUCs that affect the amount of
* shared memory needed have been calculated, the "shmem startup" hooks
* are called. In this stage, we allocate any shared memory, LWLocks
* and other shared resources.
*
* Here, in the 'neon' extension, we register just one shmem request hook
* and one startup hook, which call into functions in all the subsystems
* that are part of the extension. On v14, the ShmemRequest functions are
* called in stage 1, and on v15 onwards they are called in stage 2.
*/
/* Stage 1: Define GUCs, and other early intialization */
pg_init_libpagestore(); pg_init_libpagestore();
relsize_hash_init();
lfc_init(); lfc_init();
pg_init_walproposer(); pg_init_walproposer();
init_lwlsncache(); init_lwlsncache();
pg_init_communicator(); pg_init_communicator();
if (neon_use_communicator_worker) pg_init_communicator_new();
pg_init_communicator_new();
Custom_XLogReaderRoutines = NeonOnDemandXLogReaderRoutines; Custom_XLogReaderRoutines = NeonOnDemandXLogReaderRoutines;
@@ -582,6 +606,22 @@ _PG_init(void)
ReportSearchPath(); ReportSearchPath();
/*
* Register initialization hooks for stage 2. (On v14, there's no "shmem
* request" hooks, so call the ShmemRequest functions immediately.)
*/
#if PG_VERSION_NUM >= 150000
prev_shmem_request_hook = shmem_request_hook;
shmem_request_hook = neon_shmem_request_hook;
#else
neon_shmem_request_hook();
#endif
/* Register hooks for stage 3 */
prev_shmem_startup_hook = shmem_startup_hook;
shmem_startup_hook = neon_shmem_startup_hook;
/* Other misc initialization */
prev_ExecutorStart = ExecutorStart_hook; prev_ExecutorStart = ExecutorStart_hook;
ExecutorStart_hook = neon_ExecutorStart; ExecutorStart_hook = neon_ExecutorStart;
prev_ExecutorEnd = ExecutorEnd_hook; prev_ExecutorEnd = ExecutorEnd_hook;
@@ -673,17 +713,35 @@ approximate_working_set_size(PG_FUNCTION_ARGS)
PG_RETURN_INT32(dc); PG_RETURN_INT32(dc);
} }
/*
* Initialization stage 2: make requests for the amount of shared memory we
* will need.
*
* For a high-level explanation of the initialization process, see _PG_init().
*/
static void static void
neon_shmem_request(void) neon_shmem_request_hook(void)
{ {
#if PG_VERSION_NUM>=150000 #if PG_VERSION_NUM >= 150000
if (prev_shmem_request_hook) if (prev_shmem_request_hook)
prev_shmem_request_hook(); prev_shmem_request_hook();
#endif #endif
communicator_new_shmem_request(); LfcShmemRequest();
NeonPerfCountersShmemRequest();
PagestoreShmemRequest();
RelsizeCacheShmemRequest();
WalproposerShmemRequest();
LwLsnCacheShmemRequest();
CommunicatorNewShmemRequest();
} }
/*
* Initialization stage 3: Initialize shared memory.
*
* For a high-level explanation of the initialization process, see _PG_init().
*/
static void static void
neon_shmem_startup_hook(void) neon_shmem_startup_hook(void)
{ {
@@ -691,6 +749,16 @@ neon_shmem_startup_hook(void)
if (prev_shmem_startup_hook) if (prev_shmem_startup_hook)
prev_shmem_startup_hook(); prev_shmem_startup_hook();
LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
LfcShmemInit();
NeonPerfCountersShmemInit();
PagestoreShmemInit();
RelsizeCacheShmemInit();
WalproposerShmemInit();
LwLsnCacheShmemInit();
CommunicatorNewShmemInit();
#if PG_MAJORVERSION_NUM >= 17 #if PG_MAJORVERSION_NUM >= 17
WAIT_EVENT_NEON_LFC_MAINTENANCE = WaitEventExtensionNew("Neon/FileCache_Maintenance"); WAIT_EVENT_NEON_LFC_MAINTENANCE = WaitEventExtensionNew("Neon/FileCache_Maintenance");
WAIT_EVENT_NEON_LFC_READ = WaitEventExtensionNew("Neon/FileCache_Read"); WAIT_EVENT_NEON_LFC_READ = WaitEventExtensionNew("Neon/FileCache_Read");
@@ -704,7 +772,7 @@ neon_shmem_startup_hook(void)
WAIT_EVENT_NEON_WAL_DL = WaitEventExtensionNew("Neon/WAL_Download"); WAIT_EVENT_NEON_WAL_DL = WaitEventExtensionNew("Neon/WAL_Download");
#endif #endif
communicator_new_shmem_startup(); LWLockRelease(AddinShmemInitLock);
} }
/* /*

View File

@@ -70,4 +70,19 @@ extern PGDLLEXPORT void WalProposerSync(int argc, char *argv[]);
extern PGDLLEXPORT void WalProposerMain(Datum main_arg); extern PGDLLEXPORT void WalProposerMain(Datum main_arg);
extern PGDLLEXPORT void LogicalSlotsMonitorMain(Datum main_arg); extern PGDLLEXPORT void LogicalSlotsMonitorMain(Datum main_arg);
extern void LfcShmemRequest(void);
extern void PagestoreShmemRequest(void);
extern void RelsizeCacheShmemRequest(void);
extern void WalproposerShmemRequest(void);
extern void LwLsnCacheShmemRequest(void);
extern void NeonPerfCountersShmemRequest(void);
extern void LfcShmemInit(void);
extern void PagestoreShmemInit(void);
extern void RelsizeCacheShmemInit(void);
extern void WalproposerShmemInit(void);
extern void LwLsnCacheShmemInit(void);
extern void NeonPerfCountersShmemInit(void);
#endif /* NEON_H */ #endif /* NEON_H */

View File

@@ -1,5 +1,6 @@
#include "postgres.h" #include "postgres.h"
#include "neon.h"
#include "neon_lwlsncache.h" #include "neon_lwlsncache.h"
#include "miscadmin.h" #include "miscadmin.h"
@@ -81,14 +82,6 @@ static set_max_lwlsn_hook_type prev_set_max_lwlsn_hook = NULL;
static set_lwlsn_relation_hook_type prev_set_lwlsn_relation_hook = NULL; static set_lwlsn_relation_hook_type prev_set_lwlsn_relation_hook = NULL;
static set_lwlsn_db_hook_type prev_set_lwlsn_db_hook = NULL; static set_lwlsn_db_hook_type prev_set_lwlsn_db_hook = NULL;
static shmem_startup_hook_type prev_shmem_startup_hook;
#if PG_VERSION_NUM >= 150000
static shmem_request_hook_type prev_shmem_request_hook;
#endif
static void shmemrequest(void);
static void shmeminit(void);
static void neon_set_max_lwlsn(XLogRecPtr lsn); static void neon_set_max_lwlsn(XLogRecPtr lsn);
void void
@@ -99,16 +92,6 @@ init_lwlsncache(void)
lwlc_register_gucs(); lwlc_register_gucs();
prev_shmem_startup_hook = shmem_startup_hook;
shmem_startup_hook = shmeminit;
#if PG_VERSION_NUM >= 150000
prev_shmem_request_hook = shmem_request_hook;
shmem_request_hook = shmemrequest;
#else
shmemrequest();
#endif
prev_set_lwlsn_block_range_hook = set_lwlsn_block_range_hook; prev_set_lwlsn_block_range_hook = set_lwlsn_block_range_hook;
set_lwlsn_block_range_hook = neon_set_lwlsn_block_range; set_lwlsn_block_range_hook = neon_set_lwlsn_block_range;
prev_set_lwlsn_block_v_hook = set_lwlsn_block_v_hook; prev_set_lwlsn_block_v_hook = set_lwlsn_block_v_hook;
@@ -124,20 +107,19 @@ init_lwlsncache(void)
} }
static void shmemrequest(void) { void
LwLsnCacheShmemRequest(void)
{
Size requested_size = sizeof(LwLsnCacheCtl); Size requested_size = sizeof(LwLsnCacheCtl);
requested_size += hash_estimate_size(lwlsn_cache_size, sizeof(LastWrittenLsnCacheEntry)); requested_size += hash_estimate_size(lwlsn_cache_size, sizeof(LastWrittenLsnCacheEntry));
RequestAddinShmemSpace(requested_size); RequestAddinShmemSpace(requested_size);
#if PG_VERSION_NUM >= 150000
if (prev_shmem_request_hook)
prev_shmem_request_hook();
#endif
} }
static void shmeminit(void) { void
LwLsnCacheShmemInit(void)
{
static HASHCTL info; static HASHCTL info;
bool found; bool found;
if (lwlsn_cache_size > 0) if (lwlsn_cache_size > 0)
@@ -157,9 +139,6 @@ static void shmeminit(void) {
} }
dlist_init(&LwLsnCache->lastWrittenLsnLRU); dlist_init(&LwLsnCache->lastWrittenLsnLRU);
LwLsnCache->maxLastWrittenLsn = GetRedoRecPtr(); LwLsnCache->maxLastWrittenLsn = GetRedoRecPtr();
if (prev_shmem_startup_hook) {
prev_shmem_startup_hook();
}
} }
/* /*

View File

@@ -17,22 +17,32 @@
#include "storage/shmem.h" #include "storage/shmem.h"
#include "utils/builtins.h" #include "utils/builtins.h"
#include "neon.h"
#include "neon_perf_counters.h" #include "neon_perf_counters.h"
#include "neon_pgversioncompat.h" #include "neon_pgversioncompat.h"
neon_per_backend_counters *neon_per_backend_counters_shared; neon_per_backend_counters *neon_per_backend_counters_shared;
Size void
NeonPerfCountersShmemSize(void) NeonPerfCountersShmemRequest(void)
{ {
Size size = 0; Size size;
#if PG_MAJORVERSION_NUM < 15
size = add_size(size, mul_size(NUM_NEON_PERF_COUNTER_SLOTS, /* Hack: in PG14 MaxBackends is not initialized at the time of calling NeonPerfCountersShmemRequest function.
sizeof(neon_per_backend_counters))); * Do it ourselves and then undo to prevent assertion failure
*/
return size; Assert(MaxBackends == 0); /* not initialized yet */
InitializeMaxBackends();
size = mul_size(NUM_NEON_PERF_COUNTER_SLOTS, sizeof(neon_per_backend_counters));
MaxBackends = 0;
#else
size = mul_size(NUM_NEON_PERF_COUNTER_SLOTS, sizeof(neon_per_backend_counters));
#endif
RequestAddinShmemSpace(size);
} }
void void
NeonPerfCountersShmemInit(void) NeonPerfCountersShmemInit(void)
{ {

View File

@@ -250,6 +250,7 @@ typedef struct
{ {
char connstring[MAX_SHARDS][MAX_PAGESERVER_CONNSTRING_SIZE]; char connstring[MAX_SHARDS][MAX_PAGESERVER_CONNSTRING_SIZE];
size_t num_shards; size_t num_shards;
size_t stripe_size;
} ShardMap; } ShardMap;
extern bool parse_shard_map(const char *connstr, ShardMap *result); extern bool parse_shard_map(const char *connstr, ShardMap *result);

View File

@@ -48,32 +48,23 @@ typedef struct
* algorithm */ * algorithm */
} RelSizeHashControl; } RelSizeHashControl;
static HTAB *relsize_hash;
static LWLockId relsize_lock;
static int relsize_hash_size;
static RelSizeHashControl* relsize_ctl;
static shmem_startup_hook_type prev_shmem_startup_hook = NULL;
#if PG_VERSION_NUM >= 150000
static shmem_request_hook_type prev_shmem_request_hook = NULL;
static void relsize_shmem_request(void);
#endif
/* /*
* Size of a cache entry is 36 bytes. So this default will take about 2.3 MB, * Size of a cache entry is 36 bytes. So this default will take about 2.3 MB,
* which seems reasonable. * which seems reasonable.
*/ */
#define DEFAULT_RELSIZE_HASH_SIZE (64 * 1024) #define DEFAULT_RELSIZE_HASH_SIZE (64 * 1024)
static void static HTAB *relsize_hash;
neon_smgr_shmem_startup(void) static LWLockId relsize_lock;
static int relsize_hash_size = DEFAULT_RELSIZE_HASH_SIZE;
static RelSizeHashControl* relsize_ctl;
void
RelsizeCacheShmemInit(void)
{ {
static HASHCTL info; static HASHCTL info;
bool found; bool found;
if (prev_shmem_startup_hook)
prev_shmem_startup_hook();
LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
relsize_ctl = (RelSizeHashControl *) ShmemInitStruct("relsize_hash", sizeof(RelSizeHashControl), &found); relsize_ctl = (RelSizeHashControl *) ShmemInitStruct("relsize_hash", sizeof(RelSizeHashControl), &found);
if (!found) if (!found)
{ {
@@ -84,7 +75,6 @@ neon_smgr_shmem_startup(void)
relsize_hash_size, relsize_hash_size, relsize_hash_size, relsize_hash_size,
&info, &info,
HASH_ELEM | HASH_BLOBS); HASH_ELEM | HASH_BLOBS);
LWLockRelease(AddinShmemInitLock);
relsize_ctl->size = 0; relsize_ctl->size = 0;
relsize_ctl->hits = 0; relsize_ctl->hits = 0;
relsize_ctl->misses = 0; relsize_ctl->misses = 0;
@@ -249,34 +239,15 @@ relsize_hash_init(void)
PGC_POSTMASTER, PGC_POSTMASTER,
0, 0,
NULL, NULL, NULL); NULL, NULL, NULL);
if (relsize_hash_size > 0)
{
#if PG_VERSION_NUM >= 150000
prev_shmem_request_hook = shmem_request_hook;
shmem_request_hook = relsize_shmem_request;
#else
RequestAddinShmemSpace(hash_estimate_size(relsize_hash_size, sizeof(RelSizeEntry)));
RequestNamedLWLockTranche("neon_relsize", 1);
#endif
prev_shmem_startup_hook = shmem_startup_hook;
shmem_startup_hook = neon_smgr_shmem_startup;
}
} }
#if PG_VERSION_NUM >= 150000
/* /*
* shmem_request hook: request additional shared resources. We'll allocate or * shmem_request hook: request additional shared resources. We'll allocate or
* attach to the shared resources in neon_smgr_shmem_startup(). * attach to the shared resources in neon_smgr_shmem_startup().
*/ */
static void void
relsize_shmem_request(void) RelsizeCacheShmemRequest(void)
{ {
if (prev_shmem_request_hook)
prev_shmem_request_hook();
RequestAddinShmemSpace(sizeof(RelSizeHashControl) + hash_estimate_size(relsize_hash_size, sizeof(RelSizeEntry))); RequestAddinShmemSpace(sizeof(RelSizeHashControl) + hash_estimate_size(relsize_hash_size, sizeof(RelSizeEntry)));
RequestNamedLWLockTranche("neon_relsize", 1); RequestNamedLWLockTranche("neon_relsize", 1);
} }
#endif

View File

@@ -83,10 +83,8 @@ static XLogRecPtr standby_flush_lsn = InvalidXLogRecPtr;
static XLogRecPtr standby_apply_lsn = InvalidXLogRecPtr; static XLogRecPtr standby_apply_lsn = InvalidXLogRecPtr;
static HotStandbyFeedback agg_hs_feedback; static HotStandbyFeedback agg_hs_feedback;
static void nwp_shmem_startup_hook(void);
static void nwp_register_gucs(void); static void nwp_register_gucs(void);
static void assign_neon_safekeepers(const char *newval, void *extra); static void assign_neon_safekeepers(const char *newval, void *extra);
static void nwp_prepare_shmem(void);
static uint64 backpressure_lag_impl(void); static uint64 backpressure_lag_impl(void);
static uint64 startup_backpressure_wrap(void); static uint64 startup_backpressure_wrap(void);
static bool backpressure_throttling_impl(void); static bool backpressure_throttling_impl(void);
@@ -99,11 +97,6 @@ static TimestampTz walprop_pg_get_current_timestamp(WalProposer *wp);
static void walprop_pg_load_libpqwalreceiver(void); static void walprop_pg_load_libpqwalreceiver(void);
static process_interrupts_callback_t PrevProcessInterruptsCallback = NULL; static process_interrupts_callback_t PrevProcessInterruptsCallback = NULL;
static shmem_startup_hook_type prev_shmem_startup_hook_type;
#if PG_VERSION_NUM >= 150000
static shmem_request_hook_type prev_shmem_request_hook = NULL;
static void walproposer_shmem_request(void);
#endif
static void WalproposerShmemInit_SyncSafekeeper(void); static void WalproposerShmemInit_SyncSafekeeper(void);
@@ -193,8 +186,6 @@ pg_init_walproposer(void)
nwp_register_gucs(); nwp_register_gucs();
nwp_prepare_shmem();
delay_backend_us = &startup_backpressure_wrap; delay_backend_us = &startup_backpressure_wrap;
PrevProcessInterruptsCallback = ProcessInterruptsCallback; PrevProcessInterruptsCallback = ProcessInterruptsCallback;
ProcessInterruptsCallback = backpressure_throttling_impl; ProcessInterruptsCallback = backpressure_throttling_impl;
@@ -494,12 +485,11 @@ WalproposerShmemSize(void)
return sizeof(WalproposerShmemState); return sizeof(WalproposerShmemState);
} }
static bool void
WalproposerShmemInit(void) WalproposerShmemInit(void)
{ {
bool found; bool found;
LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
walprop_shared = ShmemInitStruct("Walproposer shared state", walprop_shared = ShmemInitStruct("Walproposer shared state",
sizeof(WalproposerShmemState), sizeof(WalproposerShmemState),
&found); &found);
@@ -517,9 +507,6 @@ WalproposerShmemInit(void)
pg_atomic_init_u64(&walprop_shared->wal_rate_limiter.last_recorded_time_us, 0); pg_atomic_init_u64(&walprop_shared->wal_rate_limiter.last_recorded_time_us, 0);
/* END_HADRON */ /* END_HADRON */
} }
LWLockRelease(AddinShmemInitLock);
return found;
} }
static void static void
@@ -623,42 +610,15 @@ walprop_register_bgworker(void)
/* shmem handling */ /* shmem handling */
static void
nwp_prepare_shmem(void)
{
#if PG_VERSION_NUM >= 150000
prev_shmem_request_hook = shmem_request_hook;
shmem_request_hook = walproposer_shmem_request;
#else
RequestAddinShmemSpace(WalproposerShmemSize());
#endif
prev_shmem_startup_hook_type = shmem_startup_hook;
shmem_startup_hook = nwp_shmem_startup_hook;
}
#if PG_VERSION_NUM >= 150000
/* /*
* shmem_request hook: request additional shared resources. We'll allocate or * shmem_request hook: request additional shared resources. We'll allocate or
* attach to the shared resources in nwp_shmem_startup_hook(). * attach to the shared resources in WalproposerShmemInit().
*/ */
static void void
walproposer_shmem_request(void) WalproposerShmemRequest(void)
{ {
if (prev_shmem_request_hook)
prev_shmem_request_hook();
RequestAddinShmemSpace(WalproposerShmemSize()); RequestAddinShmemSpace(WalproposerShmemSize());
} }
#endif
static void
nwp_shmem_startup_hook(void)
{
if (prev_shmem_startup_hook_type)
prev_shmem_startup_hook_type();
WalproposerShmemInit();
}
WalproposerShmemState * WalproposerShmemState *
GetWalpropShmemState(void) GetWalpropShmemState(void)

View File

@@ -10,6 +10,7 @@ use tokio::time::Instant;
use tracing::{debug, info}; use tracing::{debug, info};
use crate::config::ProjectInfoCacheOptions; use crate::config::ProjectInfoCacheOptions;
use crate::control_plane::messages::{ControlPlaneErrorMessage, Reason};
use crate::control_plane::{EndpointAccessControl, RoleAccessControl}; use crate::control_plane::{EndpointAccessControl, RoleAccessControl};
use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt}; use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
use crate::types::{EndpointId, RoleName}; use crate::types::{EndpointId, RoleName};
@@ -36,22 +37,37 @@ impl<T> Entry<T> {
} }
pub(crate) fn get(&self) -> Option<&T> { pub(crate) fn get(&self) -> Option<&T> {
(self.expires_at > Instant::now()).then_some(&self.value) (!self.is_expired()).then_some(&self.value)
}
fn is_expired(&self) -> bool {
self.expires_at <= Instant::now()
} }
} }
struct EndpointInfo { struct EndpointInfo {
role_controls: HashMap<RoleNameInt, Entry<RoleAccessControl>>, role_controls: HashMap<RoleNameInt, Entry<ControlPlaneResult<RoleAccessControl>>>,
controls: Option<Entry<EndpointAccessControl>>, controls: Option<Entry<ControlPlaneResult<EndpointAccessControl>>>,
} }
type ControlPlaneResult<T> = Result<T, Box<ControlPlaneErrorMessage>>;
impl EndpointInfo { impl EndpointInfo {
pub(crate) fn get_role_secret(&self, role_name: RoleNameInt) -> Option<RoleAccessControl> { pub(crate) fn get_role_secret_with_ttl(
self.role_controls.get(&role_name)?.get().cloned() &self,
role_name: RoleNameInt,
) -> Option<(ControlPlaneResult<RoleAccessControl>, Duration)> {
let entry = self.role_controls.get(&role_name)?;
let ttl = entry.expires_at - Instant::now();
Some((entry.get()?.clone(), ttl))
} }
pub(crate) fn get_controls(&self) -> Option<EndpointAccessControl> { pub(crate) fn get_controls_with_ttl(
self.controls.as_ref()?.get().cloned() &self,
) -> Option<(ControlPlaneResult<EndpointAccessControl>, Duration)> {
let entry = self.controls.as_ref()?;
let ttl = entry.expires_at - Instant::now();
Some((entry.get()?.clone(), ttl))
} }
pub(crate) fn invalidate_endpoint(&mut self) { pub(crate) fn invalidate_endpoint(&mut self) {
@@ -153,28 +169,28 @@ impl ProjectInfoCacheImpl {
self.cache.get(&endpoint_id) self.cache.get(&endpoint_id)
} }
pub(crate) fn get_role_secret( pub(crate) fn get_role_secret_with_ttl(
&self, &self,
endpoint_id: &EndpointId, endpoint_id: &EndpointId,
role_name: &RoleName, role_name: &RoleName,
) -> Option<RoleAccessControl> { ) -> Option<(ControlPlaneResult<RoleAccessControl>, Duration)> {
let role_name = RoleNameInt::get(role_name)?; let role_name = RoleNameInt::get(role_name)?;
let endpoint_info = self.get_endpoint_cache(endpoint_id)?; let endpoint_info = self.get_endpoint_cache(endpoint_id)?;
endpoint_info.get_role_secret(role_name) endpoint_info.get_role_secret_with_ttl(role_name)
} }
pub(crate) fn get_endpoint_access( pub(crate) fn get_endpoint_access_with_ttl(
&self, &self,
endpoint_id: &EndpointId, endpoint_id: &EndpointId,
) -> Option<EndpointAccessControl> { ) -> Option<(ControlPlaneResult<EndpointAccessControl>, Duration)> {
let endpoint_info = self.get_endpoint_cache(endpoint_id)?; let endpoint_info = self.get_endpoint_cache(endpoint_id)?;
endpoint_info.get_controls() endpoint_info.get_controls_with_ttl()
} }
pub(crate) fn insert_endpoint_access( pub(crate) fn insert_endpoint_access(
&self, &self,
account_id: Option<AccountIdInt>, account_id: Option<AccountIdInt>,
project_id: ProjectIdInt, project_id: Option<ProjectIdInt>,
endpoint_id: EndpointIdInt, endpoint_id: EndpointIdInt,
role_name: RoleNameInt, role_name: RoleNameInt,
controls: EndpointAccessControl, controls: EndpointAccessControl,
@@ -183,26 +199,89 @@ impl ProjectInfoCacheImpl {
if let Some(account_id) = account_id { if let Some(account_id) = account_id {
self.insert_account2endpoint(account_id, endpoint_id); self.insert_account2endpoint(account_id, endpoint_id);
} }
self.insert_project2endpoint(project_id, endpoint_id); if let Some(project_id) = project_id {
self.insert_project2endpoint(project_id, endpoint_id);
}
if self.cache.len() >= self.config.size { if self.cache.len() >= self.config.size {
// If there are too many entries, wait until the next gc cycle. // If there are too many entries, wait until the next gc cycle.
return; return;
} }
let controls = Entry::new(controls, self.config.ttl); debug!(
let role_controls = Entry::new(role_controls, self.config.ttl); key = &*endpoint_id,
"created a cache entry for endpoint access"
);
let controls = Some(Entry::new(Ok(controls), self.config.ttl));
let role_controls = Entry::new(Ok(role_controls), self.config.ttl);
match self.cache.entry(endpoint_id) { match self.cache.entry(endpoint_id) {
clashmap::Entry::Vacant(e) => { clashmap::Entry::Vacant(e) => {
e.insert(EndpointInfo { e.insert(EndpointInfo {
role_controls: HashMap::from_iter([(role_name, role_controls)]), role_controls: HashMap::from_iter([(role_name, role_controls)]),
controls: Some(controls), controls,
}); });
} }
clashmap::Entry::Occupied(mut e) => { clashmap::Entry::Occupied(mut e) => {
let ep = e.get_mut(); let ep = e.get_mut();
ep.controls = Some(controls); ep.controls = controls;
if ep.role_controls.len() < self.config.max_roles {
ep.role_controls.insert(role_name, role_controls);
}
}
}
}
pub(crate) fn insert_endpoint_access_err(
&self,
endpoint_id: EndpointIdInt,
role_name: RoleNameInt,
msg: Box<ControlPlaneErrorMessage>,
ttl: Option<Duration>,
) {
if self.cache.len() >= self.config.size {
// If there are too many entries, wait until the next gc cycle.
return;
}
debug!(
key = &*endpoint_id,
"created a cache entry for an endpoint access error"
);
let ttl = ttl.unwrap_or(self.config.ttl);
let controls = if msg.get_reason() == Reason::RoleProtected {
// RoleProtected is the only role-specific error that control plane can give us.
// If a given role name does not exist, it still returns a successful response,
// just with an empty secret.
None
} else {
// We can cache all the other errors in EndpointInfo.controls,
// because they don't depend on what role name we pass to control plane.
Some(Entry::new(Err(msg.clone()), ttl))
};
let role_controls = Entry::new(Err(msg), ttl);
match self.cache.entry(endpoint_id) {
clashmap::Entry::Vacant(e) => {
e.insert(EndpointInfo {
role_controls: HashMap::from_iter([(role_name, role_controls)]),
controls,
});
}
clashmap::Entry::Occupied(mut e) => {
let ep = e.get_mut();
if let Some(entry) = &ep.controls
&& !entry.is_expired()
&& entry.value.is_ok()
{
// If we have cached non-expired, non-error controls, keep them.
} else {
ep.controls = controls;
}
if ep.role_controls.len() < self.config.max_roles { if ep.role_controls.len() < self.config.max_roles {
ep.role_controls.insert(role_name, role_controls); ep.role_controls.insert(role_name, role_controls);
} }
@@ -245,7 +324,7 @@ impl ProjectInfoCacheImpl {
return; return;
}; };
if role_controls.get().expires_at <= Instant::now() { if role_controls.get().is_expired() {
role_controls.remove(); role_controls.remove();
} }
} }
@@ -284,13 +363,11 @@ impl ProjectInfoCacheImpl {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::sync::Arc;
use super::*; use super::*;
use crate::control_plane::messages::EndpointRateLimitConfig; use crate::control_plane::messages::{Details, EndpointRateLimitConfig, ErrorInfo, Status};
use crate::control_plane::{AccessBlockerFlags, AuthSecret}; use crate::control_plane::{AccessBlockerFlags, AuthSecret};
use crate::scram::ServerSecret; use crate::scram::ServerSecret;
use crate::types::ProjectId; use std::sync::Arc;
#[tokio::test] #[tokio::test]
async fn test_project_info_cache_settings() { async fn test_project_info_cache_settings() {
@@ -301,9 +378,9 @@ mod tests {
ttl: Duration::from_secs(1), ttl: Duration::from_secs(1),
gc_interval: Duration::from_secs(600), gc_interval: Duration::from_secs(600),
}); });
let project_id: ProjectId = "project".into(); let project_id: Option<ProjectIdInt> = Some(ProjectIdInt::from(&"project".into()));
let endpoint_id: EndpointId = "endpoint".into(); let endpoint_id: EndpointId = "endpoint".into();
let account_id: Option<AccountIdInt> = None; let account_id = None;
let user1: RoleName = "user1".into(); let user1: RoleName = "user1".into();
let user2: RoleName = "user2".into(); let user2: RoleName = "user2".into();
@@ -316,7 +393,7 @@ mod tests {
cache.insert_endpoint_access( cache.insert_endpoint_access(
account_id, account_id,
(&project_id).into(), project_id,
(&endpoint_id).into(), (&endpoint_id).into(),
(&user1).into(), (&user1).into(),
EndpointAccessControl { EndpointAccessControl {
@@ -332,7 +409,7 @@ mod tests {
cache.insert_endpoint_access( cache.insert_endpoint_access(
account_id, account_id,
(&project_id).into(), project_id,
(&endpoint_id).into(), (&endpoint_id).into(),
(&user2).into(), (&user2).into(),
EndpointAccessControl { EndpointAccessControl {
@@ -346,11 +423,17 @@ mod tests {
}, },
); );
let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap(); let (cached, ttl) = cache
assert_eq!(cached.secret, secret1); .get_role_secret_with_ttl(&endpoint_id, &user1)
.unwrap();
assert_eq!(cached.unwrap().secret, secret1);
assert_eq!(ttl, cache.config.ttl);
let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap(); let (cached, ttl) = cache
assert_eq!(cached.secret, secret2); .get_role_secret_with_ttl(&endpoint_id, &user2)
.unwrap();
assert_eq!(cached.unwrap().secret, secret2);
assert_eq!(ttl, cache.config.ttl);
// Shouldn't add more than 2 roles. // Shouldn't add more than 2 roles.
let user3: RoleName = "user3".into(); let user3: RoleName = "user3".into();
@@ -358,7 +441,7 @@ mod tests {
cache.insert_endpoint_access( cache.insert_endpoint_access(
account_id, account_id,
(&project_id).into(), project_id,
(&endpoint_id).into(), (&endpoint_id).into(),
(&user3).into(), (&user3).into(),
EndpointAccessControl { EndpointAccessControl {
@@ -372,17 +455,144 @@ mod tests {
}, },
); );
assert!(cache.get_role_secret(&endpoint_id, &user3).is_none()); assert!(
cache
.get_role_secret_with_ttl(&endpoint_id, &user3)
.is_none()
);
let cached = cache.get_endpoint_access(&endpoint_id).unwrap(); let cached = cache
.get_endpoint_access_with_ttl(&endpoint_id)
.unwrap()
.0
.unwrap();
assert_eq!(cached.allowed_ips, allowed_ips); assert_eq!(cached.allowed_ips, allowed_ips);
tokio::time::advance(Duration::from_secs(2)).await; tokio::time::advance(Duration::from_secs(2)).await;
let cached = cache.get_role_secret(&endpoint_id, &user1); let cached = cache.get_role_secret_with_ttl(&endpoint_id, &user1);
assert!(cached.is_none()); assert!(cached.is_none());
let cached = cache.get_role_secret(&endpoint_id, &user2); let cached = cache.get_role_secret_with_ttl(&endpoint_id, &user2);
assert!(cached.is_none()); assert!(cached.is_none());
let cached = cache.get_endpoint_access(&endpoint_id); let cached = cache.get_endpoint_access_with_ttl(&endpoint_id);
assert!(cached.is_none()); assert!(cached.is_none());
} }
#[tokio::test]
async fn test_caching_project_info_errors() {
let cache = ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
size: 10,
max_roles: 10,
ttl: Duration::from_secs(1),
gc_interval: Duration::from_secs(600),
});
let project_id = Some(ProjectIdInt::from(&"project".into()));
let endpoint_id: EndpointId = "endpoint".into();
let account_id = None;
let user1: RoleName = "user1".into();
let user2: RoleName = "user2".into();
let secret = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
let role_msg = Box::new(ControlPlaneErrorMessage {
error: "role is protected and cannot be used for password-based authentication"
.to_owned()
.into_boxed_str(),
http_status_code: http::StatusCode::NOT_FOUND,
status: Some(Status {
code: "PERMISSION_DENIED".to_owned().into_boxed_str(),
message: "role is protected and cannot be used for password-based authentication"
.to_owned()
.into_boxed_str(),
details: Details {
error_info: Some(ErrorInfo {
reason: Reason::RoleProtected,
}),
retry_info: None,
user_facing_message: None,
},
}),
});
let generic_msg = Box::new(ControlPlaneErrorMessage {
error: "oh noes".to_owned().into_boxed_str(),
http_status_code: http::StatusCode::NOT_FOUND,
status: None,
});
let get_role_secret = |endpoint_id, role_name| {
cache
.get_role_secret_with_ttl(endpoint_id, role_name)
.unwrap()
.0
};
let get_endpoint_access =
|endpoint_id| cache.get_endpoint_access_with_ttl(endpoint_id).unwrap().0;
// stores role-specific errors only for get_role_secret
cache.insert_endpoint_access_err(
(&endpoint_id).into(),
(&user1).into(),
role_msg.clone(),
None,
);
assert_eq!(
get_role_secret(&endpoint_id, &user1).unwrap_err().error,
role_msg.error
);
assert!(cache.get_endpoint_access_with_ttl(&endpoint_id).is_none());
// stores non-role specific errors for both get_role_secret and get_endpoint_access
cache.insert_endpoint_access_err(
(&endpoint_id).into(),
(&user1).into(),
generic_msg.clone(),
None,
);
assert_eq!(
get_role_secret(&endpoint_id, &user1).unwrap_err().error,
generic_msg.error
);
assert_eq!(
get_endpoint_access(&endpoint_id).unwrap_err().error,
generic_msg.error
);
// error isn't returned for other roles in the same endpoint
assert!(
cache
.get_role_secret_with_ttl(&endpoint_id, &user2)
.is_none()
);
// success for a role does not overwrite errors for other roles
cache.insert_endpoint_access(
account_id,
project_id,
(&endpoint_id).into(),
(&user2).into(),
EndpointAccessControl {
allowed_ips: Arc::new(vec![]),
allowed_vpce: Arc::new(vec![]),
flags: AccessBlockerFlags::default(),
rate_limits: EndpointRateLimitConfig::default(),
},
RoleAccessControl {
secret: secret.clone(),
},
);
assert!(get_role_secret(&endpoint_id, &user1).is_err());
assert!(get_role_secret(&endpoint_id, &user2).is_ok());
// ...but does clear the access control error
assert!(get_endpoint_access(&endpoint_id).is_ok());
// storing an error does not overwrite successful access control response
cache.insert_endpoint_access_err(
(&endpoint_id).into(),
(&user2).into(),
generic_msg.clone(),
None,
);
assert!(get_role_secret(&endpoint_id, &user2).is_err());
assert!(get_endpoint_access(&endpoint_id).is_ok());
}
} }

View File

@@ -32,8 +32,11 @@ use crate::util::run_until;
type IpSubnetKey = IpNet; type IpSubnetKey = IpNet;
const CANCEL_KEY_TTL: Duration = Duration::from_secs(600); /// Initial period and TTL is shorter to clear keys of short-lived connections faster.
const CANCEL_KEY_REFRESH: Duration = Duration::from_secs(570); const CANCEL_KEY_INITIAL_PERIOD: Duration = Duration::from_secs(60);
const CANCEL_KEY_REFRESH_PERIOD: Duration = Duration::from_secs(10 * 60);
/// `CANCEL_KEY_TTL_SLACK` is added to the periods to determine the actual TTL.
const CANCEL_KEY_TTL_SLACK: Duration = Duration::from_secs(30);
// Message types for sending through mpsc channel // Message types for sending through mpsc channel
pub enum CancelKeyOp { pub enum CancelKeyOp {
@@ -54,6 +57,24 @@ pub enum CancelKeyOp {
}, },
} }
impl CancelKeyOp {
const fn redis_msg_kind(&self) -> RedisMsgKind {
match self {
CancelKeyOp::Store { .. } => RedisMsgKind::Set,
CancelKeyOp::Refresh { .. } => RedisMsgKind::Expire,
CancelKeyOp::Get { .. } => RedisMsgKind::Get,
CancelKeyOp::GetOld { .. } => RedisMsgKind::HGet,
}
}
fn cancel_channel_metric_guard(&self) -> CancelChannelSizeGuard<'static> {
Metrics::get()
.proxy
.cancel_channel_size
.guard(self.redis_msg_kind())
}
}
#[derive(thiserror::Error, Debug, Clone)] #[derive(thiserror::Error, Debug, Clone)]
pub enum PipelineError { pub enum PipelineError {
#[error("could not send cmd to redis: {0}")] #[error("could not send cmd to redis: {0}")]
@@ -483,50 +504,49 @@ impl Session {
let mut cancel = pin!(cancel); let mut cancel = pin!(cancel);
enum State { enum State {
Set, Init,
Refresh, Refresh,
} }
let mut state = State::Set;
let mut state = State::Init;
loop { loop {
let guard_op = match state { let (op, mut wait_interval) = match state {
State::Set => { State::Init => {
let guard = Metrics::get()
.proxy
.cancel_channel_size
.guard(RedisMsgKind::Set);
let op = CancelKeyOp::Store {
key: self.key,
value: closure_json.clone(),
expire: CANCEL_KEY_TTL,
};
tracing::debug!( tracing::debug!(
src=%self.key, src=%self.key,
dest=?cancel_closure.cancel_token, dest=?cancel_closure.cancel_token,
"registering cancellation key" "registering cancellation key"
); );
(guard, op) (
CancelKeyOp::Store {
key: self.key,
value: closure_json.clone(),
expire: CANCEL_KEY_INITIAL_PERIOD + CANCEL_KEY_TTL_SLACK,
},
CANCEL_KEY_INITIAL_PERIOD,
)
} }
State::Refresh => { State::Refresh => {
let guard = Metrics::get()
.proxy
.cancel_channel_size
.guard(RedisMsgKind::Expire);
let op = CancelKeyOp::Refresh {
key: self.key,
expire: CANCEL_KEY_TTL,
};
tracing::debug!( tracing::debug!(
src=%self.key, src=%self.key,
dest=?cancel_closure.cancel_token, dest=?cancel_closure.cancel_token,
"refreshing cancellation key" "refreshing cancellation key"
); );
(guard, op) (
CancelKeyOp::Refresh {
key: self.key,
expire: CANCEL_KEY_REFRESH_PERIOD + CANCEL_KEY_TTL_SLACK,
},
CANCEL_KEY_REFRESH_PERIOD,
)
} }
}; };
match tx.call(guard_op, cancel.as_mut()).await { match tx
.call((op.cancel_channel_metric_guard(), op), cancel.as_mut())
.await
{
// SET returns OK // SET returns OK
Ok(Value::Okay) => { Ok(Value::Okay) => {
tracing::debug!( tracing::debug!(
@@ -549,23 +569,23 @@ impl Session {
Ok(_) => { Ok(_) => {
// Any other response likely means the key expired. // Any other response likely means the key expired.
tracing::warn!(src=%self.key, "refreshing cancellation key failed"); tracing::warn!(src=%self.key, "refreshing cancellation key failed");
// Re-enter the SET loop to repush full data. // Re-enter the SET loop quickly to repush full data.
state = State::Set; state = State::Init;
wait_interval = Duration::ZERO;
} }
// retry immediately. // retry immediately.
Err(BatchQueueError::Result(error)) => { Err(BatchQueueError::Result(error)) => {
tracing::warn!(?error, "error refreshing cancellation key"); tracing::warn!(?error, "error refreshing cancellation key");
// Small delay to prevent busy loop with high cpu and logging. // Small delay to prevent busy loop with high cpu and logging.
tokio::time::sleep(Duration::from_millis(10)).await; wait_interval = Duration::from_millis(10);
continue;
} }
Err(BatchQueueError::Cancelled(Err(_cancelled))) => break, Err(BatchQueueError::Cancelled(Err(_cancelled))) => break,
} }
// wait before continuing. break immediately if cancelled. // wait before continuing. break immediately if cancelled.
if run_until(tokio::time::sleep(CANCEL_KEY_REFRESH), cancel.as_mut()) if run_until(tokio::time::sleep(wait_interval), cancel.as_mut())
.await .await
.is_err() .is_err()
{ {

View File

@@ -68,6 +68,66 @@ impl NeonControlPlaneClient {
self.endpoint.url().as_str() self.endpoint.url().as_str()
} }
async fn get_and_cache_auth_info<T>(
&self,
ctx: &RequestContext,
endpoint: &EndpointId,
role: &RoleName,
cache_key: &EndpointId,
extract: impl FnOnce(&EndpointAccessControl, &RoleAccessControl) -> T,
) -> Result<T, GetAuthInfoError> {
match self.do_get_auth_req(ctx, endpoint, role).await {
Ok(auth_info) => {
let control = EndpointAccessControl {
allowed_ips: Arc::new(auth_info.allowed_ips),
allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids),
flags: auth_info.access_blocker_flags,
rate_limits: auth_info.rate_limits,
};
let role_control = RoleAccessControl {
secret: auth_info.secret,
};
let res = extract(&control, &role_control);
self.caches.project_info.insert_endpoint_access(
auth_info.account_id,
auth_info.project_id,
cache_key.into(),
role.into(),
control,
role_control,
);
if let Some(project_id) = auth_info.project_id {
ctx.set_project_id(project_id);
}
Ok(res)
}
Err(err) => match err {
GetAuthInfoError::ApiError(ControlPlaneError::Message(ref msg)) => {
let retry_info = msg.status.as_ref().and_then(|s| s.details.retry_info);
// If we can retry this error, do not cache it,
// unless we were given a retry delay.
if msg.could_retry() && retry_info.is_none() {
return Err(err);
}
self.caches.project_info.insert_endpoint_access_err(
cache_key.into(),
role.into(),
msg.clone(),
retry_info.map(|r| Duration::from_millis(r.retry_delay_ms)),
);
Err(err)
}
err => Err(err),
},
}
}
async fn do_get_auth_req( async fn do_get_auth_req(
&self, &self,
ctx: &RequestContext, ctx: &RequestContext,
@@ -284,43 +344,34 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
ctx: &RequestContext, ctx: &RequestContext,
endpoint: &EndpointId, endpoint: &EndpointId,
role: &RoleName, role: &RoleName,
) -> Result<RoleAccessControl, crate::control_plane::errors::GetAuthInfoError> { ) -> Result<RoleAccessControl, GetAuthInfoError> {
let normalized_ep = &endpoint.normalize(); let key = endpoint.normalize();
if let Some(secret) = self
if let Some((role_control, ttl)) = self
.caches .caches
.project_info .project_info
.get_role_secret(normalized_ep, role) .get_role_secret_with_ttl(&key, role)
{ {
return Ok(secret); return match role_control {
Err(mut msg) => {
info!(key = &*key, "found cached get_role_access_control error");
// if retry_delay_ms is set change it to the remaining TTL
replace_retry_delay_ms(&mut msg, |_| ttl.as_millis() as u64);
Err(GetAuthInfoError::ApiError(ControlPlaneError::Message(msg)))
}
Ok(role_control) => {
debug!(key = &*key, "found cached role access control");
Ok(role_control)
}
};
} }
let auth_info = self.do_get_auth_req(ctx, endpoint, role).await?; self.get_and_cache_auth_info(ctx, endpoint, role, &key, |_, role_control| {
role_control.clone()
let control = EndpointAccessControl { })
allowed_ips: Arc::new(auth_info.allowed_ips), .await
allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids),
flags: auth_info.access_blocker_flags,
rate_limits: auth_info.rate_limits,
};
let role_control = RoleAccessControl {
secret: auth_info.secret,
};
if let Some(project_id) = auth_info.project_id {
let normalized_ep_int = normalized_ep.into();
self.caches.project_info.insert_endpoint_access(
auth_info.account_id,
project_id,
normalized_ep_int,
role.into(),
control,
role_control.clone(),
);
ctx.set_project_id(project_id);
}
Ok(role_control)
} }
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
@@ -330,38 +381,30 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
endpoint: &EndpointId, endpoint: &EndpointId,
role: &RoleName, role: &RoleName,
) -> Result<EndpointAccessControl, GetAuthInfoError> { ) -> Result<EndpointAccessControl, GetAuthInfoError> {
let normalized_ep = &endpoint.normalize(); let key = endpoint.normalize();
if let Some(control) = self.caches.project_info.get_endpoint_access(normalized_ep) {
return Ok(control); if let Some((control, ttl)) = self.caches.project_info.get_endpoint_access_with_ttl(&key) {
return match control {
Err(mut msg) => {
info!(
key = &*key,
"found cached get_endpoint_access_control error"
);
// if retry_delay_ms is set change it to the remaining TTL
replace_retry_delay_ms(&mut msg, |_| ttl.as_millis() as u64);
Err(GetAuthInfoError::ApiError(ControlPlaneError::Message(msg)))
}
Ok(control) => {
debug!(key = &*key, "found cached endpoint access control");
Ok(control)
}
};
} }
let auth_info = self.do_get_auth_req(ctx, endpoint, role).await?; self.get_and_cache_auth_info(ctx, endpoint, role, &key, |control, _| control.clone())
.await
let control = EndpointAccessControl {
allowed_ips: Arc::new(auth_info.allowed_ips),
allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids),
flags: auth_info.access_blocker_flags,
rate_limits: auth_info.rate_limits,
};
let role_control = RoleAccessControl {
secret: auth_info.secret,
};
if let Some(project_id) = auth_info.project_id {
let normalized_ep_int = normalized_ep.into();
self.caches.project_info.insert_endpoint_access(
auth_info.account_id,
project_id,
normalized_ep_int,
role.into(),
control.clone(),
role_control,
);
ctx.set_project_id(project_id);
}
Ok(control)
} }
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
@@ -390,13 +433,9 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
info!(key = &*key, "found cached wake_compute error"); info!(key = &*key, "found cached wake_compute error");
// if retry_delay_ms is set, reduce it by the amount of time it spent in cache // if retry_delay_ms is set, reduce it by the amount of time it spent in cache
if let Some(status) = &mut msg.status { replace_retry_delay_ms(&mut msg, |delay| {
if let Some(retry_info) = &mut status.details.retry_info { delay.saturating_sub(created_at.elapsed().as_millis() as u64)
retry_info.retry_delay_ms = retry_info });
.retry_delay_ms
.saturating_sub(created_at.elapsed().as_millis() as u64)
}
}
Err(WakeComputeError::ControlPlane(ControlPlaneError::Message( Err(WakeComputeError::ControlPlane(ControlPlaneError::Message(
msg, msg,
@@ -478,6 +517,14 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
} }
} }
fn replace_retry_delay_ms(msg: &mut ControlPlaneErrorMessage, f: impl FnOnce(u64) -> u64) {
if let Some(status) = &mut msg.status
&& let Some(retry_info) = &mut status.details.retry_info
{
retry_info.retry_delay_ms = f(retry_info.retry_delay_ms);
}
}
/// Parse http response body, taking status code into account. /// Parse http response body, taking status code into account.
fn parse_body<T: for<'a> serde::Deserialize<'a>>( fn parse_body<T: for<'a> serde::Deserialize<'a>>(
status: StatusCode, status: StatusCode,

View File

@@ -52,7 +52,7 @@ impl ReportableError for ControlPlaneError {
| Reason::EndpointNotFound | Reason::EndpointNotFound
| Reason::EndpointDisabled | Reason::EndpointDisabled
| Reason::BranchNotFound | Reason::BranchNotFound
| Reason::InvalidEphemeralEndpointOptions => ErrorKind::User, | Reason::WrongLsnOrTimestamp => ErrorKind::User,
Reason::RateLimitExceeded => ErrorKind::ServiceRateLimit, Reason::RateLimitExceeded => ErrorKind::ServiceRateLimit,

View File

@@ -107,7 +107,7 @@ pub(crate) struct ErrorInfo {
// Schema could also have `metadata` field, but it's not structured. Skip it for now. // Schema could also have `metadata` field, but it's not structured. Skip it for now.
} }
#[derive(Clone, Copy, Debug, Deserialize, Default)] #[derive(Clone, Copy, Debug, Deserialize, Default, PartialEq, Eq)]
pub(crate) enum Reason { pub(crate) enum Reason {
/// RoleProtected indicates that the role is protected and the attempted operation is not permitted on protected roles. /// RoleProtected indicates that the role is protected and the attempted operation is not permitted on protected roles.
#[serde(rename = "ROLE_PROTECTED")] #[serde(rename = "ROLE_PROTECTED")]
@@ -133,9 +133,9 @@ pub(crate) enum Reason {
/// or that the subject doesn't have enough permissions to access the requested branch. /// or that the subject doesn't have enough permissions to access the requested branch.
#[serde(rename = "BRANCH_NOT_FOUND")] #[serde(rename = "BRANCH_NOT_FOUND")]
BranchNotFound, BranchNotFound,
/// InvalidEphemeralEndpointOptions indicates that the specified LSN or timestamp are wrong. /// WrongLsnOrTimestamp indicates that the specified LSN or timestamp are wrong.
#[serde(rename = "INVALID_EPHEMERAL_OPTIONS")] #[serde(rename = "WRONG_LSN_OR_TIMESTAMP")]
InvalidEphemeralEndpointOptions, WrongLsnOrTimestamp,
/// RateLimitExceeded indicates that the rate limit for the operation has been exceeded. /// RateLimitExceeded indicates that the rate limit for the operation has been exceeded.
#[serde(rename = "RATE_LIMIT_EXCEEDED")] #[serde(rename = "RATE_LIMIT_EXCEEDED")]
RateLimitExceeded, RateLimitExceeded,
@@ -205,7 +205,7 @@ impl Reason {
| Reason::EndpointNotFound | Reason::EndpointNotFound
| Reason::EndpointDisabled | Reason::EndpointDisabled
| Reason::BranchNotFound | Reason::BranchNotFound
| Reason::InvalidEphemeralEndpointOptions => false, | Reason::WrongLsnOrTimestamp => false,
// we were asked to go away // we were asked to go away
Reason::RateLimitExceeded Reason::RateLimitExceeded
| Reason::NonDefaultBranchComputeTimeExceeded | Reason::NonDefaultBranchComputeTimeExceeded
@@ -257,19 +257,19 @@ pub(crate) struct GetEndpointAccessControl {
pub(crate) rate_limits: EndpointRateLimitConfig, pub(crate) rate_limits: EndpointRateLimitConfig,
} }
#[derive(Copy, Clone, Deserialize, Default)] #[derive(Copy, Clone, Deserialize, Default, Debug)]
pub struct EndpointRateLimitConfig { pub struct EndpointRateLimitConfig {
pub connection_attempts: ConnectionAttemptsLimit, pub connection_attempts: ConnectionAttemptsLimit,
} }
#[derive(Copy, Clone, Deserialize, Default)] #[derive(Copy, Clone, Deserialize, Default, Debug)]
pub struct ConnectionAttemptsLimit { pub struct ConnectionAttemptsLimit {
pub tcp: Option<LeakyBucketSetting>, pub tcp: Option<LeakyBucketSetting>,
pub ws: Option<LeakyBucketSetting>, pub ws: Option<LeakyBucketSetting>,
pub http: Option<LeakyBucketSetting>, pub http: Option<LeakyBucketSetting>,
} }
#[derive(Copy, Clone, Deserialize)] #[derive(Copy, Clone, Deserialize, Debug)]
pub struct LeakyBucketSetting { pub struct LeakyBucketSetting {
pub rps: f64, pub rps: f64,
pub burst: f64, pub burst: f64,

View File

@@ -82,7 +82,7 @@ impl NodeInfo {
} }
} }
#[derive(Copy, Clone, Default)] #[derive(Copy, Clone, Default, Debug)]
pub(crate) struct AccessBlockerFlags { pub(crate) struct AccessBlockerFlags {
pub public_access_blocked: bool, pub public_access_blocked: bool,
pub vpc_access_blocked: bool, pub vpc_access_blocked: bool,
@@ -92,12 +92,12 @@ pub(crate) type NodeInfoCache =
TimedLru<EndpointCacheKey, Result<NodeInfo, Box<ControlPlaneErrorMessage>>>; TimedLru<EndpointCacheKey, Result<NodeInfo, Box<ControlPlaneErrorMessage>>>;
pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>; pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>;
#[derive(Clone)] #[derive(Clone, Debug)]
pub struct RoleAccessControl { pub struct RoleAccessControl {
pub secret: Option<AuthSecret>, pub secret: Option<AuthSecret>,
} }
#[derive(Clone)] #[derive(Clone, Debug)]
pub struct EndpointAccessControl { pub struct EndpointAccessControl {
pub allowed_ips: Arc<Vec<IpPattern>>, pub allowed_ips: Arc<Vec<IpPattern>>,
pub allowed_vpce: Arc<Vec<String>>, pub allowed_vpce: Arc<Vec<String>>,

View File

@@ -21,7 +21,8 @@ pub fn check_permission(claims: &Claims, tenant_id: Option<TenantId>) -> Result<
| Scope::GenerationsApi | Scope::GenerationsApi
| Scope::Infra | Scope::Infra
| Scope::Scrubber | Scope::Scrubber
| Scope::ControllerPeer, | Scope::ControllerPeer
| Scope::TenantEndpoint,
_, _,
) => Err(AuthError( ) => Err(AuthError(
format!( format!(

View File

@@ -52,6 +52,7 @@ tokio-rustls.workspace = true
tokio-util.workspace = true tokio-util.workspace = true
tokio.workspace = true tokio.workspace = true
tracing.workspace = true tracing.workspace = true
uuid.workspace = true
measured.workspace = true measured.workspace = true
rustls.workspace = true rustls.workspace = true
scopeguard.workspace = true scopeguard.workspace = true
@@ -63,6 +64,7 @@ tokio-postgres-rustls.workspace = true
diesel = { version = "2.2.6", features = [ diesel = { version = "2.2.6", features = [
"serde_json", "serde_json",
"chrono", "chrono",
"uuid",
] } ] }
diesel-async = { version = "0.5.2", features = ["postgres", "bb8", "async-connection-wrapper"] } diesel-async = { version = "0.5.2", features = ["postgres", "bb8", "async-connection-wrapper"] }
diesel_migrations = { version = "2.2.0" } diesel_migrations = { version = "2.2.0" }

View File

@@ -0,0 +1,2 @@
DROP TABLE hadron_safekeepers;
DROP TABLE hadron_timeline_safekeepers;

View File

@@ -0,0 +1,17 @@
-- hadron_safekeepers keep track of all Safe Keeper nodes that exist in the system.
-- Upon startup, each Safe Keeper reaches out to the hadron cluster coordinator to register its node ID and listen addresses.
CREATE TABLE hadron_safekeepers (
sk_node_id BIGINT PRIMARY KEY NOT NULL,
listen_http_addr VARCHAR NOT NULL,
listen_http_port INTEGER NOT NULL,
listen_pg_addr VARCHAR NOT NULL,
listen_pg_port INTEGER NOT NULL
);
CREATE TABLE hadron_timeline_safekeepers (
timeline_id VARCHAR NOT NULL,
sk_node_id BIGINT NOT NULL,
legacy_endpoint_id UUID DEFAULT NULL,
PRIMARY KEY(timeline_id, sk_node_id)
);

View File

@@ -1,4 +1,5 @@
use utils::auth::{AuthError, Claims, Scope}; use utils::auth::{AuthError, Claims, Scope};
use uuid::Uuid;
pub fn check_permission(claims: &Claims, required_scope: Scope) -> Result<(), AuthError> { pub fn check_permission(claims: &Claims, required_scope: Scope) -> Result<(), AuthError> {
if claims.scope != required_scope { if claims.scope != required_scope {
@@ -7,3 +8,14 @@ pub fn check_permission(claims: &Claims, required_scope: Scope) -> Result<(), Au
Ok(()) Ok(())
} }
#[allow(dead_code)]
pub fn check_endpoint_permission(claims: &Claims, endpoint_id: Uuid) -> Result<(), AuthError> {
if claims.scope != Scope::TenantEndpoint {
return Err(AuthError("Scope mismatch. Permission denied".into()));
}
if claims.endpoint_id != Some(endpoint_id) {
return Err(AuthError("Endpoint id mismatch. Permission denied".into()));
}
Ok(())
}

View File

@@ -5,6 +5,8 @@ use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use anyhow::Context; use anyhow::Context;
use compute_api::spec::PageserverProtocol;
use compute_api::spec::PageserverShardInfo;
use control_plane::endpoint::{ use control_plane::endpoint::{
ComputeControlPlane, EndpointStatus, PageserverConnectionInfo, PageserverShardConnectionInfo, ComputeControlPlane, EndpointStatus, PageserverConnectionInfo, PageserverShardConnectionInfo,
}; };
@@ -13,7 +15,7 @@ use futures::StreamExt;
use hyper::StatusCode; use hyper::StatusCode;
use pageserver_api::config::DEFAULT_GRPC_LISTEN_PORT; use pageserver_api::config::DEFAULT_GRPC_LISTEN_PORT;
use pageserver_api::controller_api::AvailabilityZone; use pageserver_api::controller_api::AvailabilityZone;
use pageserver_api::shard::{ShardCount, ShardNumber, ShardStripeSize, TenantShardId}; use pageserver_api::shard::{ShardCount, ShardIndex, ShardNumber, ShardStripeSize, TenantShardId};
use postgres_connection::parse_host_port; use postgres_connection::parse_host_port;
use safekeeper_api::membership::SafekeeperGeneration; use safekeeper_api::membership::SafekeeperGeneration;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@@ -507,7 +509,16 @@ impl ApiMethod for ComputeHookTenant {
if endpoint.tenant_id == *tenant_id && endpoint.status() == EndpointStatus::Running { if endpoint.tenant_id == *tenant_id && endpoint.status() == EndpointStatus::Running {
tracing::info!("Reconfiguring pageservers for endpoint {endpoint_name}"); tracing::info!("Reconfiguring pageservers for endpoint {endpoint_name}");
let mut shard_conninfos = HashMap::new(); let shard_count = ShardCount(shards.len().try_into().expect("too many shards"));
let mut shard_infos: HashMap<ShardIndex, PageserverShardInfo> = HashMap::new();
let prefer_protocol = if endpoint.grpc {
PageserverProtocol::Grpc
} else {
PageserverProtocol::Libpq
};
for shard in shards.iter() { for shard in shards.iter() {
let ps_conf = env let ps_conf = env
.get_pageserver_conf(shard.node_id) .get_pageserver_conf(shard.node_id)
@@ -528,19 +539,31 @@ impl ApiMethod for ComputeHookTenant {
None None
}; };
let pageserver = PageserverShardConnectionInfo { let pageserver = PageserverShardConnectionInfo {
id: Some(shard.node_id.to_string()),
libpq_url, libpq_url,
grpc_url, grpc_url,
}; };
shard_conninfos.insert(shard.shard_number.0 as u32, pageserver); let shard_info = PageserverShardInfo {
pageservers: vec![pageserver],
};
shard_infos.insert(
ShardIndex {
shard_number: shard.shard_number,
shard_count,
},
shard_info,
);
} }
let pageserver_conninfo = PageserverConnectionInfo { let pageserver_conninfo = PageserverConnectionInfo {
shards: shard_conninfos, shard_count: ShardCount::unsharded(),
prefer_grpc: endpoint.grpc, stripe_size: stripe_size.map(|val| val.0),
shards: shard_infos,
prefer_protocol,
}; };
endpoint endpoint
.reconfigure_pageservers(pageserver_conninfo, *stripe_size) .reconfigure_pageservers(&pageserver_conninfo)
.await .await
.map_err(NotifyError::NeonLocal)?; .map_err(NotifyError::NeonLocal)?;
} }
@@ -824,6 +847,7 @@ impl ComputeHook {
let send_locked = tokio::select! { let send_locked = tokio::select! {
guard = send_lock.lock_owned() => {guard}, guard = send_lock.lock_owned() => {guard},
_ = cancel.cancelled() => { _ = cancel.cancelled() => {
tracing::info!("Notification cancelled while waiting for lock");
return Err(NotifyError::ShuttingDown) return Err(NotifyError::ShuttingDown)
} }
}; };
@@ -865,11 +889,32 @@ impl ComputeHook {
let notify_url = compute_hook_url.as_ref().unwrap(); let notify_url = compute_hook_url.as_ref().unwrap();
self.do_notify(notify_url, &request, cancel).await self.do_notify(notify_url, &request, cancel).await
} else { } else {
self.do_notify_local::<M>(&request).await.map_err(|e| { match self.do_notify_local::<M>(&request).await.map_err(|e| {
// This path is for testing only, so munge the error into our prod-style error type. // This path is for testing only, so munge the error into our prod-style error type.
tracing::error!("neon_local notification hook failed: {e}"); if e.to_string().contains("refresh-configuration-pending") {
NotifyError::Fatal(StatusCode::INTERNAL_SERVER_ERROR) // If the error message mentions "refresh-configuration-pending", it means the compute node
}) // rejected our notification request because it already trying to reconfigure itself. We
// can proceed with the rest of the reconcliation process as the compute node already
// discovers the need to reconfigure and will eventually update its configuration once
// we update the pageserver mappings. In fact, it is important that we continue with
// reconcliation to make sure we update the pageserver mappings to unblock the compute node.
tracing::info!("neon_local notification hook failed: {e}");
tracing::info!("Notification failed likely due to compute node self-reconfiguration, will retry.");
Ok(())
} else {
tracing::error!("neon_local notification hook failed: {e}");
Err(NotifyError::Fatal(StatusCode::INTERNAL_SERVER_ERROR))
}
}) {
// Compute node accepted the notification request. Ok to proceed.
Ok(_) => Ok(()),
// Compute node rejected our request but it is already self-reconfiguring. Ok to proceed.
Err(Ok(_)) => Ok(()),
// Fail the reconciliation attempt in all other cases. Recall that this whole code path involving
// neon_local is for testing only. In production we always retry failed reconcliations so we
// don't have any deadends here.
Err(Err(e)) => Err(e),
}
}; };
match result { match result {

View File

@@ -0,0 +1,44 @@
use std::collections::BTreeMap;
use rand::Rng;
use utils::shard::TenantShardId;
static CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789!@#$%^&*()";
/// Generate a random string of `length` that can be used as a password. The generated string
/// contains alphanumeric characters and special characters (!@#$%^&*())
pub fn generate_random_password(length: usize) -> String {
let mut rng = rand::thread_rng();
(0..length)
.map(|_| {
let idx = rng.gen_range(0..CHARSET.len());
CHARSET[idx] as char
})
.collect()
}
pub(crate) struct TenantShardSizeMap {
#[expect(dead_code)]
pub map: BTreeMap<TenantShardId, u64>,
}
impl TenantShardSizeMap {
pub fn new(map: BTreeMap<TenantShardId, u64>) -> Self {
Self { map }
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_generate_random_password() {
let pwd1 = generate_random_password(10);
assert_eq!(pwd1.len(), 10);
let pwd2 = generate_random_password(10);
assert_ne!(pwd1, pwd2);
assert!(pwd1.chars().all(|c| CHARSET.contains(&(c as u8))));
assert!(pwd2.chars().all(|c| CHARSET.contains(&(c as u8))));
}
}

View File

@@ -48,7 +48,10 @@ use crate::metrics::{
}; };
use crate::persistence::SafekeeperUpsert; use crate::persistence::SafekeeperUpsert;
use crate::reconciler::ReconcileError; use crate::reconciler::ReconcileError;
use crate::service::{LeadershipStatus, RECONCILE_TIMEOUT, STARTUP_RECONCILE_TIMEOUT, Service}; use crate::service::{
LeadershipStatus, RECONCILE_TIMEOUT, STARTUP_RECONCILE_TIMEOUT, Service,
TenantMutationLocations,
};
/// State available to HTTP request handlers /// State available to HTTP request handlers
pub struct HttpState { pub struct HttpState {
@@ -734,77 +737,104 @@ async fn handle_tenant_timeline_passthrough(
path path
); );
// Find the node that holds shard zero let tenant_shard_id = if tenant_or_shard_id.is_unsharded() {
let (node, tenant_shard_id, consistent) = if tenant_or_shard_id.is_unsharded() { // If the request contains only tenant ID, find the node that holds shard zero
service let (_, shard_id) = service
.tenant_shard0_node(tenant_or_shard_id.tenant_id) .tenant_shard0_node(tenant_or_shard_id.tenant_id)
.await? .await?;
shard_id
} else { } else {
let (node, consistent) = service.tenant_shard_node(tenant_or_shard_id).await?; tenant_or_shard_id
(node, tenant_or_shard_id, consistent)
}; };
// Callers will always pass an unsharded tenant ID. Before proxying, we must let service_inner = service.clone();
// rewrite this to a shard-aware shard zero ID.
let path = format!("{path}");
let tenant_str = tenant_or_shard_id.tenant_id.to_string();
let tenant_shard_str = format!("{tenant_shard_id}");
let path = path.replace(&tenant_str, &tenant_shard_str);
let latency = &METRICS_REGISTRY service.tenant_shard_remote_mutation(tenant_shard_id, |locations| async move {
.metrics_group let TenantMutationLocations(locations) = locations;
.storage_controller_passthrough_request_latency; if locations.is_empty() {
return Err(ApiError::NotFound(anyhow::anyhow!("Tenant {} not found", tenant_or_shard_id.tenant_id).into()));
}
let path_label = path_without_ids(&path) let (tenant_or_shard_id, locations) = locations.into_iter().next().unwrap();
.split('/') let node = locations.latest.node;
.filter(|token| !token.is_empty())
.collect::<Vec<_>>()
.join("_");
let labels = PageserverRequestLabelGroup {
pageserver_id: &node.get_id().to_string(),
path: &path_label,
method: crate::metrics::Method::Get,
};
let _timer = latency.start_timer(labels.clone()); // Callers will always pass an unsharded tenant ID. Before proxying, we must
// rewrite this to a shard-aware shard zero ID.
let path = format!("{path}");
let tenant_str = tenant_or_shard_id.tenant_id.to_string();
let tenant_shard_str = format!("{tenant_shard_id}");
let path = path.replace(&tenant_str, &tenant_shard_str);
let client = mgmt_api::Client::new( let latency = &METRICS_REGISTRY
service.get_http_client().clone(),
node.base_url(),
service.get_config().pageserver_jwt_token.as_deref(),
);
let resp = client.op_raw(method, path).await.map_err(|e|
// We return 503 here because if we can't successfully send a request to the pageserver,
// either we aren't available or the pageserver is unavailable.
ApiError::ResourceUnavailable(format!("Error sending pageserver API request to {node}: {e}").into()))?;
if !resp.status().is_success() {
let error_counter = &METRICS_REGISTRY
.metrics_group .metrics_group
.storage_controller_passthrough_request_error; .storage_controller_passthrough_request_latency;
error_counter.inc(labels);
}
// Transform 404 into 503 if we raced with a migration let path_label = path_without_ids(&path)
if resp.status() == reqwest::StatusCode::NOT_FOUND && !consistent { .split('/')
// Rather than retry here, send the client a 503 to prompt a retry: this matches .filter(|token| !token.is_empty())
// the pageserver's use of 503, and all clients calling this API should retry on 503. .collect::<Vec<_>>()
return Err(ApiError::ResourceUnavailable( .join("_");
format!("Pageserver {node} returned 404 due to ongoing migration, retry later").into(), let labels = PageserverRequestLabelGroup {
)); pageserver_id: &node.get_id().to_string(),
} path: &path_label,
method: crate::metrics::Method::Get,
};
// We have a reqest::Response, would like a http::Response let _timer = latency.start_timer(labels.clone());
let mut builder = hyper::Response::builder().status(map_reqwest_hyper_status(resp.status())?);
for (k, v) in resp.headers() {
builder = builder.header(k.as_str(), v.as_bytes());
}
let response = builder let client = mgmt_api::Client::new(
.body(Body::wrap_stream(resp.bytes_stream())) service_inner.get_http_client().clone(),
.map_err(|e| ApiError::InternalServerError(e.into()))?; node.base_url(),
service_inner.get_config().pageserver_jwt_token.as_deref(),
);
let resp = client.op_raw(method, path).await.map_err(|e|
// We return 503 here because if we can't successfully send a request to the pageserver,
// either we aren't available or the pageserver is unavailable.
ApiError::ResourceUnavailable(format!("Error sending pageserver API request to {node}: {e}").into()))?;
Ok(response) if !resp.status().is_success() {
let error_counter = &METRICS_REGISTRY
.metrics_group
.storage_controller_passthrough_request_error;
error_counter.inc(labels);
}
let resp_staus = resp.status();
// We have a reqest::Response, would like a http::Response
let mut builder = hyper::Response::builder().status(map_reqwest_hyper_status(resp_staus)?);
for (k, v) in resp.headers() {
builder = builder.header(k.as_str(), v.as_bytes());
}
let resp_bytes = resp
.bytes()
.await
.map_err(|e| ApiError::InternalServerError(e.into()))?;
// Inspect 404 errors: at this point, we know that the tenant exists, but the pageserver we route
// the request to might not yet be ready. Therefore, if it is a _tenant_ not found error, we can
// convert it into a 503. TODO: we should make this part of the check in `tenant_shard_remote_mutation`.
// However, `tenant_shard_remote_mutation` currently cannot inspect the HTTP error response body,
// so we have to do it here instead.
if resp_staus == reqwest::StatusCode::NOT_FOUND {
let resp_str = std::str::from_utf8(&resp_bytes)
.map_err(|e| ApiError::InternalServerError(e.into()))?;
// We only handle "tenant not found" errors; other 404s like timeline not found should
// be forwarded as-is.
if Service::is_tenant_not_found_error(resp_str, tenant_or_shard_id.tenant_id) {
// Rather than retry here, send the client a 503 to prompt a retry: this matches
// the pageserver's use of 503, and all clients calling this API should retry on 503.
return Err(ApiError::ResourceUnavailable(
format!(
"Pageserver {node} returned tenant 404 due to ongoing migration, retry later"
)
.into(),
));
}
}
let response = builder
.body(Body::from(resp_bytes))
.map_err(|e| ApiError::InternalServerError(e.into()))?;
Ok(response)
}).await?
} }
async fn handle_tenant_locate( async fn handle_tenant_locate(
@@ -1085,9 +1115,10 @@ async fn handle_node_delete(req: Request<Body>) -> Result<Response<Body>, ApiErr
let state = get_state(&req); let state = get_state(&req);
let node_id: NodeId = parse_request_param(&req, "node_id")?; let node_id: NodeId = parse_request_param(&req, "node_id")?;
let force: bool = parse_query_param(&req, "force")?.unwrap_or(false);
json_response( json_response(
StatusCode::OK, StatusCode::OK,
state.service.start_node_delete(node_id).await?, state.service.start_node_delete(node_id, force).await?,
) )
} }

View File

@@ -6,6 +6,7 @@ extern crate hyper0 as hyper;
mod auth; mod auth;
mod background_node_operations; mod background_node_operations;
mod compute_hook; mod compute_hook;
pub mod hadron_utils;
mod heartbeater; mod heartbeater;
pub mod http; pub mod http;
mod id_lock_map; mod id_lock_map;

View File

@@ -76,8 +76,8 @@ pub(crate) struct StorageControllerMetricGroup {
/// How many shards would like to reconcile but were blocked by concurrency limits /// How many shards would like to reconcile but were blocked by concurrency limits
pub(crate) storage_controller_pending_reconciles: measured::Gauge, pub(crate) storage_controller_pending_reconciles: measured::Gauge,
/// How many shards are keep-failing and will be ignored when considering to run optimizations /// How many shards are stuck and will be ignored when considering to run optimizations
pub(crate) storage_controller_keep_failing_reconciles: measured::Gauge, pub(crate) storage_controller_stuck_reconciles: measured::Gauge,
/// HTTP request status counters for handled requests /// HTTP request status counters for handled requests
pub(crate) storage_controller_http_request_status: pub(crate) storage_controller_http_request_status:
@@ -151,6 +151,29 @@ pub(crate) struct StorageControllerMetricGroup {
/// Indicator of completed safekeeper reconciles, broken down by safekeeper. /// Indicator of completed safekeeper reconciles, broken down by safekeeper.
pub(crate) storage_controller_safekeeper_reconciles_complete: pub(crate) storage_controller_safekeeper_reconciles_complete:
measured::CounterVec<SafekeeperReconcilerLabelGroupSet>, measured::CounterVec<SafekeeperReconcilerLabelGroupSet>,
/* BEGIN HADRON */
/// Hadron `config_watcher` reconciliation runs completed, broken down by success/failure.
pub(crate) storage_controller_config_watcher_complete:
measured::CounterVec<ConfigWatcherCompleteLabelGroupSet>,
/// Hadron long waits for node state changes during drain and fill.
pub(crate) storage_controller_drain_and_fill_long_waits: measured::Counter,
/// Set to 1 if we detect any page server pods with pending node pool rotation annotations.
/// Requires manual reset after oncall investigation.
pub(crate) storage_controller_ps_node_pool_rotation_pending: measured::Gauge,
/// Hadron storage scrubber status.
pub(crate) storage_controller_storage_scrub_status:
measured::CounterVec<StorageScrubberLabelGroupSet>,
/// Desired number of pageservers managed by the storage controller
pub(crate) storage_controller_num_pageservers_desired: measured::Gauge,
/// Desired number of safekeepers managed by the storage controller
pub(crate) storage_controller_num_safekeeper_desired: measured::Gauge,
/* END HADRON */
} }
impl StorageControllerMetrics { impl StorageControllerMetrics {
@@ -173,6 +196,10 @@ impl Default for StorageControllerMetrics {
.storage_controller_reconcile_complete .storage_controller_reconcile_complete
.init_all_dense(); .init_all_dense();
metrics_group
.storage_controller_config_watcher_complete
.init_all_dense();
Self { Self {
metrics_group, metrics_group,
encoder: Mutex::new(measured::text::BufferedTextEncoder::new()), encoder: Mutex::new(measured::text::BufferedTextEncoder::new()),
@@ -262,11 +289,48 @@ pub(crate) struct ReconcileLongRunningLabelGroup<'a> {
pub(crate) sequence: &'a str, pub(crate) sequence: &'a str,
} }
#[derive(measured::LabelGroup, Clone)]
#[label(set = StorageScrubberLabelGroupSet)]
pub(crate) struct StorageScrubberLabelGroup<'a> {
#[label(dynamic_with = lasso::ThreadedRodeo, default)]
pub(crate) tenant_id: &'a str,
#[label(dynamic_with = lasso::ThreadedRodeo, default)]
pub(crate) shard_number: &'a str,
#[label(dynamic_with = lasso::ThreadedRodeo, default)]
pub(crate) timeline_id: &'a str,
pub(crate) outcome: StorageScrubberOutcome,
}
#[derive(FixedCardinalityLabel, Clone, Copy)]
pub(crate) enum StorageScrubberOutcome {
PSOk,
PSWarning,
PSError,
PSOrphan,
SKOk,
SKError,
}
#[derive(measured::LabelGroup)]
#[label(set = ConfigWatcherCompleteLabelGroupSet)]
pub(crate) struct ConfigWatcherCompleteLabelGroup {
// Reuse the ReconcileOutcome from the SC's reconciliation metrics.
pub(crate) status: ReconcileOutcome,
}
#[derive(FixedCardinalityLabel, Clone, Copy)] #[derive(FixedCardinalityLabel, Clone, Copy)]
pub(crate) enum ReconcileOutcome { pub(crate) enum ReconcileOutcome {
// Successfully reconciled everything.
#[label(rename = "ok")] #[label(rename = "ok")]
Success, Success,
// Used by tenant-shard reconciler only. Reconciled pageserver state successfully,
// but failed to delivery the compute notificiation. This error is typically transient
// but if its occurance keeps increasing, it should be investigated.
#[label(rename = "ok_no_notify")]
SuccessNoNotify,
// We failed to reconcile some state and the reconcilation will be retried.
Error, Error,
// Reconciliation was cancelled.
Cancel, Cancel,
} }

View File

@@ -51,6 +51,39 @@ pub(crate) struct Node {
cancel: CancellationToken, cancel: CancellationToken,
} }
#[allow(dead_code)]
const ONE_MILLION: i64 = 1000000;
// Converts a pool ID to a large number that can be used to assign unique IDs to pods in StatefulSets.
/// For example, if pool_id is 1, then the pods have NodeIds 1000000, 1000001, 1000002, etc.
/// If pool_id is None, then the pods have NodeIds 0, 1, 2, etc.
#[allow(dead_code)]
pub fn transform_pool_id(pool_id: Option<i32>) -> i64 {
match pool_id {
Some(id) => (id as i64) * ONE_MILLION,
None => 0,
}
}
#[allow(dead_code)]
pub fn get_pool_id_from_node_id(node_id: i64) -> i32 {
(node_id / ONE_MILLION) as i32
}
/// Example pod name: page-server-0-1, safe-keeper-1-0
#[allow(dead_code)]
pub fn get_node_id_from_pod_name(pod_name: &str) -> anyhow::Result<NodeId> {
let parts: Vec<&str> = pod_name.split('-').collect();
if parts.len() != 4 {
return Err(anyhow::anyhow!("Invalid pod name: {}", pod_name));
}
let pool_id = parts[2].parse::<i32>()?;
let node_offset = parts[3].parse::<i64>()?;
let node_id = transform_pool_id(Some(pool_id)) + node_offset;
Ok(NodeId(node_id as u64))
}
/// When updating [`Node::availability`] we use this type to indicate to the caller /// When updating [`Node::availability`] we use this type to indicate to the caller
/// whether/how they changed it. /// whether/how they changed it.
pub(crate) enum AvailabilityTransition { pub(crate) enum AvailabilityTransition {
@@ -403,3 +436,25 @@ impl std::fmt::Debug for Node {
write!(f, "{} ({})", self.id, self.listen_http_addr) write!(f, "{} ({})", self.id, self.listen_http_addr)
} }
} }
#[cfg(test)]
mod tests {
use utils::id::NodeId;
use crate::node::get_node_id_from_pod_name;
#[test]
fn test_get_node_id_from_pod_name() {
let pod_name = "page-server-3-12";
let node_id = get_node_id_from_pod_name(pod_name).unwrap();
assert_eq!(node_id, NodeId(3000012));
let pod_name = "safe-keeper-1-0";
let node_id = get_node_id_from_pod_name(pod_name).unwrap();
assert_eq!(node_id, NodeId(1000000));
let pod_name = "invalid-pod-name";
let result = get_node_id_from_pod_name(pod_name);
assert!(result.is_err());
}
}

View File

@@ -14,6 +14,8 @@ use reqwest::StatusCode;
use utils::id::{NodeId, TenantId, TimelineId}; use utils::id::{NodeId, TenantId, TimelineId};
use utils::lsn::Lsn; use utils::lsn::Lsn;
use crate::hadron_utils::TenantShardSizeMap;
/// Thin wrapper around [`pageserver_client::mgmt_api::Client`]. It allows the storage /// Thin wrapper around [`pageserver_client::mgmt_api::Client`]. It allows the storage
/// controller to collect metrics in a non-intrusive manner. /// controller to collect metrics in a non-intrusive manner.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -86,6 +88,31 @@ impl PageserverClient {
) )
} }
#[expect(dead_code)]
pub(crate) async fn tenant_timeline_compact(
&self,
tenant_shard_id: TenantShardId,
timeline_id: TimelineId,
force_image_layer_creation: bool,
wait_until_done: bool,
) -> Result<()> {
measured_request!(
"tenant_timeline_compact",
crate::metrics::Method::Put,
&self.node_id_label,
self.inner
.tenant_timeline_compact(
tenant_shard_id,
timeline_id,
force_image_layer_creation,
true,
false,
wait_until_done,
)
.await
)
}
/* BEGIN_HADRON */ /* BEGIN_HADRON */
pub(crate) async fn tenant_timeline_describe( pub(crate) async fn tenant_timeline_describe(
&self, &self,
@@ -101,6 +128,17 @@ impl PageserverClient {
.await .await
) )
} }
#[expect(dead_code)]
pub(crate) async fn list_tenant_visible_size(&self) -> Result<TenantShardSizeMap> {
measured_request!(
"list_tenant_visible_size",
crate::metrics::Method::Get,
&self.node_id_label,
self.inner.list_tenant_visible_size().await
)
.map(TenantShardSizeMap::new)
}
/* END_HADRON */ /* END_HADRON */
pub(crate) async fn tenant_scan_remote_storage( pub(crate) async fn tenant_scan_remote_storage(
@@ -365,6 +403,16 @@ impl PageserverClient {
) )
} }
#[expect(dead_code)]
pub(crate) async fn reset_alert_gauges(&self) -> Result<()> {
measured_request!(
"reset_alert_gauges",
crate::metrics::Method::Post,
&self.node_id_label,
self.inner.reset_alert_gauges().await
)
}
pub(crate) async fn wait_lsn( pub(crate) async fn wait_lsn(
&self, &self,
tenant_shard_id: TenantShardId, tenant_shard_id: TenantShardId,

View File

@@ -862,11 +862,11 @@ impl Reconciler {
Some(conf) if conf.conf.as_ref() == Some(&wanted_conf) => { Some(conf) if conf.conf.as_ref() == Some(&wanted_conf) => {
if refreshed { if refreshed {
tracing::info!( tracing::info!(
node_id=%node.get_id(), "Observed configuration correct after refresh. Notifying compute."); node_id=%node.get_id(), "[Attached] Observed configuration correct after refresh. Notifying compute.");
self.compute_notify().await?; self.compute_notify().await?;
} else { } else {
// Nothing to do // Nothing to do
tracing::info!(node_id=%node.get_id(), "Observed configuration already correct."); tracing::info!(node_id=%node.get_id(), "[Attached] Observed configuration already correct.");
} }
} }
observed => { observed => {
@@ -945,17 +945,17 @@ impl Reconciler {
match self.observed.locations.get(&node.get_id()) { match self.observed.locations.get(&node.get_id()) {
Some(conf) if conf.conf.as_ref() == Some(&wanted_conf) => { Some(conf) if conf.conf.as_ref() == Some(&wanted_conf) => {
// Nothing to do // Nothing to do
tracing::info!(node_id=%node.get_id(), "Observed configuration already correct.") tracing::info!(node_id=%node.get_id(), "[Secondary] Observed configuration already correct.")
} }
_ => { _ => {
// Only try and configure secondary locations on nodes that are available. This // Only try and configure secondary locations on nodes that are available. This
// allows the reconciler to "succeed" while some secondaries are offline (e.g. after // allows the reconciler to "succeed" while some secondaries are offline (e.g. after
// a node failure, where the failed node will have a secondary intent) // a node failure, where the failed node will have a secondary intent)
if node.is_available() { if node.is_available() {
tracing::info!(node_id=%node.get_id(), "Observed configuration requires update."); tracing::info!(node_id=%node.get_id(), "[Secondary] Observed configuration requires update.");
changes.push((node.clone(), wanted_conf)) changes.push((node.clone(), wanted_conf))
} else { } else {
tracing::info!(node_id=%node.get_id(), "Skipping configuration as secondary, node is unavailable"); tracing::info!(node_id=%node.get_id(), "[Secondary] Skipping configuration as secondary, node is unavailable");
self.observed self.observed
.locations .locations
.insert(node.get_id(), ObservedStateLocation { conf: None }); .insert(node.get_id(), ObservedStateLocation { conf: None });
@@ -1066,6 +1066,9 @@ impl Reconciler {
} }
result result
} else { } else {
tracing::info!(
"Compute notification is skipped because the tenant shard does not have an attached (primary) location"
);
Ok(()) Ok(())
} }
} }

View File

@@ -13,6 +13,24 @@ diesel::table! {
} }
} }
diesel::table! {
hadron_safekeepers (sk_node_id) {
sk_node_id -> Int8,
listen_http_addr -> Varchar,
listen_http_port -> Int4,
listen_pg_addr -> Varchar,
listen_pg_port -> Int4,
}
}
diesel::table! {
hadron_timeline_safekeepers (timeline_id, sk_node_id) {
timeline_id -> Varchar,
sk_node_id -> Int8,
legacy_endpoint_id -> Nullable<Uuid>,
}
}
diesel::table! { diesel::table! {
metadata_health (tenant_id, shard_number, shard_count) { metadata_health (tenant_id, shard_number, shard_count) {
tenant_id -> Varchar, tenant_id -> Varchar,
@@ -105,6 +123,8 @@ diesel::table! {
diesel::allow_tables_to_appear_in_same_query!( diesel::allow_tables_to_appear_in_same_query!(
controllers, controllers,
hadron_safekeepers,
hadron_timeline_safekeepers,
metadata_health, metadata_health,
nodes, nodes,
safekeeper_timeline_pending_ops, safekeeper_timeline_pending_ops,

View File

@@ -207,34 +207,13 @@ enum ShardGenerationValidity {
}, },
} }
/// We collect the state of attachments for some operations to determine if the operation
/// needs to be retried when it fails.
struct TenantShardAttachState {
/// The targets of the operation.
///
/// Tenant shard ID, node ID, node, is intent node observed primary.
targets: Vec<(TenantShardId, NodeId, Node, bool)>,
/// The targets grouped by node ID.
by_node_id: HashMap<NodeId, (TenantShardId, Node, bool)>,
}
impl TenantShardAttachState {
fn for_api_call(&self) -> Vec<(TenantShardId, Node)> {
self.targets
.iter()
.map(|(tenant_shard_id, _, node, _)| (*tenant_shard_id, node.clone()))
.collect()
}
}
pub const RECONCILER_CONCURRENCY_DEFAULT: usize = 128; pub const RECONCILER_CONCURRENCY_DEFAULT: usize = 128;
pub const PRIORITY_RECONCILER_CONCURRENCY_DEFAULT: usize = 256; pub const PRIORITY_RECONCILER_CONCURRENCY_DEFAULT: usize = 256;
pub const SAFEKEEPER_RECONCILER_CONCURRENCY_DEFAULT: usize = 32; pub const SAFEKEEPER_RECONCILER_CONCURRENCY_DEFAULT: usize = 32;
// Number of consecutive reconciliation errors, occured for one shard, // Number of consecutive reconciliations that have occurred for one shard,
// after which the shard is ignored when considering to run optimizations. // after which the shard is ignored when considering to run optimizations.
const MAX_CONSECUTIVE_RECONCILIATION_ERRORS: usize = 5; const MAX_CONSECUTIVE_RECONCILES: usize = 10;
// Depth of the channel used to enqueue shards for reconciliation when they can't do it immediately. // Depth of the channel used to enqueue shards for reconciliation when they can't do it immediately.
// This channel is finite-size to avoid using excessive memory if we get into a state where reconciles are finishing more slowly // This channel is finite-size to avoid using excessive memory if we get into a state where reconciles are finishing more slowly
@@ -719,47 +698,70 @@ pub(crate) enum ReconcileResultRequest {
} }
#[derive(Clone)] #[derive(Clone)]
struct MutationLocation { pub(crate) struct MutationLocation {
node: Node, pub(crate) node: Node,
generation: Generation, pub(crate) generation: Generation,
} }
#[derive(Clone)] #[derive(Clone)]
struct ShardMutationLocations { pub(crate) struct ShardMutationLocations {
latest: MutationLocation, pub(crate) latest: MutationLocation,
other: Vec<MutationLocation>, pub(crate) other: Vec<MutationLocation>,
} }
#[derive(Default, Clone)] #[derive(Default, Clone)]
struct TenantMutationLocations(BTreeMap<TenantShardId, ShardMutationLocations>); pub(crate) struct TenantMutationLocations(pub BTreeMap<TenantShardId, ShardMutationLocations>);
struct ReconcileAllResult { struct ReconcileAllResult {
spawned_reconciles: usize, spawned_reconciles: usize,
keep_failing_reconciles: usize, stuck_reconciles: usize,
has_delayed_reconciles: bool, has_delayed_reconciles: bool,
} }
impl ReconcileAllResult { impl ReconcileAllResult {
fn new( fn new(
spawned_reconciles: usize, spawned_reconciles: usize,
keep_failing_reconciles: usize, stuck_reconciles: usize,
has_delayed_reconciles: bool, has_delayed_reconciles: bool,
) -> Self { ) -> Self {
assert!( assert!(
spawned_reconciles >= keep_failing_reconciles, spawned_reconciles >= stuck_reconciles,
"It is impossible to have more keep-failing reconciles than spawned reconciles" "It is impossible to have less spawned reconciles than stuck reconciles"
); );
Self { Self {
spawned_reconciles, spawned_reconciles,
keep_failing_reconciles, stuck_reconciles,
has_delayed_reconciles, has_delayed_reconciles,
} }
} }
/// We can run optimizations only if we don't have any delayed reconciles and /// We can run optimizations only if we don't have any delayed reconciles and
/// all spawned reconciles are also keep-failing reconciles. /// all spawned reconciles are also stuck reconciles.
fn can_run_optimizations(&self) -> bool { fn can_run_optimizations(&self) -> bool {
!self.has_delayed_reconciles && self.spawned_reconciles == self.keep_failing_reconciles !self.has_delayed_reconciles && self.spawned_reconciles == self.stuck_reconciles
}
}
enum TenantIdOrShardId {
TenantId(TenantId),
TenantShardId(TenantShardId),
}
impl TenantIdOrShardId {
fn tenant_id(&self) -> TenantId {
match self {
TenantIdOrShardId::TenantId(tenant_id) => *tenant_id,
TenantIdOrShardId::TenantShardId(tenant_shard_id) => tenant_shard_id.tenant_id,
}
}
fn matches(&self, tenant_shard_id: &TenantShardId) -> bool {
match self {
TenantIdOrShardId::TenantId(tenant_id) => tenant_shard_id.tenant_id == *tenant_id,
TenantIdOrShardId::TenantShardId(this_tenant_shard_id) => {
this_tenant_shard_id == tenant_shard_id
}
}
} }
} }
@@ -1503,7 +1505,6 @@ impl Service {
match result.result { match result.result {
Ok(()) => { Ok(()) => {
tenant.consecutive_errors_count = 0;
tenant.apply_observed_deltas(deltas); tenant.apply_observed_deltas(deltas);
tenant.waiter.advance(result.sequence); tenant.waiter.advance(result.sequence);
} }
@@ -1522,8 +1523,6 @@ impl Service {
} }
} }
tenant.consecutive_errors_count = tenant.consecutive_errors_count.saturating_add(1);
// Ordering: populate last_error before advancing error_seq, // Ordering: populate last_error before advancing error_seq,
// so that waiters will see the correct error after waiting. // so that waiters will see the correct error after waiting.
tenant.set_last_error(result.sequence, e); tenant.set_last_error(result.sequence, e);
@@ -1535,6 +1534,8 @@ impl Service {
} }
} }
tenant.consecutive_reconciles_count = tenant.consecutive_reconciles_count.saturating_add(1);
// If we just finished detaching all shards for a tenant, it might be time to drop it from memory. // If we just finished detaching all shards for a tenant, it might be time to drop it from memory.
if tenant.policy == PlacementPolicy::Detached { if tenant.policy == PlacementPolicy::Detached {
// We may only drop a tenant from memory while holding the exclusive lock on the tenant ID: this protects us // We may only drop a tenant from memory while holding the exclusive lock on the tenant ID: this protects us
@@ -4773,72 +4774,24 @@ impl Service {
Ok(()) Ok(())
} }
fn is_observed_consistent_with_intent( pub(crate) fn is_tenant_not_found_error(body: &str, tenant_id: TenantId) -> bool {
&self, body.contains(&format!("tenant {tenant_id}"))
shard: &TenantShard,
intent_node_id: NodeId,
) -> bool {
if let Some(location) = shard.observed.locations.get(&intent_node_id)
&& let Some(ref conf) = location.conf
&& (conf.mode == LocationConfigMode::AttachedSingle
|| conf.mode == LocationConfigMode::AttachedMulti)
{
true
} else {
false
}
}
fn collect_tenant_shards(
&self,
tenant_id: TenantId,
) -> Result<TenantShardAttachState, ApiError> {
let locked = self.inner.read().unwrap();
let mut targets = Vec::new();
let mut by_node_id = HashMap::new();
// If the request got an unsharded tenant id, then apply
// the operation to all shards. Otherwise, apply it to a specific shard.
let shards_range = TenantShardId::tenant_range(tenant_id);
for (tenant_shard_id, shard) in locked.tenants.range(shards_range) {
if let Some(node_id) = shard.intent.get_attached() {
let node = locked
.nodes
.get(node_id)
.expect("Pageservers may not be deleted while referenced");
let consistent = self.is_observed_consistent_with_intent(shard, *node_id);
targets.push((*tenant_shard_id, *node_id, node.clone(), consistent));
by_node_id.insert(*node_id, (*tenant_shard_id, node.clone(), consistent));
}
}
Ok(TenantShardAttachState {
targets,
by_node_id,
})
} }
fn process_result_and_passthrough_errors<T>( fn process_result_and_passthrough_errors<T>(
&self, &self,
tenant_id: TenantId,
results: Vec<(Node, Result<T, mgmt_api::Error>)>, results: Vec<(Node, Result<T, mgmt_api::Error>)>,
attach_state: TenantShardAttachState,
) -> Result<Vec<(Node, T)>, ApiError> { ) -> Result<Vec<(Node, T)>, ApiError> {
let mut processed_results: Vec<(Node, T)> = Vec::with_capacity(results.len()); let mut processed_results: Vec<(Node, T)> = Vec::with_capacity(results.len());
debug_assert_eq!(results.len(), attach_state.targets.len());
for (node, res) in results { for (node, res) in results {
let is_consistent = attach_state
.by_node_id
.get(&node.get_id())
.map(|(_, _, consistent)| *consistent);
match res { match res {
Ok(res) => processed_results.push((node, res)), Ok(res) => processed_results.push((node, res)),
Err(mgmt_api::Error::ApiError(StatusCode::NOT_FOUND, _)) Err(mgmt_api::Error::ApiError(StatusCode::NOT_FOUND, body))
if is_consistent == Some(false) => if Self::is_tenant_not_found_error(&body, tenant_id) =>
{ {
// This is expected if the attach is not finished yet. Return 503 so that the client can retry. // If there's a tenant not found, we are still in the process of attaching the tenant.
// Return 503 so that the client can retry.
return Err(ApiError::ResourceUnavailable( return Err(ApiError::ResourceUnavailable(
format!( format!(
"Timeline is not attached to the pageserver {} yet, please retry", "Timeline is not attached to the pageserver {} yet, please retry",
@@ -4866,35 +4819,48 @@ impl Service {
) )
.await; .await;
let attach_state = self.collect_tenant_shards(tenant_id)?; self.tenant_remote_mutation(tenant_id, |locations| async move {
if locations.0.is_empty() {
let results = self return Err(ApiError::NotFound(
.tenant_for_shards_api( anyhow::anyhow!("Tenant not found").into(),
attach_state.for_api_call(), ));
|tenant_shard_id, client| async move {
client
.timeline_lease_lsn(tenant_shard_id, timeline_id, lsn)
.await
},
1,
1,
SHORT_RECONCILE_TIMEOUT,
&self.cancel,
)
.await;
let leases = self.process_result_and_passthrough_errors(results, attach_state)?;
let mut valid_until = None;
for (_, lease) in leases {
if let Some(ref mut valid_until) = valid_until {
*valid_until = std::cmp::min(*valid_until, lease.valid_until);
} else {
valid_until = Some(lease.valid_until);
} }
}
Ok(LsnLease { let results = self
valid_until: valid_until.unwrap_or_else(SystemTime::now), .tenant_for_shards_api(
locations
.0
.iter()
.map(|(tenant_shard_id, ShardMutationLocations { latest, .. })| {
(*tenant_shard_id, latest.node.clone())
})
.collect(),
|tenant_shard_id, client| async move {
client
.timeline_lease_lsn(tenant_shard_id, timeline_id, lsn)
.await
},
1,
1,
SHORT_RECONCILE_TIMEOUT,
&self.cancel,
)
.await;
let leases = self.process_result_and_passthrough_errors(tenant_id, results)?;
let mut valid_until = None;
for (_, lease) in leases {
if let Some(ref mut valid_until) = valid_until {
*valid_until = std::cmp::min(*valid_until, lease.valid_until);
} else {
valid_until = Some(lease.valid_until);
}
}
Ok(LsnLease {
valid_until: valid_until.unwrap_or_else(SystemTime::now),
})
}) })
.await?
} }
pub(crate) async fn tenant_timeline_download_heatmap_layers( pub(crate) async fn tenant_timeline_download_heatmap_layers(
@@ -5041,11 +5007,37 @@ impl Service {
/// - Looks up the shards and the nodes where they were most recently attached /// - Looks up the shards and the nodes where they were most recently attached
/// - Guarantees that after the inner function returns, the shards' generations haven't moved on: this /// - Guarantees that after the inner function returns, the shards' generations haven't moved on: this
/// ensures that the remote operation acted on the most recent generation, and is therefore durable. /// ensures that the remote operation acted on the most recent generation, and is therefore durable.
async fn tenant_remote_mutation<R, O, F>( pub(crate) async fn tenant_remote_mutation<R, O, F>(
&self, &self,
tenant_id: TenantId, tenant_id: TenantId,
op: O, op: O,
) -> Result<R, ApiError> ) -> Result<R, ApiError>
where
O: FnOnce(TenantMutationLocations) -> F,
F: std::future::Future<Output = R>,
{
self.tenant_remote_mutation_inner(TenantIdOrShardId::TenantId(tenant_id), op)
.await
}
pub(crate) async fn tenant_shard_remote_mutation<R, O, F>(
&self,
tenant_shard_id: TenantShardId,
op: O,
) -> Result<R, ApiError>
where
O: FnOnce(TenantMutationLocations) -> F,
F: std::future::Future<Output = R>,
{
self.tenant_remote_mutation_inner(TenantIdOrShardId::TenantShardId(tenant_shard_id), op)
.await
}
async fn tenant_remote_mutation_inner<R, O, F>(
&self,
tenant_id_or_shard_id: TenantIdOrShardId,
op: O,
) -> Result<R, ApiError>
where where
O: FnOnce(TenantMutationLocations) -> F, O: FnOnce(TenantMutationLocations) -> F,
F: std::future::Future<Output = R>, F: std::future::Future<Output = R>,
@@ -5057,7 +5049,13 @@ impl Service {
// run concurrently with reconciliations, and it is not guaranteed that the node we find here // run concurrently with reconciliations, and it is not guaranteed that the node we find here
// will still be the latest when we're done: we will check generations again at the end of // will still be the latest when we're done: we will check generations again at the end of
// this function to handle that. // this function to handle that.
let generations = self.persistence.tenant_generations(tenant_id).await?; let generations = self
.persistence
.tenant_generations(tenant_id_or_shard_id.tenant_id())
.await?
.into_iter()
.filter(|i| tenant_id_or_shard_id.matches(&i.tenant_shard_id))
.collect::<Vec<_>>();
if generations if generations
.iter() .iter()
@@ -5071,9 +5069,14 @@ impl Service {
// One or more shards has not been attached to a pageserver. Check if this is because it's configured // One or more shards has not been attached to a pageserver. Check if this is because it's configured
// to be detached (409: caller should give up), or because it's meant to be attached but isn't yet (503: caller should retry) // to be detached (409: caller should give up), or because it's meant to be attached but isn't yet (503: caller should retry)
let locked = self.inner.read().unwrap(); let locked = self.inner.read().unwrap();
for (shard_id, shard) in let tenant_shards = locked
locked.tenants.range(TenantShardId::tenant_range(tenant_id)) .tenants
{ .range(TenantShardId::tenant_range(
tenant_id_or_shard_id.tenant_id(),
))
.filter(|(shard_id, _)| tenant_id_or_shard_id.matches(shard_id))
.collect::<Vec<_>>();
for (shard_id, shard) in tenant_shards {
match shard.policy { match shard.policy {
PlacementPolicy::Attached(_) => { PlacementPolicy::Attached(_) => {
// This shard is meant to be attached: the caller is not wrong to try and // This shard is meant to be attached: the caller is not wrong to try and
@@ -5183,7 +5186,14 @@ impl Service {
// Post-check: are all the generations of all the shards the same as they were initially? This proves that // Post-check: are all the generations of all the shards the same as they were initially? This proves that
// our remote operation executed on the latest generation and is therefore persistent. // our remote operation executed on the latest generation and is therefore persistent.
{ {
let latest_generations = self.persistence.tenant_generations(tenant_id).await?; let latest_generations = self
.persistence
.tenant_generations(tenant_id_or_shard_id.tenant_id())
.await?
.into_iter()
.filter(|i| tenant_id_or_shard_id.matches(&i.tenant_shard_id))
.collect::<Vec<_>>();
if latest_generations if latest_generations
.into_iter() .into_iter()
.map( .map(
@@ -5317,7 +5327,7 @@ impl Service {
pub(crate) async fn tenant_shard0_node( pub(crate) async fn tenant_shard0_node(
&self, &self,
tenant_id: TenantId, tenant_id: TenantId,
) -> Result<(Node, TenantShardId, bool), ApiError> { ) -> Result<(Node, TenantShardId), ApiError> {
let tenant_shard_id = { let tenant_shard_id = {
let locked = self.inner.read().unwrap(); let locked = self.inner.read().unwrap();
let Some((tenant_shard_id, _shard)) = locked let Some((tenant_shard_id, _shard)) = locked
@@ -5335,7 +5345,7 @@ impl Service {
self.tenant_shard_node(tenant_shard_id) self.tenant_shard_node(tenant_shard_id)
.await .await
.map(|(node, consistent)| (node, tenant_shard_id, consistent)) .map(|node| (node, tenant_shard_id))
} }
/// When you need to send an HTTP request to the pageserver that holds a shard of a tenant, this /// When you need to send an HTTP request to the pageserver that holds a shard of a tenant, this
@@ -5345,7 +5355,7 @@ impl Service {
pub(crate) async fn tenant_shard_node( pub(crate) async fn tenant_shard_node(
&self, &self,
tenant_shard_id: TenantShardId, tenant_shard_id: TenantShardId,
) -> Result<(Node, bool), ApiError> { ) -> Result<Node, ApiError> {
// Look up in-memory state and maybe use the node from there. // Look up in-memory state and maybe use the node from there.
{ {
let locked = self.inner.read().unwrap(); let locked = self.inner.read().unwrap();
@@ -5375,8 +5385,7 @@ impl Service {
"Shard refers to nonexistent node" "Shard refers to nonexistent node"
))); )));
}; };
let consistent = self.is_observed_consistent_with_intent(shard, *intent_node_id); return Ok(node.clone());
return Ok((node.clone(), consistent));
} }
}; };
@@ -5411,7 +5420,7 @@ impl Service {
))); )));
}; };
// As a reconciliation is in flight, we do not have the observed state yet, and therefore we assume it is always inconsistent. // As a reconciliation is in flight, we do not have the observed state yet, and therefore we assume it is always inconsistent.
Ok((node.clone(), false)) Ok(node.clone())
} }
pub(crate) fn tenant_locate( pub(crate) fn tenant_locate(
@@ -7385,6 +7394,7 @@ impl Service {
self: &Arc<Self>, self: &Arc<Self>,
node_id: NodeId, node_id: NodeId,
policy_on_start: NodeSchedulingPolicy, policy_on_start: NodeSchedulingPolicy,
force: bool,
cancel: CancellationToken, cancel: CancellationToken,
) -> Result<(), OperationError> { ) -> Result<(), OperationError> {
let reconciler_config = ReconcilerConfigBuilder::new(ReconcilerPriority::Normal).build(); let reconciler_config = ReconcilerConfigBuilder::new(ReconcilerPriority::Normal).build();
@@ -7392,23 +7402,27 @@ impl Service {
let mut waiters: Vec<ReconcilerWaiter> = Vec::new(); let mut waiters: Vec<ReconcilerWaiter> = Vec::new();
let mut tid_iter = create_shared_shard_iterator(self.clone()); let mut tid_iter = create_shared_shard_iterator(self.clone());
let reset_node_policy_on_cancel = || async {
match self
.node_configure(node_id, None, Some(policy_on_start))
.await
{
Ok(()) => OperationError::Cancelled,
Err(err) => {
OperationError::FinalizeError(
format!(
"Failed to finalise delete cancel of {} by setting scheduling policy to {}: {}",
node_id, String::from(policy_on_start), err
)
.into(),
)
}
}
};
while !tid_iter.finished() { while !tid_iter.finished() {
if cancel.is_cancelled() { if cancel.is_cancelled() {
match self return Err(reset_node_policy_on_cancel().await);
.node_configure(node_id, None, Some(policy_on_start))
.await
{
Ok(()) => return Err(OperationError::Cancelled),
Err(err) => {
return Err(OperationError::FinalizeError(
format!(
"Failed to finalise delete cancel of {} by setting scheduling policy to {}: {}",
node_id, String::from(policy_on_start), err
)
.into(),
));
}
}
} }
operation_utils::validate_node_state( operation_utils::validate_node_state(
@@ -7477,8 +7491,18 @@ impl Service {
nodes, nodes,
reconciler_config, reconciler_config,
); );
if let Some(some) = waiter {
waiters.push(some); if force {
// Here we remove an existing observed location for the node we're removing, and it will
// not be re-added by a reconciler's completion because we filter out removed nodes in
// process_result.
//
// Note that we update the shard's observed state _after_ calling maybe_configured_reconcile_shard:
// that means any reconciles we spawned will know about the node we're deleting,
// enabling them to do live migrations if it's still online.
tenant_shard.observed.locations.remove(&node_id);
} else if let Some(waiter) = waiter {
waiters.push(waiter);
} }
} }
} }
@@ -7492,21 +7516,7 @@ impl Service {
while !waiters.is_empty() { while !waiters.is_empty() {
if cancel.is_cancelled() { if cancel.is_cancelled() {
match self return Err(reset_node_policy_on_cancel().await);
.node_configure(node_id, None, Some(policy_on_start))
.await
{
Ok(()) => return Err(OperationError::Cancelled),
Err(err) => {
return Err(OperationError::FinalizeError(
format!(
"Failed to finalise drain cancel of {} by setting scheduling policy to {}: {}",
node_id, String::from(policy_on_start), err
)
.into(),
));
}
}
} }
tracing::info!("Awaiting {} pending delete reconciliations", waiters.len()); tracing::info!("Awaiting {} pending delete reconciliations", waiters.len());
@@ -7516,6 +7526,12 @@ impl Service {
.await; .await;
} }
let pf = pausable_failpoint!("delete-node-after-reconciles-spawned", &cancel);
if pf.is_err() {
// An error from pausable_failpoint indicates the cancel token was triggered.
return Err(reset_node_policy_on_cancel().await);
}
self.persistence self.persistence
.set_tombstone(node_id) .set_tombstone(node_id)
.await .await
@@ -8111,6 +8127,7 @@ impl Service {
pub(crate) async fn start_node_delete( pub(crate) async fn start_node_delete(
self: &Arc<Self>, self: &Arc<Self>,
node_id: NodeId, node_id: NodeId,
force: bool,
) -> Result<(), ApiError> { ) -> Result<(), ApiError> {
let (ongoing_op, node_policy, schedulable_nodes_count) = { let (ongoing_op, node_policy, schedulable_nodes_count) = {
let locked = self.inner.read().unwrap(); let locked = self.inner.read().unwrap();
@@ -8180,7 +8197,7 @@ impl Service {
tracing::info!("Delete background operation starting"); tracing::info!("Delete background operation starting");
let res = service let res = service
.delete_node(node_id, policy_on_start, cancel) .delete_node(node_id, policy_on_start, force, cancel)
.await; .await;
match res { match res {
Ok(()) => { Ok(()) => {
@@ -8632,7 +8649,7 @@ impl Service {
// This function is an efficient place to update lazy statistics, since we are walking // This function is an efficient place to update lazy statistics, since we are walking
// all tenants. // all tenants.
let mut pending_reconciles = 0; let mut pending_reconciles = 0;
let mut keep_failing_reconciles = 0; let mut stuck_reconciles = 0;
let mut az_violations = 0; let mut az_violations = 0;
// If we find any tenants to drop from memory, stash them to offload after // If we find any tenants to drop from memory, stash them to offload after
@@ -8668,30 +8685,32 @@ impl Service {
// Eventual consistency: if an earlier reconcile job failed, and the shard is still // Eventual consistency: if an earlier reconcile job failed, and the shard is still
// dirty, spawn another one // dirty, spawn another one
let consecutive_errors_count = shard.consecutive_errors_count;
if self if self
.maybe_reconcile_shard(shard, &pageservers, ReconcilerPriority::Normal) .maybe_reconcile_shard(shard, &pageservers, ReconcilerPriority::Normal)
.is_some() .is_some()
{ {
spawned_reconciles += 1; spawned_reconciles += 1;
// Count shards that are keep-failing. We still want to reconcile them if shard.consecutive_reconciles_count >= MAX_CONSECUTIVE_RECONCILES {
// to avoid a situation where a shard is stuck. // Count shards that are stuck, butwe still want to reconcile them.
// But we don't want to consider them when deciding to run optimizations. // We don't want to consider them when deciding to run optimizations.
if consecutive_errors_count >= MAX_CONSECUTIVE_RECONCILIATION_ERRORS {
tracing::warn!( tracing::warn!(
tenant_id=%shard.tenant_shard_id.tenant_id, tenant_id=%shard.tenant_shard_id.tenant_id,
shard_id=%shard.tenant_shard_id.shard_slug(), shard_id=%shard.tenant_shard_id.shard_slug(),
"Shard reconciliation is keep-failing: {} errors", "Shard reconciliation is stuck: {} consecutive launches",
consecutive_errors_count shard.consecutive_reconciles_count
); );
keep_failing_reconciles += 1; stuck_reconciles += 1;
}
} else {
if shard.delayed_reconcile {
// Shard wanted to reconcile but for some reason couldn't.
pending_reconciles += 1;
} }
} else if shard.delayed_reconcile {
// Shard wanted to reconcile but for some reason couldn't.
pending_reconciles += 1;
}
// Reset the counter when we don't need to launch a reconcile.
shard.consecutive_reconciles_count = 0;
}
// If this tenant is detached, try dropping it from memory. This is usually done // If this tenant is detached, try dropping it from memory. This is usually done
// proactively in [`Self::process_results`], but we do it here to handle the edge // proactively in [`Self::process_results`], but we do it here to handle the edge
// case where a reconcile completes while someone else is holding an op lock for the tenant. // case where a reconcile completes while someone else is holding an op lock for the tenant.
@@ -8727,14 +8746,10 @@ impl Service {
metrics::METRICS_REGISTRY metrics::METRICS_REGISTRY
.metrics_group .metrics_group
.storage_controller_keep_failing_reconciles .storage_controller_stuck_reconciles
.set(keep_failing_reconciles as i64); .set(stuck_reconciles as i64);
ReconcileAllResult::new( ReconcileAllResult::new(spawned_reconciles, stuck_reconciles, has_delayed_reconciles)
spawned_reconciles,
keep_failing_reconciles,
has_delayed_reconciles,
)
} }
/// `optimize` in this context means identifying shards which have valid scheduled locations, but /// `optimize` in this context means identifying shards which have valid scheduled locations, but

View File

@@ -131,14 +131,16 @@ pub(crate) struct TenantShard {
#[serde(serialize_with = "read_last_error")] #[serde(serialize_with = "read_last_error")]
pub(crate) last_error: std::sync::Arc<std::sync::Mutex<Option<Arc<ReconcileError>>>>, pub(crate) last_error: std::sync::Arc<std::sync::Mutex<Option<Arc<ReconcileError>>>>,
/// Number of consecutive reconciliation errors that have occurred for this shard. /// Amount of consecutive [`crate::service::Service::reconcile_all`] iterations that have been
/// scheduled a reconciliation for this shard.
/// ///
/// When this count reaches MAX_CONSECUTIVE_RECONCILIATION_ERRORS, the tenant shard /// If this reaches `MAX_CONSECUTIVE_RECONCILES`, the shard is considered "stuck" and will be
/// will be countered as keep-failing in `reconcile_all` calculations. This will lead to /// ignored when deciding whether optimizations can run. This includes both successful and failed
/// allowing optimizations to run even with some failing shards. /// reconciliations.
/// ///
/// The counter is reset to 0 after a successful reconciliation. /// Incremented in [`crate::service::Service::process_result`], and reset to 0 when
pub(crate) consecutive_errors_count: usize, /// [`crate::service::Service::reconcile_all`] determines no reconciliation is needed for this shard.
pub(crate) consecutive_reconciles_count: usize,
/// If we have a pending compute notification that for some reason we weren't able to send, /// If we have a pending compute notification that for some reason we weren't able to send,
/// set this to true. If this is set, calls to [`Self::get_reconcile_needed`] will return Yes /// set this to true. If this is set, calls to [`Self::get_reconcile_needed`] will return Yes
@@ -603,7 +605,7 @@ impl TenantShard {
waiter: Arc::new(SeqWait::new(Sequence(0))), waiter: Arc::new(SeqWait::new(Sequence(0))),
error_waiter: Arc::new(SeqWait::new(Sequence(0))), error_waiter: Arc::new(SeqWait::new(Sequence(0))),
last_error: Arc::default(), last_error: Arc::default(),
consecutive_errors_count: 0, consecutive_reconciles_count: 0,
pending_compute_notification: false, pending_compute_notification: false,
scheduling_policy: ShardSchedulingPolicy::default(), scheduling_policy: ShardSchedulingPolicy::default(),
preferred_node: None, preferred_node: None,
@@ -1609,7 +1611,13 @@ impl TenantShard {
// Update result counter // Update result counter
let outcome_label = match &result { let outcome_label = match &result {
Ok(_) => ReconcileOutcome::Success, Ok(_) => {
if reconciler.compute_notify_failure {
ReconcileOutcome::SuccessNoNotify
} else {
ReconcileOutcome::Success
}
}
Err(ReconcileError::Cancel) => ReconcileOutcome::Cancel, Err(ReconcileError::Cancel) => ReconcileOutcome::Cancel,
Err(_) => ReconcileOutcome::Error, Err(_) => ReconcileOutcome::Error,
}; };
@@ -1908,7 +1916,7 @@ impl TenantShard {
waiter: Arc::new(SeqWait::new(Sequence::initial())), waiter: Arc::new(SeqWait::new(Sequence::initial())),
error_waiter: Arc::new(SeqWait::new(Sequence::initial())), error_waiter: Arc::new(SeqWait::new(Sequence::initial())),
last_error: Arc::default(), last_error: Arc::default(),
consecutive_errors_count: 0, consecutive_reconciles_count: 0,
pending_compute_notification: false, pending_compute_notification: false,
delayed_reconcile: false, delayed_reconcile: false,
scheduling_policy: serde_json::from_str(&tsp.scheduling_policy).unwrap(), scheduling_policy: serde_json::from_str(&tsp.scheduling_policy).unwrap(),

View File

@@ -2119,11 +2119,14 @@ class NeonStorageController(MetricsGetter, LogUtils):
headers=self.headers(TokenScope.ADMIN), headers=self.headers(TokenScope.ADMIN),
) )
def node_delete(self, node_id): def node_delete(self, node_id, force: bool = False):
log.info(f"node_delete({node_id})") log.info(f"node_delete({node_id})")
query = f"{self.api}/control/v1/node/{node_id}/delete"
if force:
query += "?force=true"
self.request( self.request(
"PUT", "PUT",
f"{self.api}/control/v1/node/{node_id}/delete", query,
headers=self.headers(TokenScope.ADMIN), headers=self.headers(TokenScope.ADMIN),
) )

View File

@@ -847,7 +847,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter):
return res_json return res_json
def timeline_lsn_lease( def timeline_lsn_lease(
self, tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, lsn: Lsn self, tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, lsn: Lsn, **kwargs
): ):
data = { data = {
"lsn": str(lsn), "lsn": str(lsn),
@@ -857,6 +857,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter):
res = self.post( res = self.post(
f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/lsn_lease", f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/lsn_lease",
json=data, json=data,
**kwargs,
) )
self.verbose_error(res) self.verbose_error(res)
res_json = res.json() res_json = res.json()

View File

@@ -187,19 +187,21 @@ def test_create_snapshot(
env.pageserver.stop() env.pageserver.stop()
env.storage_controller.stop() env.storage_controller.stop()
# Directory `compatibility_snapshot_dir` is uploaded to S3 in a workflow, keep the name in sync with it # Directory `new_compatibility_snapshot_dir` is uploaded to S3 in a workflow, keep the name in sync with it
compatibility_snapshot_dir = ( new_compatibility_snapshot_dir = (
top_output_dir / f"compatibility_snapshot_pg{pg_version.v_prefixed}" top_output_dir / f"compatibility_snapshot_pg{pg_version.v_prefixed}"
) )
if compatibility_snapshot_dir.exists(): if new_compatibility_snapshot_dir.exists():
shutil.rmtree(compatibility_snapshot_dir) shutil.rmtree(new_compatibility_snapshot_dir)
shutil.copytree( shutil.copytree(
test_output_dir, test_output_dir,
compatibility_snapshot_dir, new_compatibility_snapshot_dir,
ignore=shutil.ignore_patterns("pg_dynshmem"), ignore=shutil.ignore_patterns("pg_dynshmem"),
) )
log.info(f"Copied new compatibility snapshot dir to: {new_compatibility_snapshot_dir}")
# check_neon_works does recovery from WAL => the compatibility snapshot's WAL is old => will log this warning # check_neon_works does recovery from WAL => the compatibility snapshot's WAL is old => will log this warning
ingest_lag_log_line = ".*ingesting record with timestamp lagging more than wait_lsn_timeout.*" ingest_lag_log_line = ".*ingesting record with timestamp lagging more than wait_lsn_timeout.*"
@@ -218,6 +220,7 @@ def test_backward_compatibility(
""" """
Test that the new binaries can read old data Test that the new binaries can read old data
""" """
log.info(f"Using snapshot dir at {compatibility_snapshot_dir}")
neon_env_builder.num_safekeepers = 3 neon_env_builder.num_safekeepers = 3
env = neon_env_builder.from_repo_dir(compatibility_snapshot_dir / "repo") env = neon_env_builder.from_repo_dir(compatibility_snapshot_dir / "repo")
env.pageserver.allowed_errors.append(ingest_lag_log_line) env.pageserver.allowed_errors.append(ingest_lag_log_line)
@@ -242,7 +245,6 @@ def test_forward_compatibility(
test_output_dir: Path, test_output_dir: Path,
top_output_dir: Path, top_output_dir: Path,
pg_version: PgVersion, pg_version: PgVersion,
compatibility_snapshot_dir: Path,
compute_reconfigure_listener: ComputeReconfigure, compute_reconfigure_listener: ComputeReconfigure,
): ):
""" """
@@ -266,8 +268,14 @@ def test_forward_compatibility(
neon_env_builder.neon_binpath = neon_env_builder.compatibility_neon_binpath neon_env_builder.neon_binpath = neon_env_builder.compatibility_neon_binpath
neon_env_builder.pg_distrib_dir = neon_env_builder.compatibility_pg_distrib_dir neon_env_builder.pg_distrib_dir = neon_env_builder.compatibility_pg_distrib_dir
# Note that we are testing with new data, so we should use `new_compatibility_snapshot_dir`, which is created by test_create_snapshot.
new_compatibility_snapshot_dir = (
top_output_dir / f"compatibility_snapshot_pg{pg_version.v_prefixed}"
)
log.info(f"Using snapshot dir at {new_compatibility_snapshot_dir}")
env = neon_env_builder.from_repo_dir( env = neon_env_builder.from_repo_dir(
compatibility_snapshot_dir / "repo", new_compatibility_snapshot_dir / "repo",
) )
# there may be an arbitrary number of unrelated tests run between create_snapshot and here # there may be an arbitrary number of unrelated tests run between create_snapshot and here
env.pageserver.allowed_errors.append(ingest_lag_log_line) env.pageserver.allowed_errors.append(ingest_lag_log_line)
@@ -296,7 +304,7 @@ def test_forward_compatibility(
check_neon_works( check_neon_works(
env, env,
test_output_dir=test_output_dir, test_output_dir=test_output_dir,
sql_dump_path=compatibility_snapshot_dir / "dump.sql", sql_dump_path=new_compatibility_snapshot_dir / "dump.sql",
repo_dir=env.repo_dir, repo_dir=env.repo_dir,
) )

View File

@@ -246,9 +246,9 @@ def test_total_size_limit(neon_env_builder: NeonEnvBuilder):
system_memory = psutil.virtual_memory().total system_memory = psutil.virtual_memory().total
# The smallest total size limit we can configure is 1/1024th of the system memory (e.g. 128MB on # The smallest total size limit we can configure is 1/1024th of the system memory (e.g. 256MB on
# a system with 128GB of RAM). We will then write enough data to violate this limit. # a system with 256GB of RAM). We will then write enough data to violate this limit.
max_dirty_data = 128 * 1024 * 1024 max_dirty_data = 256 * 1024 * 1024
ephemeral_bytes_per_memory_kb = (max_dirty_data * 1024) // system_memory ephemeral_bytes_per_memory_kb = (max_dirty_data * 1024) // system_memory
assert ephemeral_bytes_per_memory_kb > 0 assert ephemeral_bytes_per_memory_kb > 0
@@ -272,7 +272,7 @@ def test_total_size_limit(neon_env_builder: NeonEnvBuilder):
timeline_count = 10 timeline_count = 10
# This is about 2MiB of data per timeline # This is about 2MiB of data per timeline
entries_per_timeline = 100_000 entries_per_timeline = 200_000
last_flush_lsns = asyncio.run(workload(env, tenant_conf, timeline_count, entries_per_timeline)) last_flush_lsns = asyncio.run(workload(env, tenant_conf, timeline_count, entries_per_timeline))
wait_until_pageserver_is_caught_up(env, last_flush_lsns) wait_until_pageserver_is_caught_up(env, last_flush_lsns)

View File

@@ -12,7 +12,7 @@ from typing import TYPE_CHECKING
import fixtures.utils import fixtures.utils
import pytest import pytest
from fixtures.auth_tokens import TokenScope from fixtures.auth_tokens import TokenScope
from fixtures.common_types import TenantId, TenantShardId, TimelineId from fixtures.common_types import Lsn, TenantId, TenantShardId, TimelineId
from fixtures.log_helper import log from fixtures.log_helper import log
from fixtures.neon_fixtures import ( from fixtures.neon_fixtures import (
DEFAULT_AZ_ID, DEFAULT_AZ_ID,
@@ -47,6 +47,7 @@ from fixtures.utils import (
wait_until, wait_until,
) )
from fixtures.workload import Workload from fixtures.workload import Workload
from requests.adapters import HTTPAdapter
from urllib3 import Retry from urllib3 import Retry
from werkzeug.wrappers.response import Response from werkzeug.wrappers.response import Response
@@ -72,6 +73,12 @@ def get_node_shard_counts(env: NeonEnv, tenant_ids):
return counts return counts
class DeletionAPIKind(Enum):
OLD = "old"
FORCE = "force"
GRACEFUL = "graceful"
@pytest.mark.parametrize(**fixtures.utils.allpairs_versions()) @pytest.mark.parametrize(**fixtures.utils.allpairs_versions())
def test_storage_controller_smoke( def test_storage_controller_smoke(
neon_env_builder: NeonEnvBuilder, compute_reconfigure_listener: ComputeReconfigure, combination neon_env_builder: NeonEnvBuilder, compute_reconfigure_listener: ComputeReconfigure, combination
@@ -990,7 +997,7 @@ def test_storage_controller_compute_hook_retry(
@run_only_on_default_postgres("postgres behavior is not relevant") @run_only_on_default_postgres("postgres behavior is not relevant")
def test_storage_controller_compute_hook_keep_failing( def test_storage_controller_compute_hook_stuck_reconciles(
httpserver: HTTPServer, httpserver: HTTPServer,
neon_env_builder: NeonEnvBuilder, neon_env_builder: NeonEnvBuilder,
httpserver_listen_address: ListenAddress, httpserver_listen_address: ListenAddress,
@@ -1040,7 +1047,7 @@ def test_storage_controller_compute_hook_keep_failing(
env.storage_controller.allowed_errors.append(NOTIFY_BLOCKED_LOG) env.storage_controller.allowed_errors.append(NOTIFY_BLOCKED_LOG)
env.storage_controller.allowed_errors.extend(NOTIFY_FAILURE_LOGS) env.storage_controller.allowed_errors.extend(NOTIFY_FAILURE_LOGS)
env.storage_controller.allowed_errors.append(".*Keeping extra secondaries.*") env.storage_controller.allowed_errors.append(".*Keeping extra secondaries.*")
env.storage_controller.allowed_errors.append(".*Shard reconciliation is keep-failing.*") env.storage_controller.allowed_errors.append(".*Shard reconciliation is stuck.*")
env.storage_controller.node_configure(banned_tenant_ps.id, {"availability": "Offline"}) env.storage_controller.node_configure(banned_tenant_ps.id, {"availability": "Offline"})
# Migrate all allowed tenant shards to the first alive pageserver # Migrate all allowed tenant shards to the first alive pageserver
@@ -1055,7 +1062,7 @@ def test_storage_controller_compute_hook_keep_failing(
# Make some reconcile_all calls to trigger optimizations # Make some reconcile_all calls to trigger optimizations
# RECONCILE_COUNT must be greater than storcon's MAX_CONSECUTIVE_RECONCILIATION_ERRORS # RECONCILE_COUNT must be greater than storcon's MAX_CONSECUTIVE_RECONCILIATION_ERRORS
RECONCILE_COUNT = 12 RECONCILE_COUNT = 20
for i in range(RECONCILE_COUNT): for i in range(RECONCILE_COUNT):
try: try:
n = env.storage_controller.reconcile_all() n = env.storage_controller.reconcile_all()
@@ -1068,6 +1075,8 @@ def test_storage_controller_compute_hook_keep_failing(
assert banned_descr["shards"][0]["is_pending_compute_notification"] is True assert banned_descr["shards"][0]["is_pending_compute_notification"] is True
time.sleep(2) time.sleep(2)
env.storage_controller.assert_log_contains(".*Shard reconciliation is stuck.*")
# Check that the allowed tenant shards are optimized due to affinity rules # Check that the allowed tenant shards are optimized due to affinity rules
locations = alive_pageservers[0].http_client().tenant_list_locations()["tenant_shards"] locations = alive_pageservers[0].http_client().tenant_list_locations()["tenant_shards"]
not_optimized_shard_count = 0 not_optimized_shard_count = 0
@@ -2572,9 +2581,11 @@ def test_background_operation_cancellation(neon_env_builder: NeonEnvBuilder):
@pytest.mark.parametrize("while_offline", [True, False]) @pytest.mark.parametrize("while_offline", [True, False])
@pytest.mark.parametrize("deletion_api", [DeletionAPIKind.OLD, DeletionAPIKind.FORCE])
def test_storage_controller_node_deletion( def test_storage_controller_node_deletion(
neon_env_builder: NeonEnvBuilder, neon_env_builder: NeonEnvBuilder,
while_offline: bool, while_offline: bool,
deletion_api: DeletionAPIKind,
): ):
""" """
Test that deleting a node works & properly reschedules everything that was on the node. Test that deleting a node works & properly reschedules everything that was on the node.
@@ -2598,6 +2609,8 @@ def test_storage_controller_node_deletion(
assert env.storage_controller.reconcile_all() == 0 assert env.storage_controller.reconcile_all() == 0
victim = env.pageservers[-1] victim = env.pageservers[-1]
if deletion_api == DeletionAPIKind.FORCE and not while_offline:
victim.allowed_errors.append(".*request was dropped before completing.*")
# The procedure a human would follow is: # The procedure a human would follow is:
# 1. Mark pageserver scheduling=pause # 1. Mark pageserver scheduling=pause
@@ -2621,7 +2634,12 @@ def test_storage_controller_node_deletion(
wait_until(assert_shards_migrated) wait_until(assert_shards_migrated)
log.info(f"Deleting pageserver {victim.id}") log.info(f"Deleting pageserver {victim.id}")
env.storage_controller.node_delete_old(victim.id) if deletion_api == DeletionAPIKind.FORCE:
env.storage_controller.node_delete(victim.id, force=True)
elif deletion_api == DeletionAPIKind.OLD:
env.storage_controller.node_delete_old(victim.id)
else:
raise AssertionError(f"Invalid deletion API: {deletion_api}")
if not while_offline: if not while_offline:
@@ -2634,7 +2652,15 @@ def test_storage_controller_node_deletion(
wait_until(assert_victim_evacuated) wait_until(assert_victim_evacuated)
# The node should be gone from the list API # The node should be gone from the list API
assert victim.id not in [n["id"] for n in env.storage_controller.node_list()] def assert_node_is_gone():
assert victim.id not in [n["id"] for n in env.storage_controller.node_list()]
if deletion_api == DeletionAPIKind.FORCE:
wait_until(assert_node_is_gone)
elif deletion_api == DeletionAPIKind.OLD:
assert_node_is_gone()
else:
raise AssertionError(f"Invalid deletion API: {deletion_api}")
# No tenants should refer to the node in their intent # No tenants should refer to the node in their intent
for tenant_id in tenant_ids: for tenant_id in tenant_ids:
@@ -2656,7 +2682,11 @@ def test_storage_controller_node_deletion(
env.storage_controller.consistency_check() env.storage_controller.consistency_check()
def test_storage_controller_node_delete_cancellation(neon_env_builder: NeonEnvBuilder): @pytest.mark.parametrize("deletion_api", [DeletionAPIKind.FORCE, DeletionAPIKind.GRACEFUL])
def test_storage_controller_node_delete_cancellation(
neon_env_builder: NeonEnvBuilder,
deletion_api: DeletionAPIKind,
):
neon_env_builder.num_pageservers = 3 neon_env_builder.num_pageservers = 3
neon_env_builder.num_azs = 3 neon_env_builder.num_azs = 3
env = neon_env_builder.init_configs() env = neon_env_builder.init_configs()
@@ -2680,12 +2710,16 @@ def test_storage_controller_node_delete_cancellation(neon_env_builder: NeonEnvBu
assert len(nodes) == 3 assert len(nodes) == 3
env.storage_controller.configure_failpoints(("sleepy-delete-loop", "return(10000)")) env.storage_controller.configure_failpoints(("sleepy-delete-loop", "return(10000)"))
env.storage_controller.configure_failpoints(("delete-node-after-reconciles-spawned", "pause"))
ps_id_to_delete = env.pageservers[0].id ps_id_to_delete = env.pageservers[0].id
env.storage_controller.warm_up_all_secondaries() env.storage_controller.warm_up_all_secondaries()
assert deletion_api in [DeletionAPIKind.FORCE, DeletionAPIKind.GRACEFUL]
force = deletion_api == DeletionAPIKind.FORCE
env.storage_controller.retryable_node_operation( env.storage_controller.retryable_node_operation(
lambda ps_id: env.storage_controller.node_delete(ps_id), lambda ps_id: env.storage_controller.node_delete(ps_id, force),
ps_id_to_delete, ps_id_to_delete,
max_attempts=3, max_attempts=3,
backoff=2, backoff=2,
@@ -2701,6 +2735,8 @@ def test_storage_controller_node_delete_cancellation(neon_env_builder: NeonEnvBu
env.storage_controller.cancel_node_delete(ps_id_to_delete) env.storage_controller.cancel_node_delete(ps_id_to_delete)
env.storage_controller.configure_failpoints(("delete-node-after-reconciles-spawned", "off"))
env.storage_controller.poll_node_status( env.storage_controller.poll_node_status(
ps_id_to_delete, ps_id_to_delete,
PageserverAvailability.ACTIVE, PageserverAvailability.ACTIVE,
@@ -3252,7 +3288,10 @@ def test_storage_controller_ps_restarted_during_drain(neon_env_builder: NeonEnvB
wait_until(reconfigure_node_again) wait_until(reconfigure_node_again)
def test_ps_unavailable_after_delete(neon_env_builder: NeonEnvBuilder): @pytest.mark.parametrize("deletion_api", [DeletionAPIKind.OLD, DeletionAPIKind.FORCE])
def test_ps_unavailable_after_delete(
neon_env_builder: NeonEnvBuilder, deletion_api: DeletionAPIKind
):
neon_env_builder.num_pageservers = 3 neon_env_builder.num_pageservers = 3
env = neon_env_builder.init_start() env = neon_env_builder.init_start()
@@ -3265,10 +3304,16 @@ def test_ps_unavailable_after_delete(neon_env_builder: NeonEnvBuilder):
assert_nodes_count(3) assert_nodes_count(3)
ps = env.pageservers[0] ps = env.pageservers[0]
env.storage_controller.node_delete_old(ps.id)
# After deletion, the node count must be reduced if deletion_api == DeletionAPIKind.FORCE:
assert_nodes_count(2) ps.allowed_errors.append(".*request was dropped before completing.*")
env.storage_controller.node_delete(ps.id, force=True)
wait_until(lambda: assert_nodes_count(2))
elif deletion_api == DeletionAPIKind.OLD:
env.storage_controller.node_delete_old(ps.id)
assert_nodes_count(2)
else:
raise AssertionError(f"Invalid deletion API: {deletion_api}")
# Running pageserver CLI init in a separate thread # Running pageserver CLI init in a separate thread
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
@@ -4814,3 +4859,103 @@ def test_storage_controller_migrate_with_pageserver_restart(
"shards": [{"node_id": int(secondary.id), "shard_number": 0}], "shards": [{"node_id": int(secondary.id), "shard_number": 0}],
"preferred_az": DEFAULT_AZ_ID, "preferred_az": DEFAULT_AZ_ID,
} }
@run_only_on_default_postgres("PG version is not important for this test")
def test_storage_controller_forward_404(neon_env_builder: NeonEnvBuilder):
"""
Ensures that the storage controller correctly forwards 404s and converts some of them
into 503s before forwarding to the client.
"""
neon_env_builder.num_pageservers = 2
neon_env_builder.num_azs = 2
env = neon_env_builder.init_start()
env.storage_controller.allowed_errors.append(".*Reconcile error.*")
env.storage_controller.allowed_errors.append(".*Timed out.*")
env.storage_controller.tenant_policy_update(env.initial_tenant, {"placement": {"Attached": 1}})
env.storage_controller.reconcile_until_idle()
# 404s on tenants and timelines are forwarded as-is when reconciler is not running.
# Access a non-existing timeline -> 404
with pytest.raises(PageserverApiException) as e:
env.storage_controller.pageserver_api().timeline_detail(
env.initial_tenant, TimelineId.generate()
)
assert e.value.status_code == 404
with pytest.raises(PageserverApiException) as e:
env.storage_controller.pageserver_api().timeline_lsn_lease(
env.initial_tenant, TimelineId.generate(), Lsn(0)
)
assert e.value.status_code == 404
# Access a non-existing tenant when reconciler is not running -> 404
with pytest.raises(PageserverApiException) as e:
env.storage_controller.pageserver_api().timeline_detail(
TenantId.generate(), env.initial_timeline
)
assert e.value.status_code == 404
with pytest.raises(PageserverApiException) as e:
env.storage_controller.pageserver_api().timeline_lsn_lease(
TenantId.generate(), env.initial_timeline, Lsn(0)
)
assert e.value.status_code == 404
# Normal requests should succeed
detail = env.storage_controller.pageserver_api().timeline_detail(
env.initial_tenant, env.initial_timeline
)
last_record_lsn = Lsn(detail["last_record_lsn"])
env.storage_controller.pageserver_api().timeline_lsn_lease(
env.initial_tenant, env.initial_timeline, last_record_lsn
)
# Get into a situation where the intent state is not the same as the observed state.
describe = env.storage_controller.tenant_describe(env.initial_tenant)["shards"][0]
current_primary = describe["node_attached"]
current_secondary = describe["node_secondary"][0]
assert current_primary != current_secondary
# Pause the reconciler so that the generation number won't be updated.
env.storage_controller.configure_failpoints(
("reconciler-live-migrate-post-generation-inc", "pause")
)
# Do the migration in another thread; the request will be dropped as we don't wait.
shard_zero = TenantShardId(env.initial_tenant, 0, 0)
concurrent.futures.ThreadPoolExecutor(max_workers=1).submit(
env.storage_controller.tenant_shard_migrate,
shard_zero,
current_secondary,
StorageControllerMigrationConfig(override_scheduler=True),
)
# Not the best way to do this, we should wait until the migration gets started.
time.sleep(1)
placement = env.storage_controller.get_tenants_placement()[str(shard_zero)]
assert placement["observed"] != placement["intent"]
assert placement["observed"]["attached"] == current_primary
assert placement["intent"]["attached"] == current_secondary
# Now we issue requests that would cause 404 again
retry_strategy = Retry(total=0)
adapter = HTTPAdapter(max_retries=retry_strategy)
no_retry_api = env.storage_controller.pageserver_api()
no_retry_api.mount("http://", adapter)
no_retry_api.mount("https://", adapter)
# As intent state != observed state, tenant not found error should return 503,
# so that the client can retry once we've successfully migrated.
with pytest.raises(PageserverApiException) as e:
no_retry_api.timeline_detail(env.initial_tenant, TimelineId.generate())
assert e.value.status_code == 503, f"unexpected status code and error: {e.value}"
with pytest.raises(PageserverApiException) as e:
no_retry_api.timeline_lsn_lease(env.initial_tenant, TimelineId.generate(), Lsn(0))
assert e.value.status_code == 503, f"unexpected status code and error: {e.value}"
# Unblock reconcile operations
env.storage_controller.configure_failpoints(
("reconciler-live-migrate-post-generation-inc", "off")
)

View File

@@ -107,7 +107,6 @@ tracing-core = { version = "0.1" }
tracing-log = { version = "0.2" } tracing-log = { version = "0.2" }
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] } tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
url = { version = "2", features = ["serde"] } url = { version = "2", features = ["serde"] }
uuid = { version = "1", features = ["serde", "v4", "v7"] }
zeroize = { version = "1", features = ["derive", "serde"] } zeroize = { version = "1", features = ["derive", "serde"] }
zstd = { version = "0.13" } zstd = { version = "0.13" }
zstd-safe = { version = "7", default-features = false, features = ["arrays", "legacy", "std", "zdict_builder"] } zstd-safe = { version = "7", default-features = false, features = ["arrays", "legacy", "std", "zdict_builder"] }