diff --git a/Cargo.lock b/Cargo.lock index ca3c351555..9cb48b51be 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1951,6 +1951,7 @@ dependencies = [ "diesel_derives", "itoa", "serde_json", + "uuid", ] [[package]] @@ -3962,7 +3963,6 @@ version = "0.1.0" dependencies = [ "ahash", "criterion", - "foldhash", "hashbrown 0.15.4", "libc", "lock_api", @@ -5509,7 +5509,7 @@ dependencies = [ "reqwest-tracing", "rsa", "rstest", - "rustc-hash 1.1.0", + "rustc-hash 2.1.1", "rustls 0.23.27", "rustls-native-certs 0.8.0", "rustls-pemfile 2.1.1", @@ -7138,6 +7138,7 @@ dependencies = [ "tokio-util", "tracing", "utils", + "uuid", "workspace_hack", ] @@ -8444,6 +8445,7 @@ dependencies = [ "tracing-error", "tracing-subscriber", "tracing-utils", + "uuid", "walkdir", ] @@ -9063,7 +9065,6 @@ dependencies = [ "tracing-log", "tracing-subscriber", "url", - "uuid", "zeroize", "zstd", "zstd-safe", diff --git a/Cargo.toml b/Cargo.toml index b622892392..76a1a57aa9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -132,6 +132,7 @@ jemalloc_pprof = { version = "0.7", features = ["symbolize", "flamegraph"] } jsonwebtoken = "9" lasso = "0.7" libc = "0.2" +lock_api = "0.4.13" md5 = "0.7.0" measured = { version = "0.0.22", features=["lasso"] } measured-process = { version = "0.0.22" } @@ -168,7 +169,7 @@ reqwest-middleware = "0.4" reqwest-retry = "0.7" routerify = "3" rpds = "0.13" -rustc-hash = "1.1.0" +rustc-hash = "2.1.1" rustls = { version = "0.23.16", default-features = false } rustls-pemfile = "2" rustls-pki-types = "1.11" diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index 509551897f..d7ec37cc0a 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -1,4 +1,4 @@ -use anyhow::{Context, Result, anyhow}; +use anyhow::{Context, Result}; use chrono::{DateTime, Utc}; use compute_api::privilege::Privilege; use compute_api::responses::{ @@ -7,7 +7,7 @@ use compute_api::responses::{ }; use compute_api::spec::{ ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, ExtVersion, PageserverConnectionInfo, - PageserverShardConnectionInfo, PgIdent, + PageserverProtocol, PageserverShardConnectionInfo, PageserverShardInfo, PgIdent, }; use futures::StreamExt; use futures::future::join_all; @@ -281,53 +281,114 @@ impl ParsedSpec { } } -fn extract_pageserver_conninfo_from_guc( - pageserver_connstring_guc: &str, -) -> PageserverConnectionInfo { - PageserverConnectionInfo { - shards: pageserver_connstring_guc - .split(',') - .enumerate() - .map(|(i, connstr)| { - ( - i as u32, - PageserverShardConnectionInfo { - libpq_url: Some(connstr.to_string()), - grpc_url: None, - }, - ) +/// Extract PageserverConnectionInfo from a comma-separated list of libpq connection strings. +/// +/// This is used for backwards-compatilibity, to parse the legacye `pageserver_connstr` +/// field in the compute spec, or the 'neon.pageserver_connstring' GUC. Nowadays, the +/// 'pageserver_connection_info' field should be used instead. +fn extract_pageserver_conninfo_from_connstr( + connstr: &str, + stripe_size: Option, +) -> Result { + let shard_infos: Vec<_> = connstr + .split(',') + .map(|connstr| PageserverShardInfo { + pageservers: vec![PageserverShardConnectionInfo { + id: None, + libpq_url: Some(connstr.to_string()), + grpc_url: None, + }], + }) + .collect(); + + match shard_infos.len() { + 0 => anyhow::bail!("empty connection string"), + 1 => { + // We assume that if there's only connection string, it means "unsharded", + // rather than a sharded system with just a single shard. The latter is + // possible in principle, but we never do it. + let shard_count = ShardCount::unsharded(); + let only_shard = shard_infos.first().unwrap().clone(); + let shards = vec![(ShardIndex::unsharded(), only_shard)]; + Ok(PageserverConnectionInfo { + shard_count, + stripe_size: None, + shards: shards.into_iter().collect(), + prefer_protocol: PageserverProtocol::Libpq, }) - .collect(), - prefer_grpc: false, + } + n => { + if stripe_size.is_none() { + anyhow::bail!("{n} shards but no stripe_size"); + } + let shard_count = ShardCount(n.try_into()?); + let shards = shard_infos + .into_iter() + .enumerate() + .map(|(idx, shard_info)| { + ( + ShardIndex { + shard_count, + shard_number: ShardNumber( + idx.try_into().expect("shard number fits in u8"), + ), + }, + shard_info, + ) + }) + .collect(); + Ok(PageserverConnectionInfo { + shard_count, + stripe_size, + shards, + prefer_protocol: PageserverProtocol::Libpq, + }) + } } } impl TryFrom for ParsedSpec { - type Error = String; - fn try_from(spec: ComputeSpec) -> Result { + type Error = anyhow::Error; + fn try_from(spec: ComputeSpec) -> Result { // Extract the options from the spec file that are needed to connect to // the storage system. // - // For backwards-compatibility, the top-level fields in the spec file - // may be empty. In that case, we need to dig them from the GUCs in the - // cluster.settings field. - let pageserver_conninfo = match &spec.pageserver_connection_info { - Some(x) => x.clone(), - None => { - if let Some(guc) = spec.cluster.settings.find("neon.pageserver_connstring") { - extract_pageserver_conninfo_from_guc(&guc) - } else { - return Err("pageserver connstr should be provided".to_string()); - } + // In compute specs generated by old control plane versions, the spec file might + // be missing the `pageserver_connection_info` field. In that case, we need to dig + // the pageserver connection info from the `pageserver_connstr` field instead, or + // if that's missing too, from the GUC in the cluster.settings field. + let mut pageserver_conninfo = spec.pageserver_connection_info.clone(); + if pageserver_conninfo.is_none() { + if let Some(pageserver_connstr_field) = &spec.pageserver_connstring { + pageserver_conninfo = Some(extract_pageserver_conninfo_from_connstr( + pageserver_connstr_field, + spec.shard_stripe_size, + )?); } - }; + } + if pageserver_conninfo.is_none() { + if let Some(guc) = spec.cluster.settings.find("neon.pageserver_connstring") { + let stripe_size = if let Some(guc) = spec.cluster.settings.find("neon.stripe_size") + { + Some(u32::from_str(&guc)?) + } else { + None + }; + pageserver_conninfo = + Some(extract_pageserver_conninfo_from_connstr(&guc, stripe_size)?); + } + } + let pageserver_conninfo = pageserver_conninfo.ok_or(anyhow::anyhow!( + "pageserver connection information should be provided" + ))?; + // Similarly for safekeeper connection strings let safekeeper_connstrings = if spec.safekeeper_connstrings.is_empty() { if matches!(spec.mode, ComputeMode::Primary) { spec.cluster .settings .find("neon.safekeepers") - .ok_or("safekeeper connstrings should be provided")? + .ok_or(anyhow::anyhow!("safekeeper connstrings should be provided"))? .split(',') .map(|str| str.to_string()) .collect() @@ -342,22 +403,22 @@ impl TryFrom for ParsedSpec { let tenant_id: TenantId = if let Some(tenant_id) = spec.tenant_id { tenant_id } else { - spec.cluster + let guc = spec + .cluster .settings .find("neon.tenant_id") - .ok_or("tenant id should be provided") - .map(|s| TenantId::from_str(&s))? - .or(Err("invalid tenant id"))? + .ok_or(anyhow::anyhow!("tenant id should be provided"))?; + TenantId::from_str(&guc).context("invalid tenant id")? }; let timeline_id: TimelineId = if let Some(timeline_id) = spec.timeline_id { timeline_id } else { - spec.cluster + let guc = spec + .cluster .settings .find("neon.timeline_id") - .ok_or("timeline id should be provided") - .map(|s| TimelineId::from_str(&s))? - .or(Err("invalid timeline id"))? + .ok_or(anyhow::anyhow!("timeline id should be provided"))?; + TimelineId::from_str(&guc).context(anyhow::anyhow!("invalid timeline id"))? }; let endpoint_storage_addr: Option = spec @@ -381,7 +442,7 @@ impl TryFrom for ParsedSpec { }; // Now check validity of the parsed specification - res.validate()?; + res.validate().map_err(anyhow::Error::msg)?; Ok(res) } } @@ -461,7 +522,7 @@ impl ComputeNode { let mut new_state = ComputeState::new(); if let Some(spec) = config.spec { - let pspec = ParsedSpec::try_from(spec).map_err(|msg| anyhow!(msg))?; + let pspec = ParsedSpec::try_from(spec).map_err(|msg| anyhow::anyhow!(msg))?; new_state.pspec = Some(pspec); } @@ -1069,10 +1130,9 @@ impl ComputeNode { let spec = compute_state.pspec.as_ref().expect("spec must be set"); let started = Instant::now(); - let (connected, size) = if spec.pageserver_conninfo.prefer_grpc { - self.try_get_basebackup_grpc(spec, lsn)? - } else { - self.try_get_basebackup_libpq(spec, lsn)? + let (connected, size) = match spec.pageserver_conninfo.prefer_protocol { + PageserverProtocol::Grpc => self.try_get_basebackup_grpc(spec, lsn)?, + PageserverProtocol::Libpq => self.try_get_basebackup_libpq(spec, lsn)?, }; self.fix_zenith_signal_neon_signal()?; @@ -1110,24 +1170,32 @@ impl ComputeNode { /// Fetches a basebackup via gRPC. The connstring must use grpc://. Returns the timestamp when /// the connection was established, and the (compressed) size of the basebackup. fn try_get_basebackup_grpc(&self, spec: &ParsedSpec, lsn: Lsn) -> Result<(Instant, usize)> { + let shard0_index = ShardIndex { + shard_number: ShardNumber(0), + shard_count: spec.pageserver_conninfo.shard_count, + }; let shard0 = spec .pageserver_conninfo .shards - .get(&0) - .expect("shard 0 connection info missing"); - let shard0_url = shard0.grpc_url.clone().expect("no grpc_url for shard 0"); - - let shard_index = match spec.pageserver_conninfo.shards.len() as u8 { - 0 | 1 => ShardIndex::unsharded(), - count => ShardIndex::new(ShardNumber(0), ShardCount(count)), - }; + .get(&shard0_index) + .ok_or_else(|| { + anyhow::anyhow!("shard connection info missing for shard {}", shard0_index) + })?; + let pageserver = shard0 + .pageservers + .first() + .expect("must have at least one pageserver"); + let shard0_url = pageserver + .grpc_url + .clone() + .expect("no grpc_url for shard 0"); let (reader, connected) = tokio::runtime::Handle::current().block_on(async move { let mut client = page_api::Client::connect( shard0_url, spec.tenant_id, spec.timeline_id, - shard_index, + shard0_index, spec.storage_auth_token.clone(), None, // NB: base backups use payload compression ) @@ -1159,12 +1227,25 @@ impl ComputeNode { /// Fetches a basebackup via libpq. The connstring must use postgresql://. Returns the timestamp /// when the connection was established, and the (compressed) size of the basebackup. fn try_get_basebackup_libpq(&self, spec: &ParsedSpec, lsn: Lsn) -> Result<(Instant, usize)> { + let shard0_index = ShardIndex { + shard_number: ShardNumber(0), + shard_count: spec.pageserver_conninfo.shard_count, + }; let shard0 = spec .pageserver_conninfo .shards - .get(&0) - .expect("shard 0 connection info missing"); - let shard0_connstr = shard0.libpq_url.clone().expect("no libpq_url for shard 0"); + .get(&shard0_index) + .ok_or_else(|| { + anyhow::anyhow!("shard connection info missing for shard {}", shard0_index) + })?; + let pageserver = shard0 + .pageservers + .first() + .expect("must have at least one pageserver"); + let shard0_connstr = pageserver + .libpq_url + .clone() + .expect("no libpq_url for shard 0"); let mut config = postgres::Config::from_str(&shard0_connstr)?; // Use the storage auth token from the config file, if given. @@ -2128,7 +2209,7 @@ LIMIT 100", self.params .remote_ext_base_url .as_ref() - .ok_or(DownloadError::BadInput(anyhow!( + .ok_or(DownloadError::BadInput(anyhow::anyhow!( "Remote extensions storage is not configured", )))?; @@ -2324,7 +2405,7 @@ LIMIT 100", let remote_extensions = spec .remote_extensions .as_ref() - .ok_or(anyhow!("Remote extensions are not configured"))?; + .ok_or(anyhow::anyhow!("Remote extensions are not configured"))?; info!("parse shared_preload_libraries from spec.cluster.settings"); let mut libs_vec = Vec::new(); @@ -2472,14 +2553,31 @@ LIMIT 100", pub fn spawn_lfc_offload_task(self: &Arc, interval: Duration) { self.terminate_lfc_offload_task(); let secs = interval.as_secs(); - info!("spawning lfc offload worker with {secs}s interval"); let this = self.clone(); + + info!("spawning LFC offload worker with {secs}s interval"); let handle = spawn(async move { let mut interval = time::interval(interval); interval.tick().await; // returns immediately loop { interval.tick().await; - this.offload_lfc_async().await; + + let prewarm_state = this.state.lock().unwrap().lfc_prewarm_state.clone(); + // Do not offload LFC state if we are currently prewarming or any issue occurred. + // If we'd do that, we might override the LFC state in endpoint storage with some + // incomplete state. Imagine a situation: + // 1. Endpoint started with `autoprewarm: true` + // 2. While prewarming is not completed, we upload the new incomplete state + // 3. Compute gets interrupted and restarts + // 4. We start again and try to prewarm with the state from 2. instead of the previous complete state + if matches!( + prewarm_state, + LfcPrewarmState::Completed + | LfcPrewarmState::NotPrewarmed + | LfcPrewarmState::Skipped + ) { + this.offload_lfc_async().await; + } } }); *self.lfc_offload_task.lock().unwrap() = Some(handle); @@ -2631,7 +2729,10 @@ mod tests { match ParsedSpec::try_from(spec.clone()) { Ok(_p) => panic!("Failed to detect duplicate entry"), - Err(e) => assert!(e.starts_with("duplicate entry in safekeeper_connstrings:")), + Err(e) => assert!( + e.to_string() + .starts_with("duplicate entry in safekeeper_connstrings:") + ), }; } } diff --git a/compute_tools/src/compute_prewarm.rs b/compute_tools/src/compute_prewarm.rs index d014a5bb72..07b4a596cc 100644 --- a/compute_tools/src/compute_prewarm.rs +++ b/compute_tools/src/compute_prewarm.rs @@ -89,7 +89,7 @@ impl ComputeNode { self.state.lock().unwrap().lfc_offload_state.clone() } - /// If there is a prewarm request ongoing, return false, true otherwise + /// If there is a prewarm request ongoing, return `false`, `true` otherwise. pub fn prewarm_lfc(self: &Arc, from_endpoint: Option) -> bool { { let state = &mut self.state.lock().unwrap().lfc_prewarm_state; @@ -101,15 +101,25 @@ impl ComputeNode { let cloned = self.clone(); spawn(async move { - let Err(err) = cloned.prewarm_impl(from_endpoint).await else { - cloned.state.lock().unwrap().lfc_prewarm_state = LfcPrewarmState::Completed; - return; - }; - crate::metrics::LFC_PREWARM_ERRORS.inc(); - error!(%err, "prewarming lfc"); - cloned.state.lock().unwrap().lfc_prewarm_state = LfcPrewarmState::Failed { - error: err.to_string(), + let state = match cloned.prewarm_impl(from_endpoint).await { + Ok(true) => LfcPrewarmState::Completed, + Ok(false) => { + info!( + "skipping LFC prewarm because LFC state is not found in endpoint storage" + ); + LfcPrewarmState::Skipped + } + Err(err) => { + crate::metrics::LFC_PREWARM_ERRORS.inc(); + error!(%err, "could not prewarm LFC"); + + LfcPrewarmState::Failed { + error: err.to_string(), + } + } }; + + cloned.state.lock().unwrap().lfc_prewarm_state = state; }); true } @@ -120,15 +130,21 @@ impl ComputeNode { EndpointStoragePair::from_spec_and_endpoint(state.pspec.as_ref().unwrap(), from_endpoint) } - async fn prewarm_impl(&self, from_endpoint: Option) -> Result<()> { + /// Request LFC state from endpoint storage and load corresponding pages into Postgres. + /// Returns a result with `false` if the LFC state is not found in endpoint storage. + async fn prewarm_impl(&self, from_endpoint: Option) -> Result { let EndpointStoragePair { url, token } = self.endpoint_storage_pair(from_endpoint)?; - info!(%url, "requesting LFC state from endpoint storage"); + info!(%url, "requesting LFC state from endpoint storage"); let request = Client::new().get(&url).bearer_auth(token); let res = request.send().await.context("querying endpoint storage")?; let status = res.status(); - if status != StatusCode::OK { - bail!("{status} querying endpoint storage") + match status { + StatusCode::OK => (), + StatusCode::NOT_FOUND => { + return Ok(false); + } + _ => bail!("{status} querying endpoint storage"), } let mut uncompressed = Vec::new(); @@ -141,7 +157,8 @@ impl ComputeNode { .await .context("decoding LFC state")?; let uncompressed_len = uncompressed.len(); - info!(%url, "downloaded LFC state, uncompressed size {uncompressed_len}, loading into postgres"); + + info!(%url, "downloaded LFC state, uncompressed size {uncompressed_len}, loading into Postgres"); ComputeNode::get_maintenance_client(&self.tokio_conn_conf) .await @@ -149,7 +166,9 @@ impl ComputeNode { .query_one("select neon.prewarm_local_cache($1)", &[&uncompressed]) .await .context("loading LFC state into postgres") - .map(|_| ()) + .map(|_| ())?; + + Ok(true) } /// If offload request is ongoing, return false, true otherwise @@ -177,12 +196,14 @@ impl ComputeNode { async fn offload_lfc_with_state_update(&self) { crate::metrics::LFC_OFFLOADS.inc(); + let Err(err) = self.offload_lfc_impl().await else { self.state.lock().unwrap().lfc_offload_state = LfcOffloadState::Completed; return; }; + crate::metrics::LFC_OFFLOAD_ERRORS.inc(); - error!(%err, "offloading lfc"); + error!(%err, "could not offload LFC state to endpoint storage"); self.state.lock().unwrap().lfc_offload_state = LfcOffloadState::Failed { error: err.to_string(), }; @@ -190,7 +211,7 @@ impl ComputeNode { async fn offload_lfc_impl(&self) -> Result<()> { let EndpointStoragePair { url, token } = self.endpoint_storage_pair(None)?; - info!(%url, "requesting LFC state from postgres"); + info!(%url, "requesting LFC state from Postgres"); let mut compressed = Vec::new(); ComputeNode::get_maintenance_client(&self.tokio_conn_conf) @@ -205,13 +226,17 @@ impl ComputeNode { .read_to_end(&mut compressed) .await .context("compressing LFC state")?; + let compressed_len = compressed.len(); info!(%url, "downloaded LFC state, compressed size {compressed_len}, writing to endpoint storage"); let request = Client::new().put(url).bearer_auth(token).body(compressed); match request.send().await { Ok(res) if res.status() == StatusCode::OK => Ok(()), - Ok(res) => bail!("Error writing to endpoint storage: {}", res.status()), + Ok(res) => bail!( + "Request to endpoint storage failed with status: {}", + res.status() + ), Err(err) => Err(err).context("writing to endpoint storage"), } } diff --git a/compute_tools/src/config.rs b/compute_tools/src/config.rs index 9da05f6598..8821611f0c 100644 --- a/compute_tools/src/config.rs +++ b/compute_tools/src/config.rs @@ -15,6 +15,8 @@ use crate::pg_helpers::{ }; use crate::tls::{self, SERVER_CRT, SERVER_KEY}; +use utils::shard::{ShardIndex, ShardNumber}; + /// Check that `line` is inside a text file and put it there if it is not. /// Create file if it doesn't exist. pub fn line_in_file(path: &Path, line: &str) -> Result { @@ -58,24 +60,53 @@ pub fn write_postgres_conf( // Add options for connecting to storage writeln!(file, "# Neon storage settings")?; - + writeln!(file)?; if let Some(conninfo) = &spec.pageserver_connection_info { + // Stripe size GUC should be defined prior to connection string + if let Some(stripe_size) = conninfo.stripe_size { + writeln!( + file, + "# from compute spec's pageserver_conninfo.stripe_size field" + )?; + writeln!(file, "neon.stripe_size={stripe_size}")?; + } + let mut libpq_urls: Option> = Some(Vec::new()); let mut grpc_urls: Option> = Some(Vec::new()); + let num_shards = if conninfo.shard_count.0 == 0 { + 1 // unsharded, treat it as a single shard + } else { + conninfo.shard_count.0 + }; - for shardno in 0..conninfo.shards.len() { - let info = conninfo.shards.get(&(shardno as u32)).ok_or_else(|| { - anyhow::anyhow!("shard {shardno} missing from pageserver_connection_info shard map") + for shard_number in 0..num_shards { + let shard_index = ShardIndex { + shard_number: ShardNumber(shard_number), + shard_count: conninfo.shard_count, + }; + let info = conninfo.shards.get(&shard_index).ok_or_else(|| { + anyhow::anyhow!( + "shard {shard_index} missing from pageserver_connection_info shard map" + ) })?; - if let Some(url) = &info.libpq_url { + let first_pageserver = info + .pageservers + .first() + .expect("must have at least one pageserver"); + + // Add the libpq URL to the array, or if the URL is missing, reset the array + // forgetting any previous entries. All servers must have a libpq URL, or none + // at all. + if let Some(url) = &first_pageserver.libpq_url { if let Some(ref mut urls) = libpq_urls { urls.push(url.clone()); } } else { libpq_urls = None } - if let Some(url) = &info.grpc_url { + // Similarly for gRPC URLs + if let Some(url) = &first_pageserver.grpc_url { if let Some(ref mut urls) = grpc_urls { urls.push(url.clone()); } @@ -84,6 +115,10 @@ pub fn write_postgres_conf( } } if let Some(libpq_urls) = libpq_urls { + writeln!( + file, + "# derived from compute spec's pageserver_conninfo field" + )?; writeln!( file, "neon.pageserver_connstring={}", @@ -93,6 +128,10 @@ pub fn write_postgres_conf( writeln!(file, "# no neon.pageserver_connstring")?; } if let Some(grpc_urls) = grpc_urls { + writeln!( + file, + "# derived from compute spec's pageserver_conninfo field" + )?; writeln!( file, "neon.pageserver_grpc_urls={}", @@ -101,11 +140,19 @@ pub fn write_postgres_conf( } else { writeln!(file, "# no neon.pageserver_grpc_urls")?; } + } else { + // Stripe size GUC should be defined prior to connection string + if let Some(stripe_size) = spec.shard_stripe_size { + writeln!(file, "# from compute spec's shard_stripe_size field")?; + writeln!(file, "neon.stripe_size={stripe_size}")?; + } + + if let Some(s) = &spec.pageserver_connstring { + writeln!(file, "# from compute spec's pageserver_connstring field")?; + writeln!(file, "neon.pageserver_connstring={}", escape_conf_value(s))?; + } } - if let Some(stripe_size) = spec.shard_stripe_size { - writeln!(file, "neon.stripe_size={stripe_size}")?; - } if !spec.safekeeper_connstrings.is_empty() { let mut neon_safekeepers_value = String::new(); tracing::info!( diff --git a/compute_tools/src/http/openapi_spec.yaml b/compute_tools/src/http/openapi_spec.yaml index 93a357e160..3cf5ea7c51 100644 --- a/compute_tools/src/http/openapi_spec.yaml +++ b/compute_tools/src/http/openapi_spec.yaml @@ -613,11 +613,11 @@ components: - skipped properties: status: - description: Lfc prewarm status - enum: [not_prewarmed, prewarming, completed, failed] + description: LFC prewarm status + enum: [not_prewarmed, prewarming, completed, failed, skipped] type: string error: - description: Lfc prewarm error, if any + description: LFC prewarm error, if any type: string total: description: Total pages processed @@ -635,11 +635,11 @@ components: - status properties: status: - description: Lfc offload status + description: LFC offload status enum: [not_offloaded, offloading, completed, failed] type: string error: - description: Lfc offload error, if any + description: LFC offload error, if any type: string PromoteState: diff --git a/compute_tools/src/lsn_lease.rs b/compute_tools/src/lsn_lease.rs index 241b4cf467..6abfea82e0 100644 --- a/compute_tools/src/lsn_lease.rs +++ b/compute_tools/src/lsn_lease.rs @@ -4,13 +4,13 @@ use std::thread; use std::time::{Duration, SystemTime}; use anyhow::{Result, bail}; -use compute_api::spec::{ComputeMode, PageserverConnectionInfo}; +use compute_api::spec::{ComputeMode, PageserverConnectionInfo, PageserverProtocol}; use pageserver_page_api as page_api; use postgres::{NoTls, SimpleQueryMessage}; use tracing::{info, warn}; use utils::id::{TenantId, TimelineId}; use utils::lsn::Lsn; -use utils::shard::{ShardCount, ShardNumber, TenantShardId}; +use utils::shard::TenantShardId; use crate::compute::ComputeNode; @@ -116,37 +116,38 @@ fn try_acquire_lsn_lease( timeline_id: TimelineId, lsn: Lsn, ) -> Result> { - let shard_count = conninfo.shards.len(); let mut leases = Vec::new(); - for (shard_number, shard) in conninfo.shards.into_iter() { - let tenant_shard_id = match shard_count { - 0 | 1 => TenantShardId::unsharded(tenant_id), - shard_count => TenantShardId { - tenant_id, - shard_number: ShardNumber(shard_number as u8), - shard_count: ShardCount::new(shard_count as u8), - }, + for (shard_index, shard) in conninfo.shards.into_iter() { + let tenant_shard_id = TenantShardId { + tenant_id, + shard_number: shard_index.shard_number, + shard_count: shard_index.shard_count, }; - let lease = if conninfo.prefer_grpc { - acquire_lsn_lease_grpc( - &shard.grpc_url.unwrap(), - auth, - tenant_shard_id, - timeline_id, - lsn, - )? - } else { - acquire_lsn_lease_libpq( - &shard.libpq_url.unwrap(), - auth, - tenant_shard_id, - timeline_id, - lsn, - )? - }; - leases.push(lease); + // XXX: If there are more than pageserver for the one shard, do we need to get a + // leas on all of them? Currently, that's what we assume, but this is hypothetical + // as of this writing, as we never pass the info for more than one pageserver per + // shard. + for pageserver in shard.pageservers { + let lease = match conninfo.prefer_protocol { + PageserverProtocol::Grpc => acquire_lsn_lease_grpc( + &pageserver.grpc_url.unwrap(), + auth, + tenant_shard_id, + timeline_id, + lsn, + )?, + PageserverProtocol::Libpq => acquire_lsn_lease_libpq( + &pageserver.libpq_url.unwrap(), + auth, + tenant_shard_id, + timeline_id, + lsn, + )?, + }; + leases.push(lease); + } } Ok(leases.into_iter().min().flatten()) diff --git a/control_plane/README.md b/control_plane/README.md index aa6f935e27..60c6120d82 100644 --- a/control_plane/README.md +++ b/control_plane/README.md @@ -8,10 +8,10 @@ code changes locally, but not suitable for running production systems. ## Example: Start with Postgres 16 -To create and start a local development environment with Postgres 16, you will need to provide `--pg-version` flag to 3 of the start-up commands. +To create and start a local development environment with Postgres 16, you will need to provide `--pg-version` flag to 2 of the start-up commands. ```shell -cargo neon init --pg-version 16 +cargo neon init cargo neon start cargo neon tenant create --set-default --pg-version 16 cargo neon endpoint create main --pg-version 16 diff --git a/control_plane/src/bin/neon_local.rs b/control_plane/src/bin/neon_local.rs index ff4f15b0ab..6da3223024 100644 --- a/control_plane/src/bin/neon_local.rs +++ b/control_plane/src/bin/neon_local.rs @@ -16,9 +16,14 @@ use std::time::Duration; use anyhow::{Context, Result, anyhow, bail}; use clap::Parser; use compute_api::requests::ComputeClaimsScope; -use compute_api::spec::{ComputeMode, PageserverConnectionInfo, PageserverShardConnectionInfo}; +use compute_api::spec::{ + ComputeMode, PageserverConnectionInfo, PageserverProtocol, PageserverShardInfo, +}; use control_plane::broker::StorageBroker; use control_plane::endpoint::{ComputeControlPlane, EndpointTerminateMode}; +use control_plane::endpoint::{ + pageserver_conf_to_shard_conn_info, tenant_locate_response_to_conn_info, +}; use control_plane::endpoint_storage::{ENDPOINT_STORAGE_DEFAULT_ADDR, EndpointStorage}; use control_plane::local_env; use control_plane::local_env::{ @@ -44,7 +49,6 @@ use pageserver_api::models::{ }; use pageserver_api::shard::{DEFAULT_STRIPE_SIZE, ShardCount, ShardStripeSize, TenantShardId}; use postgres_backend::AuthType; -use postgres_connection::parse_host_port; use safekeeper_api::membership::{SafekeeperGeneration, SafekeeperId}; use safekeeper_api::{ DEFAULT_HTTP_LISTEN_PORT as DEFAULT_SAFEKEEPER_HTTP_PORT, @@ -52,11 +56,11 @@ use safekeeper_api::{ }; use storage_broker::DEFAULT_LISTEN_ADDR as DEFAULT_BROKER_ADDR; use tokio::task::JoinSet; -use url::Host; use utils::auth::{Claims, Scope}; use utils::id::{NodeId, TenantId, TenantTimelineId, TimelineId}; use utils::lsn::Lsn; use utils::project_git_version; +use utils::shard::ShardIndex; // Default id of a safekeeper node, if not specified on the command line. const DEFAULT_SAFEKEEPER_ID: NodeId = NodeId(1); @@ -1521,74 +1525,56 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res )?; } - let (shards, stripe_size) = if let Some(ps_id) = pageserver_id { - let conf = env.get_pageserver_conf(ps_id).unwrap(); - let libpq_url = Some({ - let (host, port) = parse_host_port(&conf.listen_pg_addr)?; - let port = port.unwrap_or(5432); - format!("postgres://no_user@{host}:{port}") - }); - let grpc_url = if let Some(grpc_addr) = &conf.listen_grpc_addr { - let (host, port) = parse_host_port(grpc_addr)?; - let port = port.unwrap_or(DEFAULT_PAGESERVER_GRPC_PORT); - Some(format!("grpc://no_user@{host}:{port}")) - } else { - None - }; - let pageserver = PageserverShardConnectionInfo { - libpq_url, - grpc_url, - }; + let prefer_protocol = if endpoint.grpc { + PageserverProtocol::Grpc + } else { + PageserverProtocol::Libpq + }; + let mut pageserver_conninfo = if let Some(ps_id) = pageserver_id { + let conf = env.get_pageserver_conf(ps_id).unwrap(); + let ps_conninfo = pageserver_conf_to_shard_conn_info(conf)?; + + let shard_info = PageserverShardInfo { + pageservers: vec![ps_conninfo], + }; // If caller is telling us what pageserver to use, this is not a tenant which is // fully managed by storage controller, therefore not sharded. - (vec![(0, pageserver)], DEFAULT_STRIPE_SIZE) + let shards: HashMap<_, _> = vec![(ShardIndex::unsharded(), shard_info)] + .into_iter() + .collect(); + PageserverConnectionInfo { + shard_count: ShardCount(0), + stripe_size: None, + shards, + prefer_protocol, + } } else { // Look up the currently attached location of the tenant, and its striping metadata, // to pass these on to postgres. let storage_controller = StorageController::from_env(env); let locate_result = storage_controller.tenant_locate(endpoint.tenant_id).await?; - let shards = futures::future::try_join_all(locate_result.shards.into_iter().map( - |shard| async move { - if let ComputeMode::Static(lsn) = endpoint.mode { - // Initialize LSN leases for static computes. + assert!(!locate_result.shards.is_empty()); + + // Initialize LSN leases for static computes. + if let ComputeMode::Static(lsn) = endpoint.mode { + futures::future::try_join_all(locate_result.shards.iter().map( + |shard| async move { let conf = env.get_pageserver_conf(shard.node_id).unwrap(); let pageserver = PageServerNode::from_env(env, conf); pageserver .http_client .timeline_init_lsn_lease(shard.shard_id, endpoint.timeline_id, lsn) - .await?; - } + .await + }, + )) + .await?; + } - let libpq_host = Host::parse(&shard.listen_pg_addr)?; - let libpq_port = shard.listen_pg_port; - let libpq_url = - Some(format!("postgres://no_user@{libpq_host}:{libpq_port}")); - - let grpc_url = if let Some(grpc_host) = shard.listen_grpc_addr { - let grpc_port = shard.listen_grpc_port.expect("no gRPC port"); - Some(format!("grpc://no_user@{grpc_host}:{grpc_port}")) - } else { - None - }; - let pageserver = PageserverShardConnectionInfo { - libpq_url, - grpc_url, - }; - anyhow::Ok((shard.shard_id.shard_number.0 as u32, pageserver)) - }, - )) - .await?; - let stripe_size = locate_result.shard_params.stripe_size; - - (shards, stripe_size) - }; - assert!(!shards.is_empty()); - let pageserver_conninfo = PageserverConnectionInfo { - shards: shards.into_iter().collect(), - prefer_grpc: endpoint.grpc, + tenant_locate_response_to_conn_info(&locate_result)? }; + pageserver_conninfo.prefer_protocol = prefer_protocol; let ps_conf = env.get_pageserver_conf(DEFAULT_PAGESERVER_ID)?; let auth_token = if matches!(ps_conf.pg_auth_type, AuthType::NeonJWT) { @@ -1620,7 +1606,6 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res safekeepers, pageserver_conninfo, remote_ext_base_url: remote_ext_base_url.clone(), - shard_stripe_size: stripe_size.0 as usize, create_test_user: args.create_test_user, start_timeout: args.start_timeout, autoprewarm: args.autoprewarm, @@ -1637,66 +1622,45 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res .endpoints .get(endpoint_id.as_str()) .with_context(|| format!("postgres endpoint {endpoint_id} is not found"))?; - let shards = if let Some(ps_id) = args.endpoint_pageserver_id { + + let prefer_protocol = if endpoint.grpc { + PageserverProtocol::Grpc + } else { + PageserverProtocol::Libpq + }; + let mut pageserver_conninfo = if let Some(ps_id) = args.endpoint_pageserver_id { let conf = env.get_pageserver_conf(ps_id)?; - let libpq_url = Some({ - let (host, port) = parse_host_port(&conf.listen_pg_addr)?; - let port = port.unwrap_or(5432); - format!("postgres://no_user@{host}:{port}") - }); - let grpc_url = if let Some(grpc_addr) = &conf.listen_grpc_addr { - let (host, port) = parse_host_port(grpc_addr)?; - let port = port.unwrap_or(DEFAULT_PAGESERVER_GRPC_PORT); - Some(format!("grpc://no_user@{host}:{port}")) - } else { - None - }; - let pageserver = PageserverShardConnectionInfo { - libpq_url, - grpc_url, + let ps_conninfo = pageserver_conf_to_shard_conn_info(conf)?; + let shard_info = PageserverShardInfo { + pageservers: vec![ps_conninfo], }; + // If caller is telling us what pageserver to use, this is not a tenant which is // fully managed by storage controller, therefore not sharded. - vec![(0, pageserver)] - } else { - let storage_controller = StorageController::from_env(env); - storage_controller - .tenant_locate(endpoint.tenant_id) - .await? - .shards + let shards: HashMap<_, _> = vec![(ShardIndex::unsharded(), shard_info)] .into_iter() - .map(|shard| { - // Use gRPC if requested. - let libpq_host = Host::parse(&shard.listen_pg_addr).expect("bad hostname"); - let libpq_port = shard.listen_pg_port; - let libpq_url = - Some(format!("postgres://no_user@{libpq_host}:{libpq_port}")); + .collect(); + PageserverConnectionInfo { + shard_count: ShardCount::unsharded(), + stripe_size: None, + shards, + prefer_protocol, + } + } else { + // Look up the currently attached location of the tenant, and its striping metadata, + // to pass these on to postgres. + let storage_controller = StorageController::from_env(env); + let locate_result = storage_controller.tenant_locate(endpoint.tenant_id).await?; - let grpc_url = if let Some(grpc_host) = shard.listen_grpc_addr { - let grpc_port = shard.listen_grpc_port.expect("no gRPC port"); - Some(format!("grpc://no_user@{grpc_host}:{grpc_port}")) - } else { - None - }; - ( - shard.shard_id.shard_number.0 as u32, - PageserverShardConnectionInfo { - libpq_url, - grpc_url, - }, - ) - }) - .collect::>() - }; - let pageserver_conninfo = PageserverConnectionInfo { - shards: shards.into_iter().collect(), - prefer_grpc: endpoint.grpc, + tenant_locate_response_to_conn_info(&locate_result)? }; + pageserver_conninfo.prefer_protocol = prefer_protocol; + // If --safekeepers argument is given, use only the listed // safekeeper nodes; otherwise all from the env. let safekeepers = parse_safekeepers(&args.safekeepers)?; endpoint - .reconfigure(Some(pageserver_conninfo), None, safekeepers, None) + .reconfigure(Some(&pageserver_conninfo), safekeepers, None) .await?; } EndpointCmd::Stop(args) => { diff --git a/control_plane/src/endpoint.rs b/control_plane/src/endpoint.rs index 2339b8b923..58a419b965 100644 --- a/control_plane/src/endpoint.rs +++ b/control_plane/src/endpoint.rs @@ -37,7 +37,7 @@ //! //! ``` //! -use std::collections::BTreeMap; +use std::collections::{BTreeMap, HashMap}; use std::fmt::Display; use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream}; use std::path::PathBuf; @@ -57,8 +57,8 @@ use compute_api::responses::{ TlsConfig, }; use compute_api::spec::{ - Cluster, ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, Database, PgIdent, - RemoteExtSpec, Role, + Cluster, ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, Database, PageserverProtocol, + PageserverShardInfo, PgIdent, RemoteExtSpec, Role, }; // re-export these, because they're used in the reconfigure() function @@ -69,7 +69,6 @@ use jsonwebtoken::jwk::{ OctetKeyPairParameters, OctetKeyPairType, PublicKeyUse, }; use nix::sys::signal::{Signal, kill}; -use pageserver_api::shard::ShardStripeSize; use pem::Pem; use reqwest::header::CONTENT_TYPE; use safekeeper_api::PgMajorVersion; @@ -80,6 +79,10 @@ use spki::der::Decode; use spki::{SubjectPublicKeyInfo, SubjectPublicKeyInfoRef}; use tracing::debug; use utils::id::{NodeId, TenantId, TimelineId}; +use utils::shard::{ShardIndex, ShardNumber}; + +use pageserver_api::config::DEFAULT_GRPC_LISTEN_PORT as DEFAULT_PAGESERVER_GRPC_PORT; +use postgres_connection::parse_host_port; use crate::local_env::LocalEnv; use crate::postgresql_conf::PostgresConf; @@ -392,7 +395,6 @@ pub struct EndpointStartArgs { pub safekeepers: Vec, pub pageserver_conninfo: PageserverConnectionInfo, pub remote_ext_base_url: Option, - pub shard_stripe_size: usize, pub create_test_user: bool, pub start_timeout: Duration, pub autoprewarm: bool, @@ -724,6 +726,46 @@ impl Endpoint { remote_extensions = None; }; + // For the sake of backwards-compatibility, also fill in 'pageserver_connstring' + // + // XXX: I believe this is not really needed, except to make + // test_forward_compatibility happy. + // + // Use a closure so that we can conviniently return None in the middle of the + // loop. + let pageserver_connstring = (|| { + let num_shards = if args.pageserver_conninfo.shard_count.is_unsharded() { + 1 + } else { + args.pageserver_conninfo.shard_count.0 + }; + let mut connstrings = Vec::new(); + for shard_no in 0..num_shards { + let shard_index = ShardIndex { + shard_count: args.pageserver_conninfo.shard_count, + shard_number: ShardNumber(shard_no), + }; + let shard = args + .pageserver_conninfo + .shards + .get(&shard_index) + .expect(&format!( + "shard {} not found in pageserver_connection_info", + shard_index + )); + let pageserver = shard + .pageservers + .first() + .expect("must have at least one pageserver"); + if let Some(libpq_url) = &pageserver.libpq_url { + connstrings.push(libpq_url.clone()); + } else { + return None; + } + } + Some(connstrings.join(",")) + })(); + // Create config file let config = { let mut spec = ComputeSpec { @@ -768,13 +810,14 @@ impl Endpoint { branch_id: None, endpoint_id: Some(self.endpoint_id.clone()), mode: self.mode, - pageserver_connection_info: Some(args.pageserver_conninfo), + pageserver_connection_info: Some(args.pageserver_conninfo.clone()), + pageserver_connstring, safekeepers_generation: args.safekeepers_generation.map(|g| g.into_inner()), safekeeper_connstrings, storage_auth_token: args.auth_token.clone(), remote_extensions, pgbouncer_settings: None, - shard_stripe_size: Some(args.shard_stripe_size), + shard_stripe_size: args.pageserver_conninfo.stripe_size, // redundant with pageserver_connection_info.stripe_size local_proxy_config: None, reconfigure_concurrency: self.reconfigure_concurrency, drop_subscriptions_before_start: self.drop_subscriptions_before_start, @@ -986,8 +1029,7 @@ impl Endpoint { pub async fn reconfigure( &self, - pageserver_conninfo: Option, - stripe_size: Option, + pageserver_conninfo: Option<&PageserverConnectionInfo>, safekeepers: Option>, safekeeper_generation: Option, ) -> Result<()> { @@ -1009,10 +1051,8 @@ impl Endpoint { !pageserver_conninfo.shards.is_empty(), "no pageservers provided" ); - spec.pageserver_connection_info = Some(pageserver_conninfo); - } - if stripe_size.is_some() { - spec.shard_stripe_size = stripe_size.map(|s| s.0 as usize); + spec.pageserver_connection_info = Some(pageserver_conninfo.clone()); + spec.shard_stripe_size = pageserver_conninfo.stripe_size; } // If safekeepers are not specified, don't change them. @@ -1061,11 +1101,9 @@ impl Endpoint { pub async fn reconfigure_pageservers( &self, - pageservers: PageserverConnectionInfo, - stripe_size: Option, + pageservers: &PageserverConnectionInfo, ) -> Result<()> { - self.reconfigure(Some(pageservers), stripe_size, None, None) - .await + self.reconfigure(Some(pageservers), None, None).await } pub async fn reconfigure_safekeepers( @@ -1073,7 +1111,7 @@ impl Endpoint { safekeepers: Vec, generation: SafekeeperGeneration, ) -> Result<()> { - self.reconfigure(None, None, Some(safekeepers), Some(generation)) + self.reconfigure(None, Some(safekeepers), Some(generation)) .await } @@ -1129,3 +1167,68 @@ impl Endpoint { ) } } + +pub fn pageserver_conf_to_shard_conn_info( + conf: &crate::local_env::PageServerConf, +) -> Result { + let libpq_url = { + let (host, port) = parse_host_port(&conf.listen_pg_addr)?; + let port = port.unwrap_or(5432); + Some(format!("postgres://no_user@{host}:{port}")) + }; + let grpc_url = if let Some(grpc_addr) = &conf.listen_grpc_addr { + let (host, port) = parse_host_port(grpc_addr)?; + let port = port.unwrap_or(DEFAULT_PAGESERVER_GRPC_PORT); + Some(format!("grpc://no_user@{host}:{port}")) + } else { + None + }; + Ok(PageserverShardConnectionInfo { + id: Some(conf.id.to_string()), + libpq_url, + grpc_url, + }) +} + +pub fn tenant_locate_response_to_conn_info( + response: &pageserver_api::controller_api::TenantLocateResponse, +) -> Result { + let mut shards = HashMap::new(); + for shard in response.shards.iter() { + tracing::info!("parsing {}", shard.listen_pg_addr); + let libpq_url = { + let host = &shard.listen_pg_addr; + let port = shard.listen_pg_port; + Some(format!("postgres://no_user@{host}:{port}")) + }; + let grpc_url = if let Some(grpc_addr) = &shard.listen_grpc_addr { + let host = grpc_addr; + let port = shard.listen_grpc_port.expect("no gRPC port"); + Some(format!("grpc://no_user@{host}:{port}")) + } else { + None + }; + + let shard_info = PageserverShardInfo { + pageservers: vec![PageserverShardConnectionInfo { + id: Some(shard.node_id.to_string()), + libpq_url, + grpc_url, + }], + }; + + shards.insert(shard.shard_id.to_index(), shard_info); + } + + let stripe_size = if response.shard_params.count.is_unsharded() { + None + } else { + Some(response.shard_params.stripe_size.0) + }; + Ok(PageserverConnectionInfo { + shard_count: response.shard_params.count, + stripe_size, + shards, + prefer_protocol: PageserverProtocol::default(), + }) +} diff --git a/control_plane/storcon_cli/src/main.rs b/control_plane/storcon_cli/src/main.rs index fcc5549beb..a4d1030488 100644 --- a/control_plane/storcon_cli/src/main.rs +++ b/control_plane/storcon_cli/src/main.rs @@ -76,6 +76,12 @@ enum Command { NodeStartDelete { #[arg(long)] node_id: NodeId, + /// When `force` is true, skip waiting for shards to prewarm during migration. + /// This can significantly speed up node deletion since prewarming all shards + /// can take considerable time, but may result in slower initial access to + /// migrated shards until they warm up naturally. + #[arg(long)] + force: bool, }, /// Cancel deletion of the specified pageserver and wait for `timeout` /// for the operation to be canceled. May be retried. @@ -952,13 +958,14 @@ async fn main() -> anyhow::Result<()> { .dispatch::<(), ()>(Method::DELETE, format!("control/v1/node/{node_id}"), None) .await?; } - Command::NodeStartDelete { node_id } => { + Command::NodeStartDelete { node_id, force } => { + let query = if force { + format!("control/v1/node/{node_id}/delete?force=true") + } else { + format!("control/v1/node/{node_id}/delete") + }; storcon_client - .dispatch::<(), ()>( - Method::PUT, - format!("control/v1/node/{node_id}/delete"), - None, - ) + .dispatch::<(), ()>(Method::PUT, query, None) .await?; println!("Delete started for {node_id}"); } diff --git a/libs/compute_api/src/responses.rs b/libs/compute_api/src/responses.rs index 2fe233214a..5b8fc49750 100644 --- a/libs/compute_api/src/responses.rs +++ b/libs/compute_api/src/responses.rs @@ -46,16 +46,33 @@ pub struct ExtensionInstallResponse { pub version: ExtVersion, } +/// Status of the LFC prewarm process. The same state machine is reused for +/// both autoprewarm (prewarm after compute/Postgres start using the previously +/// stored LFC state) and explicit prewarming via API. #[derive(Serialize, Default, Debug, Clone, PartialEq)] #[serde(tag = "status", rename_all = "snake_case")] pub enum LfcPrewarmState { + /// Default value when compute boots up. #[default] NotPrewarmed, + /// Prewarming thread is active and loading pages into LFC. Prewarming, + /// We found requested LFC state in the endpoint storage and + /// completed prewarming successfully. Completed, - Failed { - error: String, - }, + /// Unexpected error happened during prewarming. Note, `Not Found 404` + /// response from the endpoint storage is explicitly excluded here + /// because it can normally happen on the first compute start, + /// since LFC state is not available yet. + Failed { error: String }, + /// We tried to fetch the corresponding LFC state from the endpoint storage, + /// but received `Not Found 404`. This should normally happen only during the + /// first endpoint start after creation with `autoprewarm: true`. + /// + /// During the orchestrated prewarm via API, when a caller explicitly + /// provides the LFC state key to prewarm from, it's the caller responsibility + /// to handle this status as an error state in this case. + Skipped, } impl Display for LfcPrewarmState { @@ -64,6 +81,7 @@ impl Display for LfcPrewarmState { LfcPrewarmState::NotPrewarmed => f.write_str("NotPrewarmed"), LfcPrewarmState::Prewarming => f.write_str("Prewarming"), LfcPrewarmState::Completed => f.write_str("Completed"), + LfcPrewarmState::Skipped => f.write_str("Skipped"), LfcPrewarmState::Failed { error } => write!(f, "Error({error})"), } } diff --git a/libs/compute_api/src/spec.rs b/libs/compute_api/src/spec.rs index d011f53ee9..f7ffcd6444 100644 --- a/libs/compute_api/src/spec.rs +++ b/libs/compute_api/src/spec.rs @@ -14,6 +14,7 @@ use serde::{Deserialize, Serialize}; use url::Url; use utils::id::{TenantId, TimelineId}; use utils::lsn::Lsn; +use utils::shard::{ShardCount, ShardIndex}; use crate::responses::TlsConfig; @@ -106,11 +107,18 @@ pub struct ComputeSpec { pub tenant_id: Option, pub timeline_id: Option, - // Pageserver information can be passed in two different ways: - // 1. Here - // 2. in cluster.settings. This is legacy, we are switching to method 1. + /// Pageserver information can be passed in three different ways: + /// 1. Here in `pageserver_connection_info` + /// 2. In the `pageserver_connstring` field. + /// 3. in `cluster.settings`. + /// + /// The goal is to use method 1. everywhere. But for backwards-compatibility with old + /// versions of the control plane, `compute_ctl` will check 2. and 3. if the + /// `pageserver_connection_info` field is missing. pub pageserver_connection_info: Option, + pub pageserver_connstring: Option, + // More neon ids that we expose to the compute_ctl // and to postgres as neon extension GUCs. pub project_id: Option, @@ -145,7 +153,7 @@ pub struct ComputeSpec { // Stripe size for pageserver sharding, in pages #[serde(default)] - pub shard_stripe_size: Option, + pub shard_stripe_size: Option, /// Local Proxy configuration used for JWT authentication #[serde(default)] @@ -218,16 +226,28 @@ pub enum ComputeFeature { UnknownFeature, } -/// Feature flag to signal `compute_ctl` to enable certain experimental functionality. -#[derive(Clone, Debug, Default, Deserialize, Serialize, Eq, PartialEq)] +#[derive(Clone, Debug, Deserialize, Serialize, Eq, PartialEq)] pub struct PageserverConnectionInfo { - pub shards: HashMap, + /// NB: 0 for unsharded tenants, 1 for sharded tenants with 1 shard, following storage + pub shard_count: ShardCount, - pub prefer_grpc: bool, + /// INVARIANT: null if shard_count is 0, otherwise non-null and immutable + pub stripe_size: Option, + + pub shards: HashMap, + + #[serde(default)] + pub prefer_protocol: PageserverProtocol, } -#[derive(Clone, Debug, Default, Deserialize, Serialize, Eq, PartialEq)] +#[derive(Clone, Debug, Deserialize, Serialize, Eq, PartialEq)] +pub struct PageserverShardInfo { + pub pageservers: Vec, +} + +#[derive(Clone, Debug, Deserialize, Serialize, Eq, PartialEq)] pub struct PageserverShardConnectionInfo { + pub id: Option, pub libpq_url: Option, pub grpc_url: Option, } @@ -465,13 +485,15 @@ pub struct JwksSettings { pub jwt_audience: Option, } -/// Protocol used to connect to a Pageserver. Parsed from the connstring scheme. -#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +/// Protocol used to connect to a Pageserver. +#[derive(Clone, Copy, Debug, Default, Deserialize, Serialize, PartialEq, Eq)] pub enum PageserverProtocol { /// The original protocol based on libpq and COPY. Uses postgresql:// or postgres:// scheme. #[default] + #[serde(rename = "libpq")] Libpq, /// A newer, gRPC-based protocol. Uses grpc:// scheme. + #[serde(rename = "grpc")] Grpc, } diff --git a/libs/metrics/src/lib.rs b/libs/metrics/src/lib.rs index 5d028ee041..41873cdcd6 100644 --- a/libs/metrics/src/lib.rs +++ b/libs/metrics/src/lib.rs @@ -4,12 +4,14 @@ //! a default registry. #![deny(clippy::undocumented_unsafe_blocks)] +use std::sync::RwLock; + use measured::label::{LabelGroupSet, LabelGroupVisitor, LabelName, NoLabels}; use measured::metric::counter::CounterState; use measured::metric::gauge::GaugeState; use measured::metric::group::Encoding; use measured::metric::name::{MetricName, MetricNameEncoder}; -use measured::metric::{MetricEncoding, MetricFamilyEncoding}; +use measured::metric::{MetricEncoding, MetricFamilyEncoding, MetricType}; use measured::{FixedCardinalityLabel, LabelGroup, MetricGroup}; use once_cell::sync::Lazy; use prometheus::Registry; @@ -116,12 +118,52 @@ pub fn pow2_buckets(start: usize, end: usize) -> Vec { .collect() } +pub struct InfoMetric { + label: RwLock, + metric: M, +} + +impl InfoMetric { + pub fn new(label: L) -> Self { + Self::with_metric(label, GaugeState::new(1)) + } +} + +impl> InfoMetric { + pub fn with_metric(label: L, metric: M) -> Self { + Self { + label: RwLock::new(label), + metric, + } + } + + pub fn set_label(&self, label: L) { + *self.label.write().unwrap() = label; + } +} + +impl MetricFamilyEncoding for InfoMetric +where + L: LabelGroup, + M: MetricEncoding, + E: Encoding, +{ + fn collect_family_into( + &self, + name: impl measured::metric::name::MetricNameEncoder, + enc: &mut E, + ) -> Result<(), E::Err> { + M::write_type(&name, enc)?; + self.metric + .collect_into(&(), &*self.label.read().unwrap(), name, enc) + } +} + pub struct BuildInfo { pub revision: &'static str, pub build_tag: &'static str, } -// todo: allow label group without the set impl LabelGroup for BuildInfo { fn visit_values(&self, v: &mut impl LabelGroupVisitor) { const REVISION: &LabelName = LabelName::from_str("revision"); @@ -131,24 +173,6 @@ impl LabelGroup for BuildInfo { } } -impl MetricFamilyEncoding for BuildInfo -where - GaugeState: MetricEncoding, -{ - fn collect_family_into( - &self, - name: impl measured::metric::name::MetricNameEncoder, - enc: &mut T, - ) -> Result<(), T::Err> { - enc.write_help(&name, "Build/version information")?; - GaugeState::write_type(&name, enc)?; - GaugeState { - count: std::sync::atomic::AtomicI64::new(1), - } - .collect_into(&(), self, name, enc) - } -} - #[derive(MetricGroup)] #[metric(new(build_info: BuildInfo))] pub struct NeonMetrics { @@ -165,8 +189,8 @@ pub struct NeonMetrics { #[derive(MetricGroup)] #[metric(new(build_info: BuildInfo))] pub struct LibMetrics { - #[metric(init = build_info)] - build_info: BuildInfo, + #[metric(init = InfoMetric::new(build_info))] + build_info: InfoMetric, #[metric(flatten)] rusage: Rusage, diff --git a/libs/neon-shmem/Cargo.toml b/libs/neon-shmem/Cargo.toml index 8ce5b52deb..da014de17b 100644 --- a/libs/neon-shmem/Cargo.toml +++ b/libs/neon-shmem/Cargo.toml @@ -8,20 +8,19 @@ license.workspace = true thiserror.workspace = true nix.workspace = true workspace_hack = { version = "0.1", path = "../../workspace_hack" } -rustc-hash = { version = "2.1.1" } -rand = "0.9.1" libc.workspace = true -lock_api = "0.4.13" +lock_api.workspace = true +rustc-hash.workspace = true [dev-dependencies] criterion = { workspace = true, features = ["html_reports"] } +rand = "0.9" rand_distr = "0.5.1" xxhash-rust = { version = "0.8.15", features = ["xxh3"] } ahash.workspace = true twox-hash = { version = "2.1.1" } seahash = "4.1.0" hashbrown = { git = "https://github.com/quantumish/hashbrown.git", rev = "6610e6d" } -foldhash = "0.1.5" [target.'cfg(target_os = "macos")'.dependencies] diff --git a/libs/neon-shmem/src/hash.rs b/libs/neon-shmem/src/hash.rs index 6fc7baefcc..627dd5d15c 100644 --- a/libs/neon-shmem/src/hash.rs +++ b/libs/neon-shmem/src/hash.rs @@ -13,6 +13,8 @@ //! This map is resizable (if initialized on top of a [`ShmemHandle`]). Both growing and shrinking happen //! in-place and are at a high level achieved by expanding/reducing the bucket array and rebuilding the //! dictionary by rehashing all keys. +//! +//! Concurrency is managed very simply: the entire map is guarded by one shared-memory RwLock. use std::fmt::Debug; use std::hash::{BuildHasher, Hash}; @@ -30,6 +32,19 @@ mod tests; use core::{Bucket, CoreHashMap, INVALID_POS}; use entry::{Entry, OccupiedEntry, PrevPos, VacantEntry}; +use thiserror::Error; + +/// Error type for a hashmap shrink operation. +#[derive(Error, Debug)] +pub enum HashMapShrinkError { + /// There was an error encountered while resizing the memory area. + #[error("shmem resize failed: {0}")] + ResizeError(shmem::Error), + /// Occupied entries in to-be-shrunk space were encountered beginning at the given index. + #[error("occupied entry in deallocated space found at {0}")] + RemainingEntries(usize), +} + /// This represents a hash table that (possibly) lives in shared memory. /// If a new process is launched with fork(), the child process inherits /// this struct. @@ -147,8 +162,8 @@ impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> { }; let hashmap = CoreHashMap::new(buckets, dictionary); - let lock = RwLock::from_raw(PthreadRwLock::new(raw_lock_ptr.cast()), hashmap); unsafe { + let lock = RwLock::from_raw(PthreadRwLock::new(raw_lock_ptr.cast()), hashmap); std::ptr::write(shared_ptr, lock); } @@ -171,6 +186,9 @@ impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> { } /// Initialize a table for reading. Currently identical to [`HashMapInit::attach_writer`]. + /// + /// This is a holdover from a previous implementation and is being kept around for + /// backwards compatibility reasons. pub fn attach_reader(self) -> HashMapAccess<'a, K, V, S> { self.attach_writer() } @@ -184,8 +202,8 @@ impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> { /// /// [`libc::pthread_rwlock_t`] /// [`HashMapShared`] -/// [buckets] -/// [dictionary] +/// buckets +/// dictionary /// /// In between the above parts, there can be padding bytes to align the parts correctly. type HashMapShared<'a, K, V> = RwLock>; @@ -310,6 +328,9 @@ where } /// Get a reference to the entry containing a key. + /// + /// NB: This takes a write lock as there's no way to distinguish whether the intention + /// is to use the entry for reading or for writing in advance. pub fn entry(&self, key: K) -> Entry<'a, '_, K, V> { let hash = self.get_hash_value(&key); self.entry_with_hash(key, hash) @@ -317,7 +338,7 @@ where /// Remove a key given its hash. Returns the associated value if it existed. pub fn remove(&self, key: &K) -> Option { - let hash = self.get_hash_value(&key); + let hash = self.get_hash_value(key); match self.entry_with_hash(key.clone(), hash) { Entry::Occupied(e) => Some(e.remove()), Entry::Vacant(_) => None, @@ -355,7 +376,7 @@ where Some((key, _)) => Some(OccupiedEntry { _key: key.clone(), bucket_pos: pos as u32, - prev_pos: entry::PrevPos::Unknown(self.get_hash_value(&key)), + prev_pos: entry::PrevPos::Unknown(self.get_hash_value(key)), map, }), _ => None, @@ -550,12 +571,7 @@ where /// The following cases result in a panic: /// - Calling this function on a map initialized with [`HashMapInit::with_fixed`]. /// - Calling this function on a map when no shrink operation is in progress. - /// - Calling this function on a map with `shrink_mode` set to [`HashMapShrinkMode::Remap`] and - /// there are more buckets in use than the value returned by [`HashMapAccess::shrink_goal`]. - /// - /// # Errors - /// Returns an [`shmem::Error`] if any errors occur resizing the memory region. - pub fn finish_shrink(&self) -> Result<(), shmem::Error> { + pub fn finish_shrink(&self) -> Result<(), HashMapShrinkError> { let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write(); assert!( map.alloc_limit != INVALID_POS, @@ -574,10 +590,8 @@ where ); for i in (num_buckets as usize)..map.buckets.len() { - if let Some((k, v)) = map.buckets[i].inner.take() { - // alloc_bucket increases count, so need to decrease since we're just moving - map.buckets_in_use -= 1; - map.alloc_bucket(k, v).unwrap(); + if map.buckets[i].inner.is_some() { + return Err(HashMapShrinkError::RemainingEntries(i)); } } @@ -587,7 +601,9 @@ where .expect("shrink called on a fixed-size hash table"); let size_bytes = HashMapInit::::estimate_size(num_buckets); - shmem_handle.set_size(size_bytes)?; + if let Err(e) = shmem_handle.set_size(size_bytes) { + return Err(HashMapShrinkError::ResizeError(e)); + } let end_ptr: *mut u8 = unsafe { shmem_handle.data_ptr.as_ptr().add(size_bytes) }; let buckets_ptr = map.buckets.as_mut_ptr(); self.rehash_dict(&mut map, buckets_ptr, end_ptr, num_buckets, num_buckets); diff --git a/libs/neon-shmem/src/hash/core.rs b/libs/neon-shmem/src/hash/core.rs index 67be7672d1..0d40a1877b 100644 --- a/libs/neon-shmem/src/hash/core.rs +++ b/libs/neon-shmem/src/hash/core.rs @@ -43,9 +43,6 @@ pub(crate) struct CoreHashMap<'a, K, V> { pub(crate) alloc_limit: u32, /// The number of currently occupied buckets. pub(crate) buckets_in_use: u32, - // pub(crate) lock: libc::pthread_mutex_t, - // Unclear what the purpose of this is. - pub(crate) _user_list_head: u32, } impl<'a, K, V> Debug for CoreHashMap<'a, K, V> @@ -66,7 +63,7 @@ where /// Error for when there are no empty buckets left but one is needed. #[derive(Debug, PartialEq)] -pub struct FullError(); +pub struct FullError; impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> { const FILL_FACTOR: f32 = 0.60; @@ -118,7 +115,6 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> { buckets, free_head: 0, buckets_in_use: 0, - _user_list_head: INVALID_POS, alloc_limit: INVALID_POS, } } @@ -179,7 +175,7 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> { pos = bucket.next; } if pos == INVALID_POS { - return Err(FullError()); + return Err(FullError); } // Repair the freelist. diff --git a/libs/neon-shmem/src/hash/entry.rs b/libs/neon-shmem/src/hash/entry.rs index bf2f63fe9c..93ae700483 100644 --- a/libs/neon-shmem/src/hash/entry.rs +++ b/libs/neon-shmem/src/hash/entry.rs @@ -90,7 +90,6 @@ impl OccupiedEntry<'_, '_, K, V> { self.map.dictionary[dict_pos as usize] = bucket.next; } PrevPos::Chained(bucket_pos) => { - // println!("we think prev of {} is {bucket_pos}", self.bucket_pos); self.map.buckets[bucket_pos as usize].next = bucket.next; } _ => unreachable!(), @@ -125,9 +124,6 @@ impl<'b, K: Clone + Hash + Eq, V> VacantEntry<'_, 'b, K, V> { /// Will return [`FullError`] if there are no unoccupied buckets in the map. pub fn insert(mut self, value: V) -> Result, FullError> { let pos = self.map.alloc_bucket(self.key, value)?; - if pos == INVALID_POS { - return Err(FullError()); - } self.map.buckets[pos as usize].next = self.map.dictionary[self.dict_pos as usize]; self.map.dictionary[self.dict_pos as usize] = pos; diff --git a/libs/neon-shmem/src/hash/tests.rs b/libs/neon-shmem/src/hash/tests.rs index aee47a0b3e..92233e8140 100644 --- a/libs/neon-shmem/src/hash/tests.rs +++ b/libs/neon-shmem/src/hash/tests.rs @@ -164,16 +164,16 @@ fn do_deletes( fn do_shrink( writer: &mut HashMapAccess, shadow: &mut BTreeMap, + from: u32, to: u32, ) { assert!(writer.shrink_goal().is_none()); writer.begin_shrink(to); assert_eq!(writer.shrink_goal(), Some(to as usize)); - while writer.get_num_buckets_in_use() > to as usize { - let (k, _) = shadow.pop_first().unwrap(); - let entry = writer.entry(k); - if let Entry::Occupied(e) = entry { - e.remove(); + for i in to..from { + if let Some(entry) = writer.entry_at_bucket(i as usize) { + shadow.remove(&entry._key); + entry.remove(); } } let old_usage = writer.get_num_buckets_in_use(); @@ -298,7 +298,7 @@ fn test_shrink() { let mut rng = rand::rng(); do_random_ops(10000, 1500, 0.75, &mut writer, &mut shadow, &mut rng); - do_shrink(&mut writer, &mut shadow, 1000); + do_shrink(&mut writer, &mut shadow, 1500, 1000); assert_eq!(writer.get_num_buckets(), 1000); do_deletes(500, &mut writer, &mut shadow); do_random_ops(10000, 500, 0.75, &mut writer, &mut shadow, &mut rng); @@ -315,7 +315,7 @@ fn test_shrink_grow_seq() { do_random_ops(500, 1000, 0.1, &mut writer, &mut shadow, &mut rng); eprintln!("Shrinking to 750"); - do_shrink(&mut writer, &mut shadow, 750); + do_shrink(&mut writer, &mut shadow, 1000, 750); do_random_ops(200, 1000, 0.5, &mut writer, &mut shadow, &mut rng); eprintln!("Growing to 1500"); writer.grow(1500).unwrap(); @@ -324,7 +324,7 @@ fn test_shrink_grow_seq() { while shadow.len() > 100 { do_deletes(1, &mut writer, &mut shadow); } - do_shrink(&mut writer, &mut shadow, 200); + do_shrink(&mut writer, &mut shadow, 1500, 200); do_random_ops(50, 1500, 0.25, &mut writer, &mut shadow, &mut rng); eprintln!("Growing to 10k"); writer.grow(10000).unwrap(); @@ -349,8 +349,7 @@ fn test_bucket_ops() { let pos = match writer.entry(1.into()) { Entry::Occupied(e) => { assert_eq!(e._key, 1.into()); - let pos = e.bucket_pos as usize; - pos + e.bucket_pos as usize } Entry::Vacant(_) => { panic!("Insert didn't affect entry"); diff --git a/libs/neon-shmem/src/lib.rs b/libs/neon-shmem/src/lib.rs index 61ca168073..226cc0c22d 100644 --- a/libs/neon-shmem/src/lib.rs +++ b/libs/neon-shmem/src/lib.rs @@ -1,5 +1,3 @@ -//! Shared memory utilities for neon communicator - pub mod hash; pub mod shmem; pub mod sync; diff --git a/libs/neon-shmem/src/sync.rs b/libs/neon-shmem/src/sync.rs index 5a296b4047..95719778ba 100644 --- a/libs/neon-shmem/src/sync.rs +++ b/libs/neon-shmem/src/sync.rs @@ -6,7 +6,7 @@ use std::ptr::NonNull; use nix::errno::Errno; pub type RwLock = lock_api::RwLock; -pub(crate) type RwLockReadGuard<'a, T> = lock_api::RwLockReadGuard<'a, PthreadRwLock, T>; +pub type RwLockReadGuard<'a, T> = lock_api::RwLockReadGuard<'a, PthreadRwLock, T>; pub type RwLockWriteGuard<'a, T> = lock_api::RwLockWriteGuard<'a, PthreadRwLock, T>; pub type ValueReadGuard<'a, T> = lock_api::MappedRwLockReadGuard<'a, PthreadRwLock, T>; pub type ValueWriteGuard<'a, T> = lock_api::MappedRwLockWriteGuard<'a, PthreadRwLock, T>; @@ -14,19 +14,34 @@ pub type ValueWriteGuard<'a, T> = lock_api::MappedRwLockWriteGuard<'a, PthreadRw /// Shared memory read-write lock. pub struct PthreadRwLock(Option>); +/// Simple macro that calls a function in the libc namespace and panics if return value is nonzero. +macro_rules! libc_checked { + ($fn_name:ident ( $($arg:expr),* )) => {{ + let res = libc::$fn_name($($arg),*); + if res != 0 { + panic!("{} failed with {}", stringify!($fn_name), Errno::from_raw(res)); + } + }}; +} + impl PthreadRwLock { - pub fn new(lock: *mut libc::pthread_rwlock_t) -> Self { + /// Creates a new `PthreadRwLock` on top of a pointer to a pthread rwlock. + /// + /// # Safety + /// `lock` must be non-null. Every unsafe operation will panic in the event of an error. + pub unsafe fn new(lock: *mut libc::pthread_rwlock_t) -> Self { unsafe { let mut attrs = MaybeUninit::uninit(); - // Ignoring return value here - only possible error is OOM. - libc::pthread_rwlockattr_init(attrs.as_mut_ptr()); - libc::pthread_rwlockattr_setpshared(attrs.as_mut_ptr(), libc::PTHREAD_PROCESS_SHARED); - // TODO(quantumish): worth making this function return Result? - libc::pthread_rwlock_init(lock, attrs.as_mut_ptr()); + libc_checked!(pthread_rwlockattr_init(attrs.as_mut_ptr())); + libc_checked!(pthread_rwlockattr_setpshared( + attrs.as_mut_ptr(), + libc::PTHREAD_PROCESS_SHARED + )); + libc_checked!(pthread_rwlock_init(lock, attrs.as_mut_ptr())); // Safety: POSIX specifies that "any function affecting the attributes // object (including destruction) shall not affect any previously // initialized read-write locks". - libc::pthread_rwlockattr_destroy(attrs.as_mut_ptr()); + libc_checked!(pthread_rwlockattr_destroy(attrs.as_mut_ptr())); Self(Some(NonNull::new_unchecked(lock))) } } @@ -34,7 +49,7 @@ impl PthreadRwLock { fn inner(&self) -> NonNull { match self.0 { None => { - panic!("PthreadRwLock constructed badly - something likely used RawMutex::INIT") + panic!("PthreadRwLock constructed badly - something likely used RawRwLock::INIT") } Some(x) => x, } @@ -45,31 +60,16 @@ unsafe impl lock_api::RawRwLock for PthreadRwLock { type GuardMarker = lock_api::GuardSend; const INIT: Self = Self(None); - fn lock_shared(&self) { - unsafe { - let res = libc::pthread_rwlock_rdlock(self.inner().as_ptr()); - if res != 0 { - panic!("rdlock failed with {}", Errno::from_raw(res)); - } - } - } - fn try_lock_shared(&self) -> bool { unsafe { let res = libc::pthread_rwlock_tryrdlock(self.inner().as_ptr()); match res { 0 => true, libc::EAGAIN => false, - _ => panic!("try_rdlock failed with {}", Errno::from_raw(res)), - } - } - } - - fn lock_exclusive(&self) { - unsafe { - let res = libc::pthread_rwlock_wrlock(self.inner().as_ptr()); - if res != 0 { - panic!("wrlock failed with {}", Errno::from_raw(res)); + _ => panic!( + "pthread_rwlock_tryrdlock failed with {}", + Errno::from_raw(res) + ), } } } @@ -85,20 +85,27 @@ unsafe impl lock_api::RawRwLock for PthreadRwLock { } } - unsafe fn unlock_exclusive(&self) { + fn lock_shared(&self) { unsafe { - let res = libc::pthread_rwlock_unlock(self.inner().as_ptr()); - if res != 0 { - panic!("unlock failed with {}", Errno::from_raw(res)); - } + libc_checked!(pthread_rwlock_rdlock(self.inner().as_ptr())); } } + + fn lock_exclusive(&self) { + unsafe { + libc_checked!(pthread_rwlock_wrlock(self.inner().as_ptr())); + } + } + + unsafe fn unlock_exclusive(&self) { + unsafe { + libc_checked!(pthread_rwlock_unlock(self.inner().as_ptr())); + } + } + unsafe fn unlock_shared(&self) { unsafe { - let res = libc::pthread_rwlock_unlock(self.inner().as_ptr()); - if res != 0 { - panic!("unlock failed with {}", Errno::from_raw(res)); - } + libc_checked!(pthread_rwlock_unlock(self.inner().as_ptr())); } } } diff --git a/libs/postgres_backend/src/lib.rs b/libs/postgres_backend/src/lib.rs index 851d824291..20afa8bb46 100644 --- a/libs/postgres_backend/src/lib.rs +++ b/libs/postgres_backend/src/lib.rs @@ -749,7 +749,18 @@ impl PostgresBackend { trace!("got query {query_string:?}"); if let Err(e) = handler.process_query(self, query_string).await { match e { - QueryError::Shutdown => return Ok(ProcessMsgResult::Break), + err @ QueryError::Shutdown => { + // Notify postgres of the connection shutdown at the libpq + // protocol level. This avoids postgres having to tell apart + // from an idle connection and a stale one, which is bug prone. + let shutdown_error = short_error(&err); + self.write_message_noflush(&BeMessage::ErrorResponse( + &shutdown_error, + Some(err.pg_error_code()), + ))?; + + return Ok(ProcessMsgResult::Break); + } QueryError::SimulatedConnectionError => { return Err(QueryError::SimulatedConnectionError); } diff --git a/libs/utils/Cargo.toml b/libs/utils/Cargo.toml index 7b1dc56071..4b326949d7 100644 --- a/libs/utils/Cargo.toml +++ b/libs/utils/Cargo.toml @@ -47,6 +47,7 @@ tracing-subscriber = { workspace = true, features = ["json", "registry"] } tracing-utils.workspace = true rand.workspace = true scopeguard.workspace = true +uuid.workspace = true strum.workspace = true strum_macros.workspace = true walkdir.workspace = true diff --git a/libs/utils/src/auth.rs b/libs/utils/src/auth.rs index de3a964d23..b2aade15de 100644 --- a/libs/utils/src/auth.rs +++ b/libs/utils/src/auth.rs @@ -12,7 +12,8 @@ use jsonwebtoken::{ Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation, decode, encode, }; use pem::Pem; -use serde::{Deserialize, Serialize, de::DeserializeOwned}; +use serde::{Deserialize, Deserializer, Serialize, de::DeserializeOwned}; +use uuid::Uuid; use crate::id::TenantId; @@ -25,6 +26,11 @@ pub enum Scope { /// Provides access to all data for a specific tenant (specified in `struct Claims` below) // TODO: join these two? Tenant, + /// Provides access to all data for a specific tenant, but based on endpoint ID. This token scope + /// is only used by compute to fetch the spec for a specific endpoint. The spec contains a Tenant-scoped + /// token authorizing access to all data of a tenant, so the spec-fetch API requires a TenantEndpoint + /// scope token to ensure that untrusted compute nodes can't fetch spec for arbitrary endpoints. + TenantEndpoint, /// Provides blanket access to all tenants on the pageserver plus pageserver-wide APIs. /// Should only be used e.g. for status check/tenant creation/list. PageServerApi, @@ -51,17 +57,43 @@ pub enum Scope { ControllerPeer, } +fn deserialize_empty_string_as_none_uuid<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + let opt = Option::::deserialize(deserializer)?; + match opt.as_deref() { + Some("") => Ok(None), + Some(s) => Uuid::parse_str(s) + .map(Some) + .map_err(serde::de::Error::custom), + None => Ok(None), + } +} + /// JWT payload. See docs/authentication.md for the format #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] pub struct Claims { #[serde(default)] pub tenant_id: Option, + #[serde( + default, + skip_serializing_if = "Option::is_none", + // Neon control plane includes this field as empty in the claims. + // Consider it None in those cases. + deserialize_with = "deserialize_empty_string_as_none_uuid" + )] + pub endpoint_id: Option, pub scope: Scope, } impl Claims { pub fn new(tenant_id: Option, scope: Scope) -> Self { - Self { tenant_id, scope } + Self { + tenant_id, + scope, + endpoint_id: None, + } } } @@ -212,6 +244,7 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH let expected_claims = Claims { tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081").unwrap()), scope: Scope::Tenant, + endpoint_id: None, }; // A test token containing the following payload, signed using TEST_PRIV_KEY_ED25519: @@ -240,6 +273,7 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH let claims = Claims { tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081").unwrap()), scope: Scope::Tenant, + endpoint_id: None, }; let pem = pem::parse(TEST_PRIV_KEY_ED25519).unwrap(); diff --git a/libs/utils/src/shard.rs b/libs/utils/src/shard.rs index 5a0edf8cea..3345549819 100644 --- a/libs/utils/src/shard.rs +++ b/libs/utils/src/shard.rs @@ -53,6 +53,10 @@ impl ShardCount { pub const MAX: Self = Self(u8::MAX); pub const MIN: Self = Self(0); + pub fn unsharded() -> Self { + ShardCount(0) + } + /// The internal value of a ShardCount may be zero, which means "1 shard, but use /// legacy format for TenantShardId that excludes the shard suffix", also known /// as [`TenantShardId::unsharded`]. diff --git a/pageserver/client/src/mgmt_api.rs b/pageserver/client/src/mgmt_api.rs index fe1ddc2e7d..3867e536f4 100644 --- a/pageserver/client/src/mgmt_api.rs +++ b/pageserver/client/src/mgmt_api.rs @@ -873,6 +873,22 @@ impl Client { .map_err(Error::ReceiveBody) } + pub async fn reset_alert_gauges(&self) -> Result<()> { + let uri = format!( + "{}/hadron-internal/reset_alert_gauges", + self.mgmt_api_endpoint + ); + self.start_request(Method::POST, uri) + .send() + .await + .map_err(Error::SendRequest)? + .error_from_body() + .await? + .json() + .await + .map_err(Error::ReceiveBody) + } + pub async fn wait_lsn( &self, tenant_shard_id: TenantShardId, diff --git a/pageserver/src/auth.rs b/pageserver/src/auth.rs index 4075427ab4..9e97fdaba8 100644 --- a/pageserver/src/auth.rs +++ b/pageserver/src/auth.rs @@ -20,7 +20,8 @@ pub fn check_permission(claims: &Claims, tenant_id: Option) -> Result< | Scope::GenerationsApi | Scope::Infra | Scope::Scrubber - | Scope::ControllerPeer, + | Scope::ControllerPeer + | Scope::TenantEndpoint, _, ) => Err(AuthError( format!( diff --git a/pageserver/src/pgdatadir_mapping.rs b/pageserver/src/pgdatadir_mapping.rs index 17401dbae8..8b76d980fc 100644 --- a/pageserver/src/pgdatadir_mapping.rs +++ b/pageserver/src/pgdatadir_mapping.rs @@ -817,6 +817,7 @@ impl Timeline { let gc_cutoff_lsn_guard = self.get_applied_gc_cutoff_lsn(); let gc_cutoff_planned = { let gc_info = self.gc_info.read().unwrap(); + info!(cutoffs=?gc_info.cutoffs, applied_cutoff=%*gc_cutoff_lsn_guard, "starting find_lsn_for_timestamp"); gc_info.min_cutoff() }; // Usually the planned cutoff is newer than the cutoff of the last gc run, diff --git a/pgxn/neon/communicator/src/neon_request.rs b/pgxn/neon/communicator/src/neon_request.rs index 70ceaf8744..f777256a5f 100644 --- a/pgxn/neon/communicator/src/neon_request.rs +++ b/pgxn/neon/communicator/src/neon_request.rs @@ -1,4 +1,3 @@ - // Definitions of some core PostgreSQL datatypes. /// XLogRecPtr is defined in "access/xlogdefs.h" as: diff --git a/pgxn/neon/communicator/src/worker_process/logging.rs b/pgxn/neon/communicator/src/worker_process/logging.rs index 3b652e203f..43f51cd332 100644 --- a/pgxn/neon/communicator/src/worker_process/logging.rs +++ b/pgxn/neon/communicator/src/worker_process/logging.rs @@ -1,10 +1,13 @@ //! Glue code to hook up Rust logging with the `tracing` crate to the PostgreSQL log //! //! In the Rust threads, the log messages are written to a mpsc Channel, and the Postgres -//! process latch is raised. That wakes up the loop in the main thread. It reads the -//! message from the channel and ereport()s it. This ensures that only one thread, the main -//! thread, calls the PostgreSQL logging routines at any time. +//! process latch is raised. That wakes up the loop in the main thread, see +//! `communicator_new_bgworker_main()`. It reads the message from the channel and +//! ereport()s it. This ensures that only one thread, the main thread, calls the +//! PostgreSQL logging routines at any time. +use std::ffi::c_char; +use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::mpsc::sync_channel; use std::sync::mpsc::{Receiver, SyncSender}; use std::sync::mpsc::{TryRecvError, TrySendError}; @@ -12,27 +15,32 @@ use std::sync::mpsc::{TryRecvError, TrySendError}; use tracing::info; use tracing::{Event, Level, Metadata, Subscriber}; use tracing_subscriber::filter::LevelFilter; -use tracing_subscriber::fmt::FmtContext; -use tracing_subscriber::fmt::FormatEvent; -use tracing_subscriber::fmt::FormatFields; -use tracing_subscriber::fmt::FormattedFields; -use tracing_subscriber::fmt::MakeWriter; use tracing_subscriber::fmt::format::Writer; +use tracing_subscriber::fmt::{FmtContext, FormatEvent, FormatFields, FormattedFields, MakeWriter}; use tracing_subscriber::registry::LookupSpan; use crate::worker_process::callbacks::callback_set_my_latch; -pub struct LoggingState { +/// This handle is passed to the C code, and used by [`communicator_worker_poll_logging`] +pub struct LoggingReceiver { receiver: Receiver, } +/// This is passed to `tracing` +struct LoggingSender { + sender: SyncSender, +} + +static DROPPED_EVENT_COUNT: AtomicU64 = AtomicU64::new(0); + /// Called once, at worker process startup. The returned LoggingState is passed back /// in the subsequent calls to `pump_logging`. It is opaque to the C code. #[unsafe(no_mangle)] -pub extern "C" fn configure_logging() -> Box { +pub extern "C" fn communicator_worker_configure_logging() -> Box { let (sender, receiver) = sync_channel(1000); - let maker = Maker { channel: sender }; + let receiver = LoggingReceiver { receiver }; + let sender = LoggingSender { sender }; use tracing_subscriber::prelude::*; let r = tracing_subscriber::registry(); @@ -41,32 +49,45 @@ pub extern "C" fn configure_logging() -> Box { tracing_subscriber::fmt::layer() .with_ansi(false) .event_format(SimpleFormatter::new()) - .with_writer(maker) - // TODO: derive this from log_min_messages? + .with_writer(sender) + // TODO: derive this from log_min_messages? Currently the code in + // communicator_process.c forces log_min_messages='INFO'. .with_filter(LevelFilter::from_level(Level::INFO)), ); r.init(); info!("communicator process logging started"); - let state = LoggingState { receiver }; - - Box::new(state) + Box::new(receiver) } /// Read one message from the logging queue. This is essentially a wrapper to Receiver, /// with a C-friendly signature. /// -/// The message is copied into *errbuf, which is a caller-supplied buffer of size `errbuf_len`. -/// If the message doesn't fit in the buffer, it is truncated. It is always NULL-terminated. +/// The message is copied into *errbuf, which is a caller-supplied buffer of size +/// `errbuf_len`. If the message doesn't fit in the buffer, it is truncated. It is always +/// NULL-terminated. /// -/// The error level is returned *elevel_p. It's one of the PostgreSQL error levels, see elog.h +/// The error level is returned *elevel_p. It's one of the PostgreSQL error levels, see +/// elog.h +/// +/// If there was a message, *dropped_event_count_p is also updated with a counter of how +/// many log messages in total has been dropped. By comparing that with the value from +/// previous call, you can tell how many were dropped since last call. +/// +/// Returns: +/// +/// 0 if there were no messages +/// 1 if there was a message. The message and its level are returned in +/// *errbuf and *elevel_p. *dropped_event_count_p is also updated. +/// -1 on error, i.e the other end of the queue was disconnected #[unsafe(no_mangle)] -pub extern "C" fn pump_logging( - state: &mut LoggingState, - errbuf: *mut u8, +pub extern "C" fn communicator_worker_poll_logging( + state: &mut LoggingReceiver, + errbuf: *mut c_char, errbuf_len: u32, elevel_p: &mut i32, + dropped_event_count_p: &mut u64, ) -> i32 { let msg = match state.receiver.try_recv() { Err(TryRecvError::Empty) => return 0, @@ -75,15 +96,17 @@ pub extern "C" fn pump_logging( }; let src: &[u8] = &msg.message; - let dst = errbuf; + let dst: *mut u8 = errbuf.cast(); let len = std::cmp::min(src.len(), errbuf_len as usize - 1); unsafe { std::ptr::copy_nonoverlapping(src.as_ptr(), dst, len); - *(errbuf.add(len)) = b'\0'; // NULL terminator + *(dst.add(len)) = b'\0'; // NULL terminator } - // XXX: these levels are copied from PostgreSQL's elog.h. Introduce another enum - // to hide these? + // Map the tracing Level to PostgreSQL elevel. + // + // XXX: These levels are copied from PostgreSQL's elog.h. Introduce another enum to + // hide these? *elevel_p = match msg.level { Level::TRACE => 10, // DEBUG5 Level::DEBUG => 14, // DEBUG1 @@ -92,6 +115,8 @@ pub extern "C" fn pump_logging( Level::ERROR => 21, // ERROR }; + *dropped_event_count_p = DROPPED_EVENT_COUNT.load(Ordering::Relaxed); + 1 } @@ -115,7 +140,7 @@ impl Default for FormattedEventWithMeta { struct EventBuilder<'a> { event: FormattedEventWithMeta, - maker: &'a Maker, + sender: &'a LoggingSender, } impl std::io::Write for EventBuilder<'_> { @@ -123,25 +148,21 @@ impl std::io::Write for EventBuilder<'_> { self.event.message.write(buf) } fn flush(&mut self) -> std::io::Result<()> { - self.maker.send_event(self.event.clone()); + self.sender.send_event(self.event.clone()); Ok(()) } } impl Drop for EventBuilder<'_> { fn drop(&mut self) { - let maker = self.maker; + let sender = self.sender; let event = std::mem::take(&mut self.event); - maker.send_event(event); + sender.send_event(event); } } -struct Maker { - channel: SyncSender, -} - -impl<'a> MakeWriter<'a> for Maker { +impl<'a> MakeWriter<'a> for LoggingSender { type Writer = EventBuilder<'a>; fn make_writer(&'a self) -> Self::Writer { @@ -154,33 +175,38 @@ impl<'a> MakeWriter<'a> for Maker { message: Vec::new(), level: *meta.level(), }, - maker: self, + sender: self, } } } -impl Maker { +impl LoggingSender { fn send_event(&self, e: FormattedEventWithMeta) { - match self.channel.try_send(e) { + match self.sender.try_send(e) { Ok(()) => { // notify the main thread callback_set_my_latch(); } Err(TrySendError::Disconnected(_)) => {} Err(TrySendError::Full(_)) => { - // TODO: record that some messages were lost + // The queue is full, cannot send any more. To avoid blocking the tokio + // thread, simply drop the message. Better to lose some logs than get + // stuck if there's a problem with the logging. + // + // Record the fact that was a message was dropped by incrementing the + // counter. + DROPPED_EVENT_COUNT.fetch_add(1, Ordering::Relaxed); } } } } -/// Simple formatter implementation for tracing_subscriber, which prints the log -/// spans and message part like the default formatter, but no timestamp or error -/// level. The error level is captured separately by `FormattedEventWithMeta', -/// and when the error is printed by the main thread, with PostgreSQL ereport(), -/// it gets a timestamp at that point. (The timestamp printed will therefore lag -/// behind the timestamp on the event here, if the main thread doesn't process -/// the log message promptly) +/// Simple formatter implementation for tracing_subscriber, which prints the log spans and +/// message part like the default formatter, but no timestamp or error level. The error +/// level is captured separately by `FormattedEventWithMeta', and when the error is +/// printed by the main thread, with PostgreSQL ereport(), it gets a timestamp at that +/// point. (The timestamp printed will therefore lag behind the timestamp on the event +/// here, if the main thread doesn't process the log message promptly) struct SimpleFormatter; impl FormatEvent for SimpleFormatter @@ -199,11 +225,10 @@ where for span in scope.from_root() { write!(writer, "{}", span.name())?; - // `FormattedFields` is a formatted representation of the span's - // fields, which is stored in its extensions by the `fmt` layer's - // `new_span` method. The fields will have been formatted - // by the same field formatter that's provided to the event - // formatter in the `FmtContext`. + // `FormattedFields` is a formatted representation of the span's fields, + // which is stored in its extensions by the `fmt` layer's `new_span` + // method. The fields will have been formatted by the same field formatter + // that's provided to the event formatter in the `FmtContext`. let ext = span.extensions(); let fields = &ext .get::>() @@ -220,7 +245,7 @@ where // Write fields on the event ctx.field_format().format_fields(writer.by_ref(), event)?; - writeln!(writer) + Ok(()) } } diff --git a/pgxn/neon/communicator/src/worker_process/worker_interface.rs b/pgxn/neon/communicator/src/worker_process/worker_interface.rs index a7bd79fa83..ff9b1ba699 100644 --- a/pgxn/neon/communicator/src/worker_process/worker_interface.rs +++ b/pgxn/neon/communicator/src/worker_process/worker_interface.rs @@ -30,6 +30,7 @@ pub extern "C" fn communicator_worker_process_launch( file_cache_path: *const c_char, initial_file_cache_size: u64, ) -> &'static CommunicatorWorkerProcessStruct<'static> { + tracing::warn!("starting threads in rust code"); // Convert the arguments into more convenient Rust types let tenant_id = unsafe { CStr::from_ptr(tenant_id) }.to_str().unwrap(); let timeline_id = unsafe { CStr::from_ptr(timeline_id) }.to_str().unwrap(); diff --git a/pgxn/neon/communicator_new.c b/pgxn/neon/communicator_new.c index 87c25af8e5..68501f4ca2 100644 --- a/pgxn/neon/communicator_new.c +++ b/pgxn/neon/communicator_new.c @@ -142,6 +142,7 @@ static bool bounce_needed(void *buffer); static void *bounce_buf(void); static void *bounce_write_if_needed(void *buffer); +static void pump_logging(struct LoggingReceiver *logging); PGDLLEXPORT void communicator_new_bgworker_main(Datum main_arg); static void communicator_new_backend_exit(int code, Datum arg); @@ -184,6 +185,9 @@ pg_init_communicator_new(void) { BackgroundWorker bgw; + if (!neon_use_communicator_worker) + return; + if (pageserver_connstring[0] == '\0' && pageserver_grpc_urls[0] == '\0') { /* running with local storage */ @@ -211,6 +215,9 @@ communicator_new_shmem_size(void) size_t size = 0; int num_request_slots; + if (!neon_use_communicator_worker) + return 0; + size += MAXALIGN( offsetof(CommunicatorShmemData, backends) + MaxProcs * sizeof(CommunicatorShmemPerBackendData) @@ -225,13 +232,16 @@ communicator_new_shmem_size(void) } void -communicator_new_shmem_request(void) +CommunicatorNewShmemRequest(void) { + if (!neon_use_communicator_worker) + return; + RequestAddinShmemSpace(communicator_new_shmem_size()); } void -communicator_new_shmem_startup(void) +CommunicatorNewShmemInit(void) { bool found; int pipefd[2]; @@ -242,6 +252,9 @@ communicator_new_shmem_startup(void) uint64 initial_file_cache_size; uint64 max_file_cache_size; + if (!neon_use_communicator_worker) + return; + rc = pipe(pipefd); if (rc != 0) ereport(ERROR, @@ -294,10 +307,8 @@ communicator_new_bgworker_main(Datum main_arg) { char **connstrings; ShardMap shard_map; - struct LoggingState *logging; - char errbuf[1000]; - int elevel; uint64 file_cache_size; + struct LoggingReceiver *logging; const struct CommunicatorWorkerProcessStruct *proc_handle; /* @@ -334,8 +345,21 @@ communicator_new_bgworker_main(Datum main_arg) for (int i = 0; i < shard_map.num_shards; i++) connstrings[i] = shard_map.connstring[i]; - logging = configure_logging(); + /* + * By default, INFO messages are not printed to the log. We want + * `tracing::info!` messages emitted from the communicator to be printed, + * however, so increase the log level. + * + * XXX: This overrides any user-set value from the config file. That's not + * great, but on the other hand, there should be little reason for user to + * control the verbosity of the communicator. It's not too verbose by + * default. + */ + SetConfigOption("log_min_messages", "INFO", PGC_SUSET, PGC_S_OVERRIDE); + logging = communicator_worker_configure_logging(); + + elog(LOG, "launching worker process threads"); proc_handle = communicator_worker_process_launch( cis, neon_tenant, @@ -348,11 +372,27 @@ communicator_new_bgworker_main(Datum main_arg) file_cache_size); pfree(connstrings); cis = NULL; + if (proc_handle == NULL) + { + /* + * Something went wrong. Before exiting, forward any log messages that + * might've been generated during the failed launch. + */ + pump_logging(logging); + + elog(PANIC, "failure launching threads"); + } elog(LOG, "communicator threads started"); for (;;) { - int32 rc; + ResetLatch(MyLatch); + + /* + * Forward any log messages from the Rust threads into the normal + * Postgres logging facility. + */ + pump_logging(logging); CHECK_FOR_INTERRUPTS(); @@ -383,37 +423,81 @@ communicator_new_bgworker_main(Datum main_arg) pfree(connstrings); } - for (;;) - { - rc = pump_logging(logging, (uint8 *) errbuf, sizeof(errbuf), &elevel); - if (rc == 0) - { - /* nothing to do */ - break; - } - else if (rc == 1) - { - /* Because we don't want to exit on error */ - if (elevel == ERROR) - elevel = LOG; - if (elevel == INFO) - elevel = LOG; - elog(elevel, "[COMMUNICATOR] %s", errbuf); - } - else if (rc == -1) - { - elog(ERROR, "logging channel was closed unexpectedly"); - } - } - (void) WaitLatch(MyLatch, WL_LATCH_SET | WL_EXIT_ON_PM_DEATH, 0, PG_WAIT_EXTENSION); - ResetLatch(MyLatch); } } +static void +pump_logging(struct LoggingReceiver *logging) +{ + char errbuf[1000]; + int elevel; + int32 rc; + static uint64_t last_dropped_event_count = 0; + uint64_t dropped_event_count; + uint64_t dropped_now; + + for (;;) + { + rc = communicator_worker_poll_logging(logging, + errbuf, + sizeof(errbuf), + &elevel, + &dropped_event_count); + if (rc == 0) + { + /* nothing to do */ + break; + } + else if (rc == 1) + { + /* Because we don't want to exit on error */ + + if (message_level_is_interesting(elevel)) + { + /* + * Prevent interrupts while cleaning up. + * + * (Not sure if this is required, but all the error handlers + * in Postgres that are installed as sigsetjmp() targets do + * this, so let's follow the example) + */ + HOLD_INTERRUPTS(); + + errstart(elevel, TEXTDOMAIN); + errmsg_internal("[COMMUNICATOR] %s", errbuf); + EmitErrorReport(); + FlushErrorState(); + + /* Now we can allow interrupts again */ + RESUME_INTERRUPTS(); + } + } + else if (rc == -1) + { + elog(ERROR, "logging channel was closed unexpectedly"); + } + } + + /* + * If the queue was full at any time since the last time we reported it, + * report how many messages were lost. We do this outside the loop, so + * that if the logging system is clogged, we don't exacerbate it by + * printing lots of warnings about dropped messages. + */ + dropped_now = dropped_event_count - last_dropped_event_count; + if (dropped_now != 0) + { + elog(WARNING, "%lu communicator log messages were dropped because the log buffer was full", + (unsigned long) dropped_now); + last_dropped_event_count = dropped_event_count; + } +} + + /* * Callbacks from the rust code, in the communicator process. * diff --git a/pgxn/neon/communicator_new.h b/pgxn/neon/communicator_new.h index a19feaaac6..7fbc167f0f 100644 --- a/pgxn/neon/communicator_new.h +++ b/pgxn/neon/communicator_new.h @@ -20,8 +20,8 @@ /* initialization at postmaster startup */ extern void pg_init_communicator_new(void); -extern void communicator_new_shmem_request(void); -extern void communicator_new_shmem_startup(void); +extern void CommunicatorNewShmemRequest(void); +extern void CommunicatorNewShmemInit(void); /* initialization at backend startup */ extern void communicator_new_init(void); diff --git a/pgxn/neon/file_cache.c b/pgxn/neon/file_cache.c index 754b1ca033..91d8dac274 100644 --- a/pgxn/neon/file_cache.c +++ b/pgxn/neon/file_cache.c @@ -219,10 +219,6 @@ char *lfc_path; static uint64 lfc_generation; static FileCacheControl *lfc_ctl; static bool lfc_do_prewarm; -static shmem_startup_hook_type prev_shmem_startup_hook; -#if PG_VERSION_NUM>=150000 -static shmem_request_hook_type prev_shmem_request_hook; -#endif bool lfc_store_prefetch_result; bool lfc_prewarm_update_ws_estimation; @@ -346,20 +342,17 @@ lfc_ensure_opened(void) return true; } -static void -lfc_shmem_startup(void) +void +LfcShmemInit(void) { bool found; static HASHCTL info; - Assert(!neon_use_communicator_worker); + if (neon_use_communicator_worker) + return; - if (prev_shmem_startup_hook) - { - prev_shmem_startup_hook(); - } - - LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE); + if (lfc_max_size <= 0) + return; lfc_ctl = (FileCacheControl *) ShmemInitStruct("lfc", sizeof(FileCacheControl), &found); if (!found) @@ -404,19 +397,16 @@ lfc_shmem_startup(void) ConditionVariableInit(&lfc_ctl->cv[i]); } - LWLockRelease(AddinShmemInitLock); } -static void -lfc_shmem_request(void) +void +LfcShmemRequest(void) { -#if PG_VERSION_NUM>=150000 - if (prev_shmem_request_hook) - prev_shmem_request_hook(); -#endif - - RequestAddinShmemSpace(sizeof(FileCacheControl) + hash_estimate_size(SIZE_MB_TO_CHUNKS(lfc_max_size) + 1, FILE_CACHE_ENRTY_SIZE)); - RequestNamedLWLockTranche("lfc_lock", 1); + if (lfc_max_size > 0) + { + RequestAddinShmemSpace(sizeof(FileCacheControl) + hash_estimate_size(SIZE_MB_TO_CHUNKS(lfc_max_size) + 1, FILE_CACHE_ENRTY_SIZE)); + RequestNamedLWLockTranche("lfc_lock", 1); + } } static bool @@ -550,7 +540,6 @@ lfc_init(void) if (!process_shared_preload_libraries_in_progress) neon_log(ERROR, "Neon module should be loaded via shared_preload_libraries"); - DefineCustomBoolVariable("neon.store_prefetch_result_in_lfc", "Immediately store received prefetch result in LFC", NULL, @@ -648,21 +637,6 @@ lfc_init(void) NULL, NULL, NULL); - - if (lfc_max_size == 0) - return; - - if (neon_use_communicator_worker) - return; - - prev_shmem_startup_hook = shmem_startup_hook; - shmem_startup_hook = lfc_shmem_startup; -#if PG_VERSION_NUM>=150000 - prev_shmem_request_hook = shmem_request_hook; - shmem_request_hook = lfc_shmem_request; -#else - lfc_shmem_request(); -#endif } FileCacheState* diff --git a/pgxn/neon/libpagestore.c b/pgxn/neon/libpagestore.c index 1e41527fb5..690dfd8635 100644 --- a/pgxn/neon/libpagestore.c +++ b/pgxn/neon/libpagestore.c @@ -105,6 +105,11 @@ static int pageserver_response_disconnect_timeout = 150000; * has changed since last access, and to detect and retry copying the value if * the postmaster changes the value concurrently. (Postmaster doesn't have a * PGPROC entry and therefore cannot use LWLocks.) + * + * stripe_size is now also part of ShardMap, although it is defined by separate GUC. + * Postgres doesn't provide any mechanism to enforce dependencies between GUCs, + * that it we we have to rely on order of GUC definition in config file. + * "neon.stripe_size" should be defined prior to "neon.pageserver_connstring" */ typedef struct { @@ -113,10 +118,6 @@ typedef struct ShardMap shard_map; } PagestoreShmemState; -#if PG_VERSION_NUM >= 150000 -static shmem_request_hook_type prev_shmem_request_hook = NULL; -#endif -static shmem_startup_hook_type prev_shmem_startup_hook; static PagestoreShmemState *pagestore_shared; static uint64 pagestore_local_counter = 0; @@ -231,7 +232,10 @@ parse_shard_map(const char *connstr, ShardMap *result) p = sep + 1; } if (result) + { result->num_shards = nshards; + result->stripe_size = neon_stripe_size; + } return true; } @@ -317,12 +321,13 @@ AssignShardMap(const char *newval) * last call, terminates all existing connections to all pageservers. */ static void -load_shard_map(shardno_t shard_no, char *connstr_p, shardno_t *num_shards_p) +load_shard_map(shardno_t shard_no, char *connstr_p, shardno_t *num_shards_p, size_t* stripe_size_p) { uint64 begin_update_counter; uint64 end_update_counter; ShardMap *shard_map = &pagestore_shared->shard_map; shardno_t num_shards; + size_t stripe_size; /* * Postmaster can update the shared memory values concurrently, in which @@ -337,6 +342,7 @@ load_shard_map(shardno_t shard_no, char *connstr_p, shardno_t *num_shards_p) end_update_counter = pg_atomic_read_u64(&pagestore_shared->end_update_counter); num_shards = shard_map->num_shards; + stripe_size = shard_map->stripe_size; if (connstr_p && shard_no < MAX_SHARDS) strlcpy(connstr_p, shard_map->connstring[shard_no], MAX_PAGESERVER_CONNSTRING_SIZE); pg_memory_barrier(); @@ -371,6 +377,8 @@ load_shard_map(shardno_t shard_no, char *connstr_p, shardno_t *num_shards_p) if (num_shards_p) *num_shards_p = num_shards; + if (stripe_size_p) + *stripe_size_p = stripe_size; } #define MB (1024*1024) @@ -379,9 +387,10 @@ shardno_t get_shard_number(BufferTag *tag) { shardno_t n_shards; + size_t stripe_size; uint32 hash; - load_shard_map(0, NULL, &n_shards); + load_shard_map(0, NULL, &n_shards, &stripe_size); #if PG_MAJORVERSION_NUM < 16 hash = murmurhash32(tag->rnode.relNode); @@ -434,7 +443,7 @@ pageserver_connect(shardno_t shard_no, int elevel) * Note that connstr is used both during connection start, and when we * log the successful connection. */ - load_shard_map(shard_no, connstr, NULL); + load_shard_map(shard_no, connstr, NULL, NULL); switch (shard->state) { @@ -1306,18 +1315,12 @@ check_neon_id(char **newval, void **extra, GucSource source) return **newval == '\0' || HexDecodeString(id, *newval, 16); } -static Size -PagestoreShmemSize(void) -{ - return add_size(sizeof(PagestoreShmemState), NeonPerfCountersShmemSize()); -} -static bool +void PagestoreShmemInit(void) { bool found; - LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE); pagestore_shared = ShmemInitStruct("libpagestore shared state", sizeof(PagestoreShmemState), &found); @@ -1328,44 +1331,12 @@ PagestoreShmemInit(void) memset(&pagestore_shared->shard_map, 0, sizeof(ShardMap)); AssignPageserverConnstring(pageserver_connstring, NULL); } - - NeonPerfCountersShmemInit(); - - LWLockRelease(AddinShmemInitLock); - return found; } -static void -pagestore_shmem_startup_hook(void) +void +PagestoreShmemRequest(void) { - if (prev_shmem_startup_hook) - prev_shmem_startup_hook(); - - PagestoreShmemInit(); -} - -static void -pagestore_shmem_request(void) -{ -#if PG_VERSION_NUM >= 150000 - if (prev_shmem_request_hook) - prev_shmem_request_hook(); -#endif - - RequestAddinShmemSpace(PagestoreShmemSize()); -} - -static void -pagestore_prepare_shmem(void) -{ -#if PG_VERSION_NUM >= 150000 - prev_shmem_request_hook = shmem_request_hook; - shmem_request_hook = pagestore_shmem_request; -#else - pagestore_shmem_request(); -#endif - prev_shmem_startup_hook = shmem_startup_hook; - shmem_startup_hook = pagestore_shmem_startup_hook; + RequestAddinShmemSpace(sizeof(PagestoreShmemState)); } /* @@ -1374,8 +1345,6 @@ pagestore_prepare_shmem(void) void pg_init_libpagestore(void) { - pagestore_prepare_shmem(); - DefineCustomStringVariable("neon.pageserver_connstring", "connection string to the page server", NULL, @@ -1535,8 +1504,6 @@ pg_init_libpagestore(void) 0, NULL, NULL, NULL); - relsize_hash_init(); - if (page_server != NULL) neon_log(ERROR, "libpagestore already loaded"); diff --git a/pgxn/neon/neon.c b/pgxn/neon/neon.c index 10785f748f..8efea63e72 100644 --- a/pgxn/neon/neon.c +++ b/pgxn/neon/neon.c @@ -23,6 +23,7 @@ #include "replication/walsender.h" #include "storage/ipc.h" #include "storage/proc.h" +#include "storage/ipc.h" #include "funcapi.h" #include "access/htup_details.h" #include "utils/builtins.h" @@ -62,12 +63,13 @@ static void neon_ExecutorStart(QueryDesc *queryDesc, int eflags); static void neon_ExecutorEnd(QueryDesc *queryDesc); static shmem_startup_hook_type prev_shmem_startup_hook; -#if PG_VERSION_NUM>=150000 -static shmem_request_hook_type prev_shmem_request_hook; +static void neon_shmem_startup_hook(void); +static void neon_shmem_request_hook(void); + +#if PG_MAJORVERSION_NUM >= 15 +static shmem_request_hook_type prev_shmem_request_hook = NULL; #endif -static void neon_shmem_request(void); -static void neon_shmem_startup_hook(void); #if PG_MAJORVERSION_NUM >= 17 uint32 WAIT_EVENT_NEON_LFC_MAINTENANCE; @@ -457,15 +459,6 @@ _PG_init(void) load_file("$libdir/neon_rmgr", false); #endif - prev_shmem_startup_hook = shmem_startup_hook; - shmem_startup_hook = neon_shmem_startup_hook; -#if PG_VERSION_NUM>=150000 - prev_shmem_request_hook = shmem_request_hook; - shmem_request_hook = neon_shmem_request; -#else - neon_shmem_request(); -#endif - DefineCustomBoolVariable( "neon.use_communicator_worker", "Uses the communicator worker implementation", @@ -476,14 +469,45 @@ _PG_init(void) 0, NULL, NULL, NULL); + /* + * Initializing a pre-loaded Postgres extension happens in three stages: + * + * 1. _PG_init() is called early at postmaster startup. In this stage, no + * shared memory has been allocated yet. Core Postgres GUCs have been + * initialized from the config files, but notably, MaxBackends has not + * calculated yet. In this stage, we must register any extension GUCs + * and can do other early initialization that doesn't depend on shared + * memory. In this stage we must also register "shmem request" and + * "shmem starutup" hooks, to be called in stages 2 and 3. + * + * 2. After MaxBackends have been calculated, the "shmem request" hooks + * are called. The hooks can reserve shared memory by calling + * RequestAddinShmemSpace and RequestNamedLWLockTranche(). The "shmem + * request hooks" are a new mechanism in Postgres v15. In v14 and + * below, you had to make those Requests in stage 1 already, which + * means they could not depend on MaxBackends. (See hack in + * NeonPerfCountersShmemRequest()) + * + * 3. After some more runtime-computed GUCs that affect the amount of + * shared memory needed have been calculated, the "shmem startup" hooks + * are called. In this stage, we allocate any shared memory, LWLocks + * and other shared resources. + * + * Here, in the 'neon' extension, we register just one shmem request hook + * and one startup hook, which call into functions in all the subsystems + * that are part of the extension. On v14, the ShmemRequest functions are + * called in stage 1, and on v15 onwards they are called in stage 2. + */ + + /* Stage 1: Define GUCs, and other early intialization */ pg_init_libpagestore(); + relsize_hash_init(); lfc_init(); pg_init_walproposer(); init_lwlsncache(); pg_init_communicator(); - if (neon_use_communicator_worker) - pg_init_communicator_new(); + pg_init_communicator_new(); Custom_XLogReaderRoutines = NeonOnDemandXLogReaderRoutines; @@ -582,6 +606,22 @@ _PG_init(void) ReportSearchPath(); + /* + * Register initialization hooks for stage 2. (On v14, there's no "shmem + * request" hooks, so call the ShmemRequest functions immediately.) + */ +#if PG_VERSION_NUM >= 150000 + prev_shmem_request_hook = shmem_request_hook; + shmem_request_hook = neon_shmem_request_hook; +#else + neon_shmem_request_hook(); +#endif + + /* Register hooks for stage 3 */ + prev_shmem_startup_hook = shmem_startup_hook; + shmem_startup_hook = neon_shmem_startup_hook; + + /* Other misc initialization */ prev_ExecutorStart = ExecutorStart_hook; ExecutorStart_hook = neon_ExecutorStart; prev_ExecutorEnd = ExecutorEnd_hook; @@ -673,17 +713,35 @@ approximate_working_set_size(PG_FUNCTION_ARGS) PG_RETURN_INT32(dc); } +/* + * Initialization stage 2: make requests for the amount of shared memory we + * will need. + * + * For a high-level explanation of the initialization process, see _PG_init(). + */ static void -neon_shmem_request(void) +neon_shmem_request_hook(void) { -#if PG_VERSION_NUM>=150000 +#if PG_VERSION_NUM >= 150000 if (prev_shmem_request_hook) prev_shmem_request_hook(); #endif - communicator_new_shmem_request(); + LfcShmemRequest(); + NeonPerfCountersShmemRequest(); + PagestoreShmemRequest(); + RelsizeCacheShmemRequest(); + WalproposerShmemRequest(); + LwLsnCacheShmemRequest(); + CommunicatorNewShmemRequest(); } + +/* + * Initialization stage 3: Initialize shared memory. + * + * For a high-level explanation of the initialization process, see _PG_init(). + */ static void neon_shmem_startup_hook(void) { @@ -691,6 +749,16 @@ neon_shmem_startup_hook(void) if (prev_shmem_startup_hook) prev_shmem_startup_hook(); + LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE); + + LfcShmemInit(); + NeonPerfCountersShmemInit(); + PagestoreShmemInit(); + RelsizeCacheShmemInit(); + WalproposerShmemInit(); + LwLsnCacheShmemInit(); + CommunicatorNewShmemInit(); + #if PG_MAJORVERSION_NUM >= 17 WAIT_EVENT_NEON_LFC_MAINTENANCE = WaitEventExtensionNew("Neon/FileCache_Maintenance"); WAIT_EVENT_NEON_LFC_READ = WaitEventExtensionNew("Neon/FileCache_Read"); @@ -704,7 +772,7 @@ neon_shmem_startup_hook(void) WAIT_EVENT_NEON_WAL_DL = WaitEventExtensionNew("Neon/WAL_Download"); #endif - communicator_new_shmem_startup(); + LWLockRelease(AddinShmemInitLock); } /* diff --git a/pgxn/neon/neon.h b/pgxn/neon/neon.h index 215396ef7a..20c850864a 100644 --- a/pgxn/neon/neon.h +++ b/pgxn/neon/neon.h @@ -70,4 +70,19 @@ extern PGDLLEXPORT void WalProposerSync(int argc, char *argv[]); extern PGDLLEXPORT void WalProposerMain(Datum main_arg); extern PGDLLEXPORT void LogicalSlotsMonitorMain(Datum main_arg); +extern void LfcShmemRequest(void); +extern void PagestoreShmemRequest(void); +extern void RelsizeCacheShmemRequest(void); +extern void WalproposerShmemRequest(void); +extern void LwLsnCacheShmemRequest(void); +extern void NeonPerfCountersShmemRequest(void); + +extern void LfcShmemInit(void); +extern void PagestoreShmemInit(void); +extern void RelsizeCacheShmemInit(void); +extern void WalproposerShmemInit(void); +extern void LwLsnCacheShmemInit(void); +extern void NeonPerfCountersShmemInit(void); + + #endif /* NEON_H */ diff --git a/pgxn/neon/neon_lwlsncache.c b/pgxn/neon/neon_lwlsncache.c index a8cfa0f825..5887c02c36 100644 --- a/pgxn/neon/neon_lwlsncache.c +++ b/pgxn/neon/neon_lwlsncache.c @@ -1,5 +1,6 @@ #include "postgres.h" +#include "neon.h" #include "neon_lwlsncache.h" #include "miscadmin.h" @@ -81,14 +82,6 @@ static set_max_lwlsn_hook_type prev_set_max_lwlsn_hook = NULL; static set_lwlsn_relation_hook_type prev_set_lwlsn_relation_hook = NULL; static set_lwlsn_db_hook_type prev_set_lwlsn_db_hook = NULL; -static shmem_startup_hook_type prev_shmem_startup_hook; - -#if PG_VERSION_NUM >= 150000 -static shmem_request_hook_type prev_shmem_request_hook; -#endif - -static void shmemrequest(void); -static void shmeminit(void); static void neon_set_max_lwlsn(XLogRecPtr lsn); void @@ -99,16 +92,6 @@ init_lwlsncache(void) lwlc_register_gucs(); - prev_shmem_startup_hook = shmem_startup_hook; - shmem_startup_hook = shmeminit; - - #if PG_VERSION_NUM >= 150000 - prev_shmem_request_hook = shmem_request_hook; - shmem_request_hook = shmemrequest; - #else - shmemrequest(); - #endif - prev_set_lwlsn_block_range_hook = set_lwlsn_block_range_hook; set_lwlsn_block_range_hook = neon_set_lwlsn_block_range; prev_set_lwlsn_block_v_hook = set_lwlsn_block_v_hook; @@ -124,20 +107,19 @@ init_lwlsncache(void) } -static void shmemrequest(void) { +void +LwLsnCacheShmemRequest(void) +{ Size requested_size = sizeof(LwLsnCacheCtl); - + requested_size += hash_estimate_size(lwlsn_cache_size, sizeof(LastWrittenLsnCacheEntry)); RequestAddinShmemSpace(requested_size); - - #if PG_VERSION_NUM >= 150000 - if (prev_shmem_request_hook) - prev_shmem_request_hook(); - #endif } -static void shmeminit(void) { +void +LwLsnCacheShmemInit(void) +{ static HASHCTL info; bool found; if (lwlsn_cache_size > 0) @@ -157,9 +139,6 @@ static void shmeminit(void) { } dlist_init(&LwLsnCache->lastWrittenLsnLRU); LwLsnCache->maxLastWrittenLsn = GetRedoRecPtr(); - if (prev_shmem_startup_hook) { - prev_shmem_startup_hook(); - } } /* diff --git a/pgxn/neon/neon_perf_counters.c b/pgxn/neon/neon_perf_counters.c index d0a3d15108..dd576e4e73 100644 --- a/pgxn/neon/neon_perf_counters.c +++ b/pgxn/neon/neon_perf_counters.c @@ -17,22 +17,32 @@ #include "storage/shmem.h" #include "utils/builtins.h" +#include "neon.h" #include "neon_perf_counters.h" #include "neon_pgversioncompat.h" neon_per_backend_counters *neon_per_backend_counters_shared; -Size -NeonPerfCountersShmemSize(void) +void +NeonPerfCountersShmemRequest(void) { - Size size = 0; - - size = add_size(size, mul_size(NUM_NEON_PERF_COUNTER_SLOTS, - sizeof(neon_per_backend_counters))); - - return size; + Size size; +#if PG_MAJORVERSION_NUM < 15 + /* Hack: in PG14 MaxBackends is not initialized at the time of calling NeonPerfCountersShmemRequest function. + * Do it ourselves and then undo to prevent assertion failure + */ + Assert(MaxBackends == 0); /* not initialized yet */ + InitializeMaxBackends(); + size = mul_size(NUM_NEON_PERF_COUNTER_SLOTS, sizeof(neon_per_backend_counters)); + MaxBackends = 0; +#else + size = mul_size(NUM_NEON_PERF_COUNTER_SLOTS, sizeof(neon_per_backend_counters)); +#endif + RequestAddinShmemSpace(size); } + + void NeonPerfCountersShmemInit(void) { diff --git a/pgxn/neon/pagestore_client.h b/pgxn/neon/pagestore_client.h index 8ec8ce5408..47417a7bd5 100644 --- a/pgxn/neon/pagestore_client.h +++ b/pgxn/neon/pagestore_client.h @@ -250,6 +250,7 @@ typedef struct { char connstring[MAX_SHARDS][MAX_PAGESERVER_CONNSTRING_SIZE]; size_t num_shards; + size_t stripe_size; } ShardMap; extern bool parse_shard_map(const char *connstr, ShardMap *result); diff --git a/pgxn/neon/relsize_cache.c b/pgxn/neon/relsize_cache.c index f3ceec78cc..613e98f0d4 100644 --- a/pgxn/neon/relsize_cache.c +++ b/pgxn/neon/relsize_cache.c @@ -48,32 +48,23 @@ typedef struct * algorithm */ } RelSizeHashControl; -static HTAB *relsize_hash; -static LWLockId relsize_lock; -static int relsize_hash_size; -static RelSizeHashControl* relsize_ctl; -static shmem_startup_hook_type prev_shmem_startup_hook = NULL; -#if PG_VERSION_NUM >= 150000 -static shmem_request_hook_type prev_shmem_request_hook = NULL; -static void relsize_shmem_request(void); -#endif - /* * Size of a cache entry is 36 bytes. So this default will take about 2.3 MB, * which seems reasonable. */ #define DEFAULT_RELSIZE_HASH_SIZE (64 * 1024) -static void -neon_smgr_shmem_startup(void) +static HTAB *relsize_hash; +static LWLockId relsize_lock; +static int relsize_hash_size = DEFAULT_RELSIZE_HASH_SIZE; +static RelSizeHashControl* relsize_ctl; + +void +RelsizeCacheShmemInit(void) { static HASHCTL info; bool found; - if (prev_shmem_startup_hook) - prev_shmem_startup_hook(); - - LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE); relsize_ctl = (RelSizeHashControl *) ShmemInitStruct("relsize_hash", sizeof(RelSizeHashControl), &found); if (!found) { @@ -84,7 +75,6 @@ neon_smgr_shmem_startup(void) relsize_hash_size, relsize_hash_size, &info, HASH_ELEM | HASH_BLOBS); - LWLockRelease(AddinShmemInitLock); relsize_ctl->size = 0; relsize_ctl->hits = 0; relsize_ctl->misses = 0; @@ -249,34 +239,15 @@ relsize_hash_init(void) PGC_POSTMASTER, 0, NULL, NULL, NULL); - - if (relsize_hash_size > 0) - { -#if PG_VERSION_NUM >= 150000 - prev_shmem_request_hook = shmem_request_hook; - shmem_request_hook = relsize_shmem_request; -#else - RequestAddinShmemSpace(hash_estimate_size(relsize_hash_size, sizeof(RelSizeEntry))); - RequestNamedLWLockTranche("neon_relsize", 1); -#endif - - prev_shmem_startup_hook = shmem_startup_hook; - shmem_startup_hook = neon_smgr_shmem_startup; - } } -#if PG_VERSION_NUM >= 150000 /* * shmem_request hook: request additional shared resources. We'll allocate or * attach to the shared resources in neon_smgr_shmem_startup(). */ -static void -relsize_shmem_request(void) +void +RelsizeCacheShmemRequest(void) { - if (prev_shmem_request_hook) - prev_shmem_request_hook(); - RequestAddinShmemSpace(sizeof(RelSizeHashControl) + hash_estimate_size(relsize_hash_size, sizeof(RelSizeEntry))); RequestNamedLWLockTranche("neon_relsize", 1); } -#endif diff --git a/pgxn/neon/walproposer_pg.c b/pgxn/neon/walproposer_pg.c index 18655d4c6c..9ed8d0d2d2 100644 --- a/pgxn/neon/walproposer_pg.c +++ b/pgxn/neon/walproposer_pg.c @@ -83,10 +83,8 @@ static XLogRecPtr standby_flush_lsn = InvalidXLogRecPtr; static XLogRecPtr standby_apply_lsn = InvalidXLogRecPtr; static HotStandbyFeedback agg_hs_feedback; -static void nwp_shmem_startup_hook(void); static void nwp_register_gucs(void); static void assign_neon_safekeepers(const char *newval, void *extra); -static void nwp_prepare_shmem(void); static uint64 backpressure_lag_impl(void); static uint64 startup_backpressure_wrap(void); static bool backpressure_throttling_impl(void); @@ -99,11 +97,6 @@ static TimestampTz walprop_pg_get_current_timestamp(WalProposer *wp); static void walprop_pg_load_libpqwalreceiver(void); static process_interrupts_callback_t PrevProcessInterruptsCallback = NULL; -static shmem_startup_hook_type prev_shmem_startup_hook_type; -#if PG_VERSION_NUM >= 150000 -static shmem_request_hook_type prev_shmem_request_hook = NULL; -static void walproposer_shmem_request(void); -#endif static void WalproposerShmemInit_SyncSafekeeper(void); @@ -193,8 +186,6 @@ pg_init_walproposer(void) nwp_register_gucs(); - nwp_prepare_shmem(); - delay_backend_us = &startup_backpressure_wrap; PrevProcessInterruptsCallback = ProcessInterruptsCallback; ProcessInterruptsCallback = backpressure_throttling_impl; @@ -494,12 +485,11 @@ WalproposerShmemSize(void) return sizeof(WalproposerShmemState); } -static bool +void WalproposerShmemInit(void) { bool found; - LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE); walprop_shared = ShmemInitStruct("Walproposer shared state", sizeof(WalproposerShmemState), &found); @@ -517,9 +507,6 @@ WalproposerShmemInit(void) pg_atomic_init_u64(&walprop_shared->wal_rate_limiter.last_recorded_time_us, 0); /* END_HADRON */ } - LWLockRelease(AddinShmemInitLock); - - return found; } static void @@ -623,42 +610,15 @@ walprop_register_bgworker(void) /* shmem handling */ -static void -nwp_prepare_shmem(void) -{ -#if PG_VERSION_NUM >= 150000 - prev_shmem_request_hook = shmem_request_hook; - shmem_request_hook = walproposer_shmem_request; -#else - RequestAddinShmemSpace(WalproposerShmemSize()); -#endif - prev_shmem_startup_hook_type = shmem_startup_hook; - shmem_startup_hook = nwp_shmem_startup_hook; -} - -#if PG_VERSION_NUM >= 150000 /* * shmem_request hook: request additional shared resources. We'll allocate or - * attach to the shared resources in nwp_shmem_startup_hook(). + * attach to the shared resources in WalproposerShmemInit(). */ -static void -walproposer_shmem_request(void) +void +WalproposerShmemRequest(void) { - if (prev_shmem_request_hook) - prev_shmem_request_hook(); - RequestAddinShmemSpace(WalproposerShmemSize()); } -#endif - -static void -nwp_shmem_startup_hook(void) -{ - if (prev_shmem_startup_hook_type) - prev_shmem_startup_hook_type(); - - WalproposerShmemInit(); -} WalproposerShmemState * GetWalpropShmemState(void) diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index c812779e30..0ef09a8a9a 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -10,6 +10,7 @@ use tokio::time::Instant; use tracing::{debug, info}; use crate::config::ProjectInfoCacheOptions; +use crate::control_plane::messages::{ControlPlaneErrorMessage, Reason}; use crate::control_plane::{EndpointAccessControl, RoleAccessControl}; use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt}; use crate::types::{EndpointId, RoleName}; @@ -36,22 +37,37 @@ impl Entry { } pub(crate) fn get(&self) -> Option<&T> { - (self.expires_at > Instant::now()).then_some(&self.value) + (!self.is_expired()).then_some(&self.value) + } + + fn is_expired(&self) -> bool { + self.expires_at <= Instant::now() } } struct EndpointInfo { - role_controls: HashMap>, - controls: Option>, + role_controls: HashMap>>, + controls: Option>>, } +type ControlPlaneResult = Result>; + impl EndpointInfo { - pub(crate) fn get_role_secret(&self, role_name: RoleNameInt) -> Option { - self.role_controls.get(&role_name)?.get().cloned() + pub(crate) fn get_role_secret_with_ttl( + &self, + role_name: RoleNameInt, + ) -> Option<(ControlPlaneResult, Duration)> { + let entry = self.role_controls.get(&role_name)?; + let ttl = entry.expires_at - Instant::now(); + Some((entry.get()?.clone(), ttl)) } - pub(crate) fn get_controls(&self) -> Option { - self.controls.as_ref()?.get().cloned() + pub(crate) fn get_controls_with_ttl( + &self, + ) -> Option<(ControlPlaneResult, Duration)> { + let entry = self.controls.as_ref()?; + let ttl = entry.expires_at - Instant::now(); + Some((entry.get()?.clone(), ttl)) } pub(crate) fn invalidate_endpoint(&mut self) { @@ -153,28 +169,28 @@ impl ProjectInfoCacheImpl { self.cache.get(&endpoint_id) } - pub(crate) fn get_role_secret( + pub(crate) fn get_role_secret_with_ttl( &self, endpoint_id: &EndpointId, role_name: &RoleName, - ) -> Option { + ) -> Option<(ControlPlaneResult, Duration)> { let role_name = RoleNameInt::get(role_name)?; let endpoint_info = self.get_endpoint_cache(endpoint_id)?; - endpoint_info.get_role_secret(role_name) + endpoint_info.get_role_secret_with_ttl(role_name) } - pub(crate) fn get_endpoint_access( + pub(crate) fn get_endpoint_access_with_ttl( &self, endpoint_id: &EndpointId, - ) -> Option { + ) -> Option<(ControlPlaneResult, Duration)> { let endpoint_info = self.get_endpoint_cache(endpoint_id)?; - endpoint_info.get_controls() + endpoint_info.get_controls_with_ttl() } pub(crate) fn insert_endpoint_access( &self, account_id: Option, - project_id: ProjectIdInt, + project_id: Option, endpoint_id: EndpointIdInt, role_name: RoleNameInt, controls: EndpointAccessControl, @@ -183,26 +199,89 @@ impl ProjectInfoCacheImpl { if let Some(account_id) = account_id { self.insert_account2endpoint(account_id, endpoint_id); } - self.insert_project2endpoint(project_id, endpoint_id); + if let Some(project_id) = project_id { + self.insert_project2endpoint(project_id, endpoint_id); + } if self.cache.len() >= self.config.size { // If there are too many entries, wait until the next gc cycle. return; } - let controls = Entry::new(controls, self.config.ttl); - let role_controls = Entry::new(role_controls, self.config.ttl); + debug!( + key = &*endpoint_id, + "created a cache entry for endpoint access" + ); + + let controls = Some(Entry::new(Ok(controls), self.config.ttl)); + let role_controls = Entry::new(Ok(role_controls), self.config.ttl); match self.cache.entry(endpoint_id) { clashmap::Entry::Vacant(e) => { e.insert(EndpointInfo { role_controls: HashMap::from_iter([(role_name, role_controls)]), - controls: Some(controls), + controls, }); } clashmap::Entry::Occupied(mut e) => { let ep = e.get_mut(); - ep.controls = Some(controls); + ep.controls = controls; + if ep.role_controls.len() < self.config.max_roles { + ep.role_controls.insert(role_name, role_controls); + } + } + } + } + + pub(crate) fn insert_endpoint_access_err( + &self, + endpoint_id: EndpointIdInt, + role_name: RoleNameInt, + msg: Box, + ttl: Option, + ) { + if self.cache.len() >= self.config.size { + // If there are too many entries, wait until the next gc cycle. + return; + } + + debug!( + key = &*endpoint_id, + "created a cache entry for an endpoint access error" + ); + + let ttl = ttl.unwrap_or(self.config.ttl); + + let controls = if msg.get_reason() == Reason::RoleProtected { + // RoleProtected is the only role-specific error that control plane can give us. + // If a given role name does not exist, it still returns a successful response, + // just with an empty secret. + None + } else { + // We can cache all the other errors in EndpointInfo.controls, + // because they don't depend on what role name we pass to control plane. + Some(Entry::new(Err(msg.clone()), ttl)) + }; + + let role_controls = Entry::new(Err(msg), ttl); + + match self.cache.entry(endpoint_id) { + clashmap::Entry::Vacant(e) => { + e.insert(EndpointInfo { + role_controls: HashMap::from_iter([(role_name, role_controls)]), + controls, + }); + } + clashmap::Entry::Occupied(mut e) => { + let ep = e.get_mut(); + if let Some(entry) = &ep.controls + && !entry.is_expired() + && entry.value.is_ok() + { + // If we have cached non-expired, non-error controls, keep them. + } else { + ep.controls = controls; + } if ep.role_controls.len() < self.config.max_roles { ep.role_controls.insert(role_name, role_controls); } @@ -245,7 +324,7 @@ impl ProjectInfoCacheImpl { return; }; - if role_controls.get().expires_at <= Instant::now() { + if role_controls.get().is_expired() { role_controls.remove(); } } @@ -284,13 +363,11 @@ impl ProjectInfoCacheImpl { #[cfg(test)] mod tests { - use std::sync::Arc; - use super::*; - use crate::control_plane::messages::EndpointRateLimitConfig; + use crate::control_plane::messages::{Details, EndpointRateLimitConfig, ErrorInfo, Status}; use crate::control_plane::{AccessBlockerFlags, AuthSecret}; use crate::scram::ServerSecret; - use crate::types::ProjectId; + use std::sync::Arc; #[tokio::test] async fn test_project_info_cache_settings() { @@ -301,9 +378,9 @@ mod tests { ttl: Duration::from_secs(1), gc_interval: Duration::from_secs(600), }); - let project_id: ProjectId = "project".into(); + let project_id: Option = Some(ProjectIdInt::from(&"project".into())); let endpoint_id: EndpointId = "endpoint".into(); - let account_id: Option = None; + let account_id = None; let user1: RoleName = "user1".into(); let user2: RoleName = "user2".into(); @@ -316,7 +393,7 @@ mod tests { cache.insert_endpoint_access( account_id, - (&project_id).into(), + project_id, (&endpoint_id).into(), (&user1).into(), EndpointAccessControl { @@ -332,7 +409,7 @@ mod tests { cache.insert_endpoint_access( account_id, - (&project_id).into(), + project_id, (&endpoint_id).into(), (&user2).into(), EndpointAccessControl { @@ -346,11 +423,17 @@ mod tests { }, ); - let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap(); - assert_eq!(cached.secret, secret1); + let (cached, ttl) = cache + .get_role_secret_with_ttl(&endpoint_id, &user1) + .unwrap(); + assert_eq!(cached.unwrap().secret, secret1); + assert_eq!(ttl, cache.config.ttl); - let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap(); - assert_eq!(cached.secret, secret2); + let (cached, ttl) = cache + .get_role_secret_with_ttl(&endpoint_id, &user2) + .unwrap(); + assert_eq!(cached.unwrap().secret, secret2); + assert_eq!(ttl, cache.config.ttl); // Shouldn't add more than 2 roles. let user3: RoleName = "user3".into(); @@ -358,7 +441,7 @@ mod tests { cache.insert_endpoint_access( account_id, - (&project_id).into(), + project_id, (&endpoint_id).into(), (&user3).into(), EndpointAccessControl { @@ -372,17 +455,144 @@ mod tests { }, ); - assert!(cache.get_role_secret(&endpoint_id, &user3).is_none()); + assert!( + cache + .get_role_secret_with_ttl(&endpoint_id, &user3) + .is_none() + ); - let cached = cache.get_endpoint_access(&endpoint_id).unwrap(); + let cached = cache + .get_endpoint_access_with_ttl(&endpoint_id) + .unwrap() + .0 + .unwrap(); assert_eq!(cached.allowed_ips, allowed_ips); tokio::time::advance(Duration::from_secs(2)).await; - let cached = cache.get_role_secret(&endpoint_id, &user1); + let cached = cache.get_role_secret_with_ttl(&endpoint_id, &user1); assert!(cached.is_none()); - let cached = cache.get_role_secret(&endpoint_id, &user2); + let cached = cache.get_role_secret_with_ttl(&endpoint_id, &user2); assert!(cached.is_none()); - let cached = cache.get_endpoint_access(&endpoint_id); + let cached = cache.get_endpoint_access_with_ttl(&endpoint_id); assert!(cached.is_none()); } + + #[tokio::test] + async fn test_caching_project_info_errors() { + let cache = ProjectInfoCacheImpl::new(ProjectInfoCacheOptions { + size: 10, + max_roles: 10, + ttl: Duration::from_secs(1), + gc_interval: Duration::from_secs(600), + }); + let project_id = Some(ProjectIdInt::from(&"project".into())); + let endpoint_id: EndpointId = "endpoint".into(); + let account_id = None; + + let user1: RoleName = "user1".into(); + let user2: RoleName = "user2".into(); + let secret = Some(AuthSecret::Scram(ServerSecret::mock([1; 32]))); + + let role_msg = Box::new(ControlPlaneErrorMessage { + error: "role is protected and cannot be used for password-based authentication" + .to_owned() + .into_boxed_str(), + http_status_code: http::StatusCode::NOT_FOUND, + status: Some(Status { + code: "PERMISSION_DENIED".to_owned().into_boxed_str(), + message: "role is protected and cannot be used for password-based authentication" + .to_owned() + .into_boxed_str(), + details: Details { + error_info: Some(ErrorInfo { + reason: Reason::RoleProtected, + }), + retry_info: None, + user_facing_message: None, + }, + }), + }); + + let generic_msg = Box::new(ControlPlaneErrorMessage { + error: "oh noes".to_owned().into_boxed_str(), + http_status_code: http::StatusCode::NOT_FOUND, + status: None, + }); + + let get_role_secret = |endpoint_id, role_name| { + cache + .get_role_secret_with_ttl(endpoint_id, role_name) + .unwrap() + .0 + }; + let get_endpoint_access = + |endpoint_id| cache.get_endpoint_access_with_ttl(endpoint_id).unwrap().0; + + // stores role-specific errors only for get_role_secret + cache.insert_endpoint_access_err( + (&endpoint_id).into(), + (&user1).into(), + role_msg.clone(), + None, + ); + assert_eq!( + get_role_secret(&endpoint_id, &user1).unwrap_err().error, + role_msg.error + ); + assert!(cache.get_endpoint_access_with_ttl(&endpoint_id).is_none()); + + // stores non-role specific errors for both get_role_secret and get_endpoint_access + cache.insert_endpoint_access_err( + (&endpoint_id).into(), + (&user1).into(), + generic_msg.clone(), + None, + ); + assert_eq!( + get_role_secret(&endpoint_id, &user1).unwrap_err().error, + generic_msg.error + ); + assert_eq!( + get_endpoint_access(&endpoint_id).unwrap_err().error, + generic_msg.error + ); + + // error isn't returned for other roles in the same endpoint + assert!( + cache + .get_role_secret_with_ttl(&endpoint_id, &user2) + .is_none() + ); + + // success for a role does not overwrite errors for other roles + cache.insert_endpoint_access( + account_id, + project_id, + (&endpoint_id).into(), + (&user2).into(), + EndpointAccessControl { + allowed_ips: Arc::new(vec![]), + allowed_vpce: Arc::new(vec![]), + flags: AccessBlockerFlags::default(), + rate_limits: EndpointRateLimitConfig::default(), + }, + RoleAccessControl { + secret: secret.clone(), + }, + ); + assert!(get_role_secret(&endpoint_id, &user1).is_err()); + assert!(get_role_secret(&endpoint_id, &user2).is_ok()); + // ...but does clear the access control error + assert!(get_endpoint_access(&endpoint_id).is_ok()); + + // storing an error does not overwrite successful access control response + cache.insert_endpoint_access_err( + (&endpoint_id).into(), + (&user2).into(), + generic_msg.clone(), + None, + ); + assert!(get_role_secret(&endpoint_id, &user2).is_err()); + assert!(get_endpoint_access(&endpoint_id).is_ok()); + } } diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index 77062d3bb4..f25121331f 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -32,8 +32,11 @@ use crate::util::run_until; type IpSubnetKey = IpNet; -const CANCEL_KEY_TTL: Duration = Duration::from_secs(600); -const CANCEL_KEY_REFRESH: Duration = Duration::from_secs(570); +/// Initial period and TTL is shorter to clear keys of short-lived connections faster. +const CANCEL_KEY_INITIAL_PERIOD: Duration = Duration::from_secs(60); +const CANCEL_KEY_REFRESH_PERIOD: Duration = Duration::from_secs(10 * 60); +/// `CANCEL_KEY_TTL_SLACK` is added to the periods to determine the actual TTL. +const CANCEL_KEY_TTL_SLACK: Duration = Duration::from_secs(30); // Message types for sending through mpsc channel pub enum CancelKeyOp { @@ -54,6 +57,24 @@ pub enum CancelKeyOp { }, } +impl CancelKeyOp { + const fn redis_msg_kind(&self) -> RedisMsgKind { + match self { + CancelKeyOp::Store { .. } => RedisMsgKind::Set, + CancelKeyOp::Refresh { .. } => RedisMsgKind::Expire, + CancelKeyOp::Get { .. } => RedisMsgKind::Get, + CancelKeyOp::GetOld { .. } => RedisMsgKind::HGet, + } + } + + fn cancel_channel_metric_guard(&self) -> CancelChannelSizeGuard<'static> { + Metrics::get() + .proxy + .cancel_channel_size + .guard(self.redis_msg_kind()) + } +} + #[derive(thiserror::Error, Debug, Clone)] pub enum PipelineError { #[error("could not send cmd to redis: {0}")] @@ -483,50 +504,49 @@ impl Session { let mut cancel = pin!(cancel); enum State { - Set, + Init, Refresh, } - let mut state = State::Set; + let mut state = State::Init; loop { - let guard_op = match state { - State::Set => { - let guard = Metrics::get() - .proxy - .cancel_channel_size - .guard(RedisMsgKind::Set); - let op = CancelKeyOp::Store { - key: self.key, - value: closure_json.clone(), - expire: CANCEL_KEY_TTL, - }; + let (op, mut wait_interval) = match state { + State::Init => { tracing::debug!( src=%self.key, dest=?cancel_closure.cancel_token, "registering cancellation key" ); - (guard, op) + ( + CancelKeyOp::Store { + key: self.key, + value: closure_json.clone(), + expire: CANCEL_KEY_INITIAL_PERIOD + CANCEL_KEY_TTL_SLACK, + }, + CANCEL_KEY_INITIAL_PERIOD, + ) } State::Refresh => { - let guard = Metrics::get() - .proxy - .cancel_channel_size - .guard(RedisMsgKind::Expire); - let op = CancelKeyOp::Refresh { - key: self.key, - expire: CANCEL_KEY_TTL, - }; tracing::debug!( src=%self.key, dest=?cancel_closure.cancel_token, "refreshing cancellation key" ); - (guard, op) + ( + CancelKeyOp::Refresh { + key: self.key, + expire: CANCEL_KEY_REFRESH_PERIOD + CANCEL_KEY_TTL_SLACK, + }, + CANCEL_KEY_REFRESH_PERIOD, + ) } }; - match tx.call(guard_op, cancel.as_mut()).await { + match tx + .call((op.cancel_channel_metric_guard(), op), cancel.as_mut()) + .await + { // SET returns OK Ok(Value::Okay) => { tracing::debug!( @@ -549,23 +569,23 @@ impl Session { Ok(_) => { // Any other response likely means the key expired. tracing::warn!(src=%self.key, "refreshing cancellation key failed"); - // Re-enter the SET loop to repush full data. - state = State::Set; + // Re-enter the SET loop quickly to repush full data. + state = State::Init; + wait_interval = Duration::ZERO; } // retry immediately. Err(BatchQueueError::Result(error)) => { tracing::warn!(?error, "error refreshing cancellation key"); // Small delay to prevent busy loop with high cpu and logging. - tokio::time::sleep(Duration::from_millis(10)).await; - continue; + wait_interval = Duration::from_millis(10); } Err(BatchQueueError::Cancelled(Err(_cancelled))) => break, } // wait before continuing. break immediately if cancelled. - if run_until(tokio::time::sleep(CANCEL_KEY_REFRESH), cancel.as_mut()) + if run_until(tokio::time::sleep(wait_interval), cancel.as_mut()) .await .is_err() { diff --git a/proxy/src/control_plane/client/cplane_proxy_v1.rs b/proxy/src/control_plane/client/cplane_proxy_v1.rs index bb785b8b0c..8a0403c0b0 100644 --- a/proxy/src/control_plane/client/cplane_proxy_v1.rs +++ b/proxy/src/control_plane/client/cplane_proxy_v1.rs @@ -68,6 +68,66 @@ impl NeonControlPlaneClient { self.endpoint.url().as_str() } + async fn get_and_cache_auth_info( + &self, + ctx: &RequestContext, + endpoint: &EndpointId, + role: &RoleName, + cache_key: &EndpointId, + extract: impl FnOnce(&EndpointAccessControl, &RoleAccessControl) -> T, + ) -> Result { + match self.do_get_auth_req(ctx, endpoint, role).await { + Ok(auth_info) => { + let control = EndpointAccessControl { + allowed_ips: Arc::new(auth_info.allowed_ips), + allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids), + flags: auth_info.access_blocker_flags, + rate_limits: auth_info.rate_limits, + }; + let role_control = RoleAccessControl { + secret: auth_info.secret, + }; + let res = extract(&control, &role_control); + + self.caches.project_info.insert_endpoint_access( + auth_info.account_id, + auth_info.project_id, + cache_key.into(), + role.into(), + control, + role_control, + ); + + if let Some(project_id) = auth_info.project_id { + ctx.set_project_id(project_id); + } + + Ok(res) + } + Err(err) => match err { + GetAuthInfoError::ApiError(ControlPlaneError::Message(ref msg)) => { + let retry_info = msg.status.as_ref().and_then(|s| s.details.retry_info); + + // If we can retry this error, do not cache it, + // unless we were given a retry delay. + if msg.could_retry() && retry_info.is_none() { + return Err(err); + } + + self.caches.project_info.insert_endpoint_access_err( + cache_key.into(), + role.into(), + msg.clone(), + retry_info.map(|r| Duration::from_millis(r.retry_delay_ms)), + ); + + Err(err) + } + err => Err(err), + }, + } + } + async fn do_get_auth_req( &self, ctx: &RequestContext, @@ -284,43 +344,34 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { ctx: &RequestContext, endpoint: &EndpointId, role: &RoleName, - ) -> Result { - let normalized_ep = &endpoint.normalize(); - if let Some(secret) = self + ) -> Result { + let key = endpoint.normalize(); + + if let Some((role_control, ttl)) = self .caches .project_info - .get_role_secret(normalized_ep, role) + .get_role_secret_with_ttl(&key, role) { - return Ok(secret); + return match role_control { + Err(mut msg) => { + info!(key = &*key, "found cached get_role_access_control error"); + + // if retry_delay_ms is set change it to the remaining TTL + replace_retry_delay_ms(&mut msg, |_| ttl.as_millis() as u64); + + Err(GetAuthInfoError::ApiError(ControlPlaneError::Message(msg))) + } + Ok(role_control) => { + debug!(key = &*key, "found cached role access control"); + Ok(role_control) + } + }; } - let auth_info = self.do_get_auth_req(ctx, endpoint, role).await?; - - let control = EndpointAccessControl { - allowed_ips: Arc::new(auth_info.allowed_ips), - allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids), - flags: auth_info.access_blocker_flags, - rate_limits: auth_info.rate_limits, - }; - let role_control = RoleAccessControl { - secret: auth_info.secret, - }; - - if let Some(project_id) = auth_info.project_id { - let normalized_ep_int = normalized_ep.into(); - - self.caches.project_info.insert_endpoint_access( - auth_info.account_id, - project_id, - normalized_ep_int, - role.into(), - control, - role_control.clone(), - ); - ctx.set_project_id(project_id); - } - - Ok(role_control) + self.get_and_cache_auth_info(ctx, endpoint, role, &key, |_, role_control| { + role_control.clone() + }) + .await } #[tracing::instrument(skip_all)] @@ -330,38 +381,30 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { endpoint: &EndpointId, role: &RoleName, ) -> Result { - let normalized_ep = &endpoint.normalize(); - if let Some(control) = self.caches.project_info.get_endpoint_access(normalized_ep) { - return Ok(control); + let key = endpoint.normalize(); + + if let Some((control, ttl)) = self.caches.project_info.get_endpoint_access_with_ttl(&key) { + return match control { + Err(mut msg) => { + info!( + key = &*key, + "found cached get_endpoint_access_control error" + ); + + // if retry_delay_ms is set change it to the remaining TTL + replace_retry_delay_ms(&mut msg, |_| ttl.as_millis() as u64); + + Err(GetAuthInfoError::ApiError(ControlPlaneError::Message(msg))) + } + Ok(control) => { + debug!(key = &*key, "found cached endpoint access control"); + Ok(control) + } + }; } - let auth_info = self.do_get_auth_req(ctx, endpoint, role).await?; - - let control = EndpointAccessControl { - allowed_ips: Arc::new(auth_info.allowed_ips), - allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids), - flags: auth_info.access_blocker_flags, - rate_limits: auth_info.rate_limits, - }; - let role_control = RoleAccessControl { - secret: auth_info.secret, - }; - - if let Some(project_id) = auth_info.project_id { - let normalized_ep_int = normalized_ep.into(); - - self.caches.project_info.insert_endpoint_access( - auth_info.account_id, - project_id, - normalized_ep_int, - role.into(), - control.clone(), - role_control, - ); - ctx.set_project_id(project_id); - } - - Ok(control) + self.get_and_cache_auth_info(ctx, endpoint, role, &key, |control, _| control.clone()) + .await } #[tracing::instrument(skip_all)] @@ -390,13 +433,9 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { info!(key = &*key, "found cached wake_compute error"); // if retry_delay_ms is set, reduce it by the amount of time it spent in cache - if let Some(status) = &mut msg.status { - if let Some(retry_info) = &mut status.details.retry_info { - retry_info.retry_delay_ms = retry_info - .retry_delay_ms - .saturating_sub(created_at.elapsed().as_millis() as u64) - } - } + replace_retry_delay_ms(&mut msg, |delay| { + delay.saturating_sub(created_at.elapsed().as_millis() as u64) + }); Err(WakeComputeError::ControlPlane(ControlPlaneError::Message( msg, @@ -478,6 +517,14 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { } } +fn replace_retry_delay_ms(msg: &mut ControlPlaneErrorMessage, f: impl FnOnce(u64) -> u64) { + if let Some(status) = &mut msg.status + && let Some(retry_info) = &mut status.details.retry_info + { + retry_info.retry_delay_ms = f(retry_info.retry_delay_ms); + } +} + /// Parse http response body, taking status code into account. fn parse_body serde::Deserialize<'a>>( status: StatusCode, diff --git a/proxy/src/control_plane/errors.rs b/proxy/src/control_plane/errors.rs index 12843e48c7..1e43010957 100644 --- a/proxy/src/control_plane/errors.rs +++ b/proxy/src/control_plane/errors.rs @@ -52,7 +52,7 @@ impl ReportableError for ControlPlaneError { | Reason::EndpointNotFound | Reason::EndpointDisabled | Reason::BranchNotFound - | Reason::InvalidEphemeralEndpointOptions => ErrorKind::User, + | Reason::WrongLsnOrTimestamp => ErrorKind::User, Reason::RateLimitExceeded => ErrorKind::ServiceRateLimit, diff --git a/proxy/src/control_plane/messages.rs b/proxy/src/control_plane/messages.rs index cf193ed268..d44d7efcc3 100644 --- a/proxy/src/control_plane/messages.rs +++ b/proxy/src/control_plane/messages.rs @@ -107,7 +107,7 @@ pub(crate) struct ErrorInfo { // Schema could also have `metadata` field, but it's not structured. Skip it for now. } -#[derive(Clone, Copy, Debug, Deserialize, Default)] +#[derive(Clone, Copy, Debug, Deserialize, Default, PartialEq, Eq)] pub(crate) enum Reason { /// RoleProtected indicates that the role is protected and the attempted operation is not permitted on protected roles. #[serde(rename = "ROLE_PROTECTED")] @@ -133,9 +133,9 @@ pub(crate) enum Reason { /// or that the subject doesn't have enough permissions to access the requested branch. #[serde(rename = "BRANCH_NOT_FOUND")] BranchNotFound, - /// InvalidEphemeralEndpointOptions indicates that the specified LSN or timestamp are wrong. - #[serde(rename = "INVALID_EPHEMERAL_OPTIONS")] - InvalidEphemeralEndpointOptions, + /// WrongLsnOrTimestamp indicates that the specified LSN or timestamp are wrong. + #[serde(rename = "WRONG_LSN_OR_TIMESTAMP")] + WrongLsnOrTimestamp, /// RateLimitExceeded indicates that the rate limit for the operation has been exceeded. #[serde(rename = "RATE_LIMIT_EXCEEDED")] RateLimitExceeded, @@ -205,7 +205,7 @@ impl Reason { | Reason::EndpointNotFound | Reason::EndpointDisabled | Reason::BranchNotFound - | Reason::InvalidEphemeralEndpointOptions => false, + | Reason::WrongLsnOrTimestamp => false, // we were asked to go away Reason::RateLimitExceeded | Reason::NonDefaultBranchComputeTimeExceeded @@ -257,19 +257,19 @@ pub(crate) struct GetEndpointAccessControl { pub(crate) rate_limits: EndpointRateLimitConfig, } -#[derive(Copy, Clone, Deserialize, Default)] +#[derive(Copy, Clone, Deserialize, Default, Debug)] pub struct EndpointRateLimitConfig { pub connection_attempts: ConnectionAttemptsLimit, } -#[derive(Copy, Clone, Deserialize, Default)] +#[derive(Copy, Clone, Deserialize, Default, Debug)] pub struct ConnectionAttemptsLimit { pub tcp: Option, pub ws: Option, pub http: Option, } -#[derive(Copy, Clone, Deserialize)] +#[derive(Copy, Clone, Deserialize, Debug)] pub struct LeakyBucketSetting { pub rps: f64, pub burst: f64, diff --git a/proxy/src/control_plane/mod.rs b/proxy/src/control_plane/mod.rs index a8c59dad0c..9bbd3f4fb7 100644 --- a/proxy/src/control_plane/mod.rs +++ b/proxy/src/control_plane/mod.rs @@ -82,7 +82,7 @@ impl NodeInfo { } } -#[derive(Copy, Clone, Default)] +#[derive(Copy, Clone, Default, Debug)] pub(crate) struct AccessBlockerFlags { pub public_access_blocked: bool, pub vpc_access_blocked: bool, @@ -92,12 +92,12 @@ pub(crate) type NodeInfoCache = TimedLru>>; pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>; -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct RoleAccessControl { pub secret: Option, } -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct EndpointAccessControl { pub allowed_ips: Arc>, pub allowed_vpce: Arc>, diff --git a/safekeeper/src/auth.rs b/safekeeper/src/auth.rs index 81c79fae30..008f903a89 100644 --- a/safekeeper/src/auth.rs +++ b/safekeeper/src/auth.rs @@ -21,7 +21,8 @@ pub fn check_permission(claims: &Claims, tenant_id: Option) -> Result< | Scope::GenerationsApi | Scope::Infra | Scope::Scrubber - | Scope::ControllerPeer, + | Scope::ControllerPeer + | Scope::TenantEndpoint, _, ) => Err(AuthError( format!( diff --git a/storage_controller/Cargo.toml b/storage_controller/Cargo.toml index 143f4241f4..d67be6d469 100644 --- a/storage_controller/Cargo.toml +++ b/storage_controller/Cargo.toml @@ -52,6 +52,7 @@ tokio-rustls.workspace = true tokio-util.workspace = true tokio.workspace = true tracing.workspace = true +uuid.workspace = true measured.workspace = true rustls.workspace = true scopeguard.workspace = true @@ -63,6 +64,7 @@ tokio-postgres-rustls.workspace = true diesel = { version = "2.2.6", features = [ "serde_json", "chrono", + "uuid", ] } diesel-async = { version = "0.5.2", features = ["postgres", "bb8", "async-connection-wrapper"] } diesel_migrations = { version = "2.2.0" } diff --git a/storage_controller/migrations/2025-07-17-000001_hadron_safekeepers/down.sql b/storage_controller/migrations/2025-07-17-000001_hadron_safekeepers/down.sql new file mode 100644 index 0000000000..b45b45e438 --- /dev/null +++ b/storage_controller/migrations/2025-07-17-000001_hadron_safekeepers/down.sql @@ -0,0 +1,2 @@ +DROP TABLE hadron_safekeepers; +DROP TABLE hadron_timeline_safekeepers; diff --git a/storage_controller/migrations/2025-07-17-000001_hadron_safekeepers/up.sql b/storage_controller/migrations/2025-07-17-000001_hadron_safekeepers/up.sql new file mode 100644 index 0000000000..6cee981efc --- /dev/null +++ b/storage_controller/migrations/2025-07-17-000001_hadron_safekeepers/up.sql @@ -0,0 +1,17 @@ +-- hadron_safekeepers keep track of all Safe Keeper nodes that exist in the system. +-- Upon startup, each Safe Keeper reaches out to the hadron cluster coordinator to register its node ID and listen addresses. + +CREATE TABLE hadron_safekeepers ( + sk_node_id BIGINT PRIMARY KEY NOT NULL, + listen_http_addr VARCHAR NOT NULL, + listen_http_port INTEGER NOT NULL, + listen_pg_addr VARCHAR NOT NULL, + listen_pg_port INTEGER NOT NULL +); + +CREATE TABLE hadron_timeline_safekeepers ( + timeline_id VARCHAR NOT NULL, + sk_node_id BIGINT NOT NULL, + legacy_endpoint_id UUID DEFAULT NULL, + PRIMARY KEY(timeline_id, sk_node_id) +); diff --git a/storage_controller/src/auth.rs b/storage_controller/src/auth.rs index ef47abf8c7..8f15f0f072 100644 --- a/storage_controller/src/auth.rs +++ b/storage_controller/src/auth.rs @@ -1,4 +1,5 @@ use utils::auth::{AuthError, Claims, Scope}; +use uuid::Uuid; pub fn check_permission(claims: &Claims, required_scope: Scope) -> Result<(), AuthError> { if claims.scope != required_scope { @@ -7,3 +8,14 @@ pub fn check_permission(claims: &Claims, required_scope: Scope) -> Result<(), Au Ok(()) } + +#[allow(dead_code)] +pub fn check_endpoint_permission(claims: &Claims, endpoint_id: Uuid) -> Result<(), AuthError> { + if claims.scope != Scope::TenantEndpoint { + return Err(AuthError("Scope mismatch. Permission denied".into())); + } + if claims.endpoint_id != Some(endpoint_id) { + return Err(AuthError("Endpoint id mismatch. Permission denied".into())); + } + Ok(()) +} diff --git a/storage_controller/src/compute_hook.rs b/storage_controller/src/compute_hook.rs index 81f7fedac8..9e4e528b16 100644 --- a/storage_controller/src/compute_hook.rs +++ b/storage_controller/src/compute_hook.rs @@ -5,6 +5,8 @@ use std::sync::Arc; use std::time::Duration; use anyhow::Context; +use compute_api::spec::PageserverProtocol; +use compute_api::spec::PageserverShardInfo; use control_plane::endpoint::{ ComputeControlPlane, EndpointStatus, PageserverConnectionInfo, PageserverShardConnectionInfo, }; @@ -13,7 +15,7 @@ use futures::StreamExt; use hyper::StatusCode; use pageserver_api::config::DEFAULT_GRPC_LISTEN_PORT; use pageserver_api::controller_api::AvailabilityZone; -use pageserver_api::shard::{ShardCount, ShardNumber, ShardStripeSize, TenantShardId}; +use pageserver_api::shard::{ShardCount, ShardIndex, ShardNumber, ShardStripeSize, TenantShardId}; use postgres_connection::parse_host_port; use safekeeper_api::membership::SafekeeperGeneration; use serde::{Deserialize, Serialize}; @@ -507,7 +509,16 @@ impl ApiMethod for ComputeHookTenant { if endpoint.tenant_id == *tenant_id && endpoint.status() == EndpointStatus::Running { tracing::info!("Reconfiguring pageservers for endpoint {endpoint_name}"); - let mut shard_conninfos = HashMap::new(); + let shard_count = ShardCount(shards.len().try_into().expect("too many shards")); + + let mut shard_infos: HashMap = HashMap::new(); + + let prefer_protocol = if endpoint.grpc { + PageserverProtocol::Grpc + } else { + PageserverProtocol::Libpq + }; + for shard in shards.iter() { let ps_conf = env .get_pageserver_conf(shard.node_id) @@ -528,19 +539,31 @@ impl ApiMethod for ComputeHookTenant { None }; let pageserver = PageserverShardConnectionInfo { + id: Some(shard.node_id.to_string()), libpq_url, grpc_url, }; - shard_conninfos.insert(shard.shard_number.0 as u32, pageserver); + let shard_info = PageserverShardInfo { + pageservers: vec![pageserver], + }; + shard_infos.insert( + ShardIndex { + shard_number: shard.shard_number, + shard_count, + }, + shard_info, + ); } let pageserver_conninfo = PageserverConnectionInfo { - shards: shard_conninfos, - prefer_grpc: endpoint.grpc, + shard_count: ShardCount::unsharded(), + stripe_size: stripe_size.map(|val| val.0), + shards: shard_infos, + prefer_protocol, }; endpoint - .reconfigure_pageservers(pageserver_conninfo, *stripe_size) + .reconfigure_pageservers(&pageserver_conninfo) .await .map_err(NotifyError::NeonLocal)?; } @@ -824,6 +847,7 @@ impl ComputeHook { let send_locked = tokio::select! { guard = send_lock.lock_owned() => {guard}, _ = cancel.cancelled() => { + tracing::info!("Notification cancelled while waiting for lock"); return Err(NotifyError::ShuttingDown) } }; @@ -865,11 +889,32 @@ impl ComputeHook { let notify_url = compute_hook_url.as_ref().unwrap(); self.do_notify(notify_url, &request, cancel).await } else { - self.do_notify_local::(&request).await.map_err(|e| { + match self.do_notify_local::(&request).await.map_err(|e| { // This path is for testing only, so munge the error into our prod-style error type. - tracing::error!("neon_local notification hook failed: {e}"); - NotifyError::Fatal(StatusCode::INTERNAL_SERVER_ERROR) - }) + if e.to_string().contains("refresh-configuration-pending") { + // If the error message mentions "refresh-configuration-pending", it means the compute node + // rejected our notification request because it already trying to reconfigure itself. We + // can proceed with the rest of the reconcliation process as the compute node already + // discovers the need to reconfigure and will eventually update its configuration once + // we update the pageserver mappings. In fact, it is important that we continue with + // reconcliation to make sure we update the pageserver mappings to unblock the compute node. + tracing::info!("neon_local notification hook failed: {e}"); + tracing::info!("Notification failed likely due to compute node self-reconfiguration, will retry."); + Ok(()) + } else { + tracing::error!("neon_local notification hook failed: {e}"); + Err(NotifyError::Fatal(StatusCode::INTERNAL_SERVER_ERROR)) + } + }) { + // Compute node accepted the notification request. Ok to proceed. + Ok(_) => Ok(()), + // Compute node rejected our request but it is already self-reconfiguring. Ok to proceed. + Err(Ok(_)) => Ok(()), + // Fail the reconciliation attempt in all other cases. Recall that this whole code path involving + // neon_local is for testing only. In production we always retry failed reconcliations so we + // don't have any deadends here. + Err(Err(e)) => Err(e), + } }; match result { diff --git a/storage_controller/src/hadron_utils.rs b/storage_controller/src/hadron_utils.rs new file mode 100644 index 0000000000..871e21c367 --- /dev/null +++ b/storage_controller/src/hadron_utils.rs @@ -0,0 +1,44 @@ +use std::collections::BTreeMap; + +use rand::Rng; +use utils::shard::TenantShardId; + +static CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789!@#$%^&*()"; + +/// Generate a random string of `length` that can be used as a password. The generated string +/// contains alphanumeric characters and special characters (!@#$%^&*()) +pub fn generate_random_password(length: usize) -> String { + let mut rng = rand::thread_rng(); + (0..length) + .map(|_| { + let idx = rng.gen_range(0..CHARSET.len()); + CHARSET[idx] as char + }) + .collect() +} + +pub(crate) struct TenantShardSizeMap { + #[expect(dead_code)] + pub map: BTreeMap, +} + +impl TenantShardSizeMap { + pub fn new(map: BTreeMap) -> Self { + Self { map } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_generate_random_password() { + let pwd1 = generate_random_password(10); + assert_eq!(pwd1.len(), 10); + let pwd2 = generate_random_password(10); + assert_ne!(pwd1, pwd2); + assert!(pwd1.chars().all(|c| CHARSET.contains(&(c as u8)))); + assert!(pwd2.chars().all(|c| CHARSET.contains(&(c as u8)))); + } +} diff --git a/storage_controller/src/http.rs b/storage_controller/src/http.rs index c8227f0219..ff73719adb 100644 --- a/storage_controller/src/http.rs +++ b/storage_controller/src/http.rs @@ -48,7 +48,10 @@ use crate::metrics::{ }; use crate::persistence::SafekeeperUpsert; use crate::reconciler::ReconcileError; -use crate::service::{LeadershipStatus, RECONCILE_TIMEOUT, STARTUP_RECONCILE_TIMEOUT, Service}; +use crate::service::{ + LeadershipStatus, RECONCILE_TIMEOUT, STARTUP_RECONCILE_TIMEOUT, Service, + TenantMutationLocations, +}; /// State available to HTTP request handlers pub struct HttpState { @@ -734,77 +737,104 @@ async fn handle_tenant_timeline_passthrough( path ); - // Find the node that holds shard zero - let (node, tenant_shard_id, consistent) = if tenant_or_shard_id.is_unsharded() { - service + let tenant_shard_id = if tenant_or_shard_id.is_unsharded() { + // If the request contains only tenant ID, find the node that holds shard zero + let (_, shard_id) = service .tenant_shard0_node(tenant_or_shard_id.tenant_id) - .await? + .await?; + shard_id } else { - let (node, consistent) = service.tenant_shard_node(tenant_or_shard_id).await?; - (node, tenant_or_shard_id, consistent) + tenant_or_shard_id }; - // Callers will always pass an unsharded tenant ID. Before proxying, we must - // rewrite this to a shard-aware shard zero ID. - let path = format!("{path}"); - let tenant_str = tenant_or_shard_id.tenant_id.to_string(); - let tenant_shard_str = format!("{tenant_shard_id}"); - let path = path.replace(&tenant_str, &tenant_shard_str); + let service_inner = service.clone(); - let latency = &METRICS_REGISTRY - .metrics_group - .storage_controller_passthrough_request_latency; + service.tenant_shard_remote_mutation(tenant_shard_id, |locations| async move { + let TenantMutationLocations(locations) = locations; + if locations.is_empty() { + return Err(ApiError::NotFound(anyhow::anyhow!("Tenant {} not found", tenant_or_shard_id.tenant_id).into())); + } - let path_label = path_without_ids(&path) - .split('/') - .filter(|token| !token.is_empty()) - .collect::>() - .join("_"); - let labels = PageserverRequestLabelGroup { - pageserver_id: &node.get_id().to_string(), - path: &path_label, - method: crate::metrics::Method::Get, - }; + let (tenant_or_shard_id, locations) = locations.into_iter().next().unwrap(); + let node = locations.latest.node; - let _timer = latency.start_timer(labels.clone()); + // Callers will always pass an unsharded tenant ID. Before proxying, we must + // rewrite this to a shard-aware shard zero ID. + let path = format!("{path}"); + let tenant_str = tenant_or_shard_id.tenant_id.to_string(); + let tenant_shard_str = format!("{tenant_shard_id}"); + let path = path.replace(&tenant_str, &tenant_shard_str); - let client = mgmt_api::Client::new( - service.get_http_client().clone(), - node.base_url(), - service.get_config().pageserver_jwt_token.as_deref(), - ); - let resp = client.op_raw(method, path).await.map_err(|e| - // We return 503 here because if we can't successfully send a request to the pageserver, - // either we aren't available or the pageserver is unavailable. - ApiError::ResourceUnavailable(format!("Error sending pageserver API request to {node}: {e}").into()))?; - - if !resp.status().is_success() { - let error_counter = &METRICS_REGISTRY + let latency = &METRICS_REGISTRY .metrics_group - .storage_controller_passthrough_request_error; - error_counter.inc(labels); - } + .storage_controller_passthrough_request_latency; - // Transform 404 into 503 if we raced with a migration - if resp.status() == reqwest::StatusCode::NOT_FOUND && !consistent { - // Rather than retry here, send the client a 503 to prompt a retry: this matches - // the pageserver's use of 503, and all clients calling this API should retry on 503. - return Err(ApiError::ResourceUnavailable( - format!("Pageserver {node} returned 404 due to ongoing migration, retry later").into(), - )); - } + let path_label = path_without_ids(&path) + .split('/') + .filter(|token| !token.is_empty()) + .collect::>() + .join("_"); + let labels = PageserverRequestLabelGroup { + pageserver_id: &node.get_id().to_string(), + path: &path_label, + method: crate::metrics::Method::Get, + }; - // We have a reqest::Response, would like a http::Response - let mut builder = hyper::Response::builder().status(map_reqwest_hyper_status(resp.status())?); - for (k, v) in resp.headers() { - builder = builder.header(k.as_str(), v.as_bytes()); - } + let _timer = latency.start_timer(labels.clone()); - let response = builder - .body(Body::wrap_stream(resp.bytes_stream())) - .map_err(|e| ApiError::InternalServerError(e.into()))?; + let client = mgmt_api::Client::new( + service_inner.get_http_client().clone(), + node.base_url(), + service_inner.get_config().pageserver_jwt_token.as_deref(), + ); + let resp = client.op_raw(method, path).await.map_err(|e| + // We return 503 here because if we can't successfully send a request to the pageserver, + // either we aren't available or the pageserver is unavailable. + ApiError::ResourceUnavailable(format!("Error sending pageserver API request to {node}: {e}").into()))?; - Ok(response) + if !resp.status().is_success() { + let error_counter = &METRICS_REGISTRY + .metrics_group + .storage_controller_passthrough_request_error; + error_counter.inc(labels); + } + let resp_staus = resp.status(); + + // We have a reqest::Response, would like a http::Response + let mut builder = hyper::Response::builder().status(map_reqwest_hyper_status(resp_staus)?); + for (k, v) in resp.headers() { + builder = builder.header(k.as_str(), v.as_bytes()); + } + let resp_bytes = resp + .bytes() + .await + .map_err(|e| ApiError::InternalServerError(e.into()))?; + // Inspect 404 errors: at this point, we know that the tenant exists, but the pageserver we route + // the request to might not yet be ready. Therefore, if it is a _tenant_ not found error, we can + // convert it into a 503. TODO: we should make this part of the check in `tenant_shard_remote_mutation`. + // However, `tenant_shard_remote_mutation` currently cannot inspect the HTTP error response body, + // so we have to do it here instead. + if resp_staus == reqwest::StatusCode::NOT_FOUND { + let resp_str = std::str::from_utf8(&resp_bytes) + .map_err(|e| ApiError::InternalServerError(e.into()))?; + // We only handle "tenant not found" errors; other 404s like timeline not found should + // be forwarded as-is. + if Service::is_tenant_not_found_error(resp_str, tenant_or_shard_id.tenant_id) { + // Rather than retry here, send the client a 503 to prompt a retry: this matches + // the pageserver's use of 503, and all clients calling this API should retry on 503. + return Err(ApiError::ResourceUnavailable( + format!( + "Pageserver {node} returned tenant 404 due to ongoing migration, retry later" + ) + .into(), + )); + } + } + let response = builder + .body(Body::from(resp_bytes)) + .map_err(|e| ApiError::InternalServerError(e.into()))?; + Ok(response) + }).await? } async fn handle_tenant_locate( @@ -1085,9 +1115,10 @@ async fn handle_node_delete(req: Request) -> Result, ApiErr let state = get_state(&req); let node_id: NodeId = parse_request_param(&req, "node_id")?; + let force: bool = parse_query_param(&req, "force")?.unwrap_or(false); json_response( StatusCode::OK, - state.service.start_node_delete(node_id).await?, + state.service.start_node_delete(node_id, force).await?, ) } diff --git a/storage_controller/src/lib.rs b/storage_controller/src/lib.rs index 36e3c5dc6c..24b06da83a 100644 --- a/storage_controller/src/lib.rs +++ b/storage_controller/src/lib.rs @@ -6,6 +6,7 @@ extern crate hyper0 as hyper; mod auth; mod background_node_operations; mod compute_hook; +pub mod hadron_utils; mod heartbeater; pub mod http; mod id_lock_map; diff --git a/storage_controller/src/metrics.rs b/storage_controller/src/metrics.rs index 8738386968..9c34b34044 100644 --- a/storage_controller/src/metrics.rs +++ b/storage_controller/src/metrics.rs @@ -76,8 +76,8 @@ pub(crate) struct StorageControllerMetricGroup { /// How many shards would like to reconcile but were blocked by concurrency limits pub(crate) storage_controller_pending_reconciles: measured::Gauge, - /// How many shards are keep-failing and will be ignored when considering to run optimizations - pub(crate) storage_controller_keep_failing_reconciles: measured::Gauge, + /// How many shards are stuck and will be ignored when considering to run optimizations + pub(crate) storage_controller_stuck_reconciles: measured::Gauge, /// HTTP request status counters for handled requests pub(crate) storage_controller_http_request_status: @@ -151,6 +151,29 @@ pub(crate) struct StorageControllerMetricGroup { /// Indicator of completed safekeeper reconciles, broken down by safekeeper. pub(crate) storage_controller_safekeeper_reconciles_complete: measured::CounterVec, + + /* BEGIN HADRON */ + /// Hadron `config_watcher` reconciliation runs completed, broken down by success/failure. + pub(crate) storage_controller_config_watcher_complete: + measured::CounterVec, + + /// Hadron long waits for node state changes during drain and fill. + pub(crate) storage_controller_drain_and_fill_long_waits: measured::Counter, + + /// Set to 1 if we detect any page server pods with pending node pool rotation annotations. + /// Requires manual reset after oncall investigation. + pub(crate) storage_controller_ps_node_pool_rotation_pending: measured::Gauge, + + /// Hadron storage scrubber status. + pub(crate) storage_controller_storage_scrub_status: + measured::CounterVec, + + /// Desired number of pageservers managed by the storage controller + pub(crate) storage_controller_num_pageservers_desired: measured::Gauge, + + /// Desired number of safekeepers managed by the storage controller + pub(crate) storage_controller_num_safekeeper_desired: measured::Gauge, + /* END HADRON */ } impl StorageControllerMetrics { @@ -173,6 +196,10 @@ impl Default for StorageControllerMetrics { .storage_controller_reconcile_complete .init_all_dense(); + metrics_group + .storage_controller_config_watcher_complete + .init_all_dense(); + Self { metrics_group, encoder: Mutex::new(measured::text::BufferedTextEncoder::new()), @@ -262,11 +289,48 @@ pub(crate) struct ReconcileLongRunningLabelGroup<'a> { pub(crate) sequence: &'a str, } +#[derive(measured::LabelGroup, Clone)] +#[label(set = StorageScrubberLabelGroupSet)] +pub(crate) struct StorageScrubberLabelGroup<'a> { + #[label(dynamic_with = lasso::ThreadedRodeo, default)] + pub(crate) tenant_id: &'a str, + #[label(dynamic_with = lasso::ThreadedRodeo, default)] + pub(crate) shard_number: &'a str, + #[label(dynamic_with = lasso::ThreadedRodeo, default)] + pub(crate) timeline_id: &'a str, + pub(crate) outcome: StorageScrubberOutcome, +} + +#[derive(FixedCardinalityLabel, Clone, Copy)] +pub(crate) enum StorageScrubberOutcome { + PSOk, + PSWarning, + PSError, + PSOrphan, + SKOk, + SKError, +} + +#[derive(measured::LabelGroup)] +#[label(set = ConfigWatcherCompleteLabelGroupSet)] +pub(crate) struct ConfigWatcherCompleteLabelGroup { + // Reuse the ReconcileOutcome from the SC's reconciliation metrics. + pub(crate) status: ReconcileOutcome, +} + #[derive(FixedCardinalityLabel, Clone, Copy)] pub(crate) enum ReconcileOutcome { + // Successfully reconciled everything. #[label(rename = "ok")] Success, + // Used by tenant-shard reconciler only. Reconciled pageserver state successfully, + // but failed to delivery the compute notificiation. This error is typically transient + // but if its occurance keeps increasing, it should be investigated. + #[label(rename = "ok_no_notify")] + SuccessNoNotify, + // We failed to reconcile some state and the reconcilation will be retried. Error, + // Reconciliation was cancelled. Cancel, } diff --git a/storage_controller/src/node.rs b/storage_controller/src/node.rs index 6642c72f3c..63c82b5682 100644 --- a/storage_controller/src/node.rs +++ b/storage_controller/src/node.rs @@ -51,6 +51,39 @@ pub(crate) struct Node { cancel: CancellationToken, } +#[allow(dead_code)] +const ONE_MILLION: i64 = 1000000; + +// Converts a pool ID to a large number that can be used to assign unique IDs to pods in StatefulSets. +/// For example, if pool_id is 1, then the pods have NodeIds 1000000, 1000001, 1000002, etc. +/// If pool_id is None, then the pods have NodeIds 0, 1, 2, etc. +#[allow(dead_code)] +pub fn transform_pool_id(pool_id: Option) -> i64 { + match pool_id { + Some(id) => (id as i64) * ONE_MILLION, + None => 0, + } +} + +#[allow(dead_code)] +pub fn get_pool_id_from_node_id(node_id: i64) -> i32 { + (node_id / ONE_MILLION) as i32 +} + +/// Example pod name: page-server-0-1, safe-keeper-1-0 +#[allow(dead_code)] +pub fn get_node_id_from_pod_name(pod_name: &str) -> anyhow::Result { + let parts: Vec<&str> = pod_name.split('-').collect(); + if parts.len() != 4 { + return Err(anyhow::anyhow!("Invalid pod name: {}", pod_name)); + } + let pool_id = parts[2].parse::()?; + let node_offset = parts[3].parse::()?; + let node_id = transform_pool_id(Some(pool_id)) + node_offset; + + Ok(NodeId(node_id as u64)) +} + /// When updating [`Node::availability`] we use this type to indicate to the caller /// whether/how they changed it. pub(crate) enum AvailabilityTransition { @@ -403,3 +436,25 @@ impl std::fmt::Debug for Node { write!(f, "{} ({})", self.id, self.listen_http_addr) } } + +#[cfg(test)] +mod tests { + use utils::id::NodeId; + + use crate::node::get_node_id_from_pod_name; + + #[test] + fn test_get_node_id_from_pod_name() { + let pod_name = "page-server-3-12"; + let node_id = get_node_id_from_pod_name(pod_name).unwrap(); + assert_eq!(node_id, NodeId(3000012)); + + let pod_name = "safe-keeper-1-0"; + let node_id = get_node_id_from_pod_name(pod_name).unwrap(); + assert_eq!(node_id, NodeId(1000000)); + + let pod_name = "invalid-pod-name"; + let result = get_node_id_from_pod_name(pod_name); + assert!(result.is_err()); + } +} diff --git a/storage_controller/src/pageserver_client.rs b/storage_controller/src/pageserver_client.rs index da0687895a..9e829e252d 100644 --- a/storage_controller/src/pageserver_client.rs +++ b/storage_controller/src/pageserver_client.rs @@ -14,6 +14,8 @@ use reqwest::StatusCode; use utils::id::{NodeId, TenantId, TimelineId}; use utils::lsn::Lsn; +use crate::hadron_utils::TenantShardSizeMap; + /// Thin wrapper around [`pageserver_client::mgmt_api::Client`]. It allows the storage /// controller to collect metrics in a non-intrusive manner. #[derive(Debug, Clone)] @@ -86,6 +88,31 @@ impl PageserverClient { ) } + #[expect(dead_code)] + pub(crate) async fn tenant_timeline_compact( + &self, + tenant_shard_id: TenantShardId, + timeline_id: TimelineId, + force_image_layer_creation: bool, + wait_until_done: bool, + ) -> Result<()> { + measured_request!( + "tenant_timeline_compact", + crate::metrics::Method::Put, + &self.node_id_label, + self.inner + .tenant_timeline_compact( + tenant_shard_id, + timeline_id, + force_image_layer_creation, + true, + false, + wait_until_done, + ) + .await + ) + } + /* BEGIN_HADRON */ pub(crate) async fn tenant_timeline_describe( &self, @@ -101,6 +128,17 @@ impl PageserverClient { .await ) } + + #[expect(dead_code)] + pub(crate) async fn list_tenant_visible_size(&self) -> Result { + measured_request!( + "list_tenant_visible_size", + crate::metrics::Method::Get, + &self.node_id_label, + self.inner.list_tenant_visible_size().await + ) + .map(TenantShardSizeMap::new) + } /* END_HADRON */ pub(crate) async fn tenant_scan_remote_storage( @@ -365,6 +403,16 @@ impl PageserverClient { ) } + #[expect(dead_code)] + pub(crate) async fn reset_alert_gauges(&self) -> Result<()> { + measured_request!( + "reset_alert_gauges", + crate::metrics::Method::Post, + &self.node_id_label, + self.inner.reset_alert_gauges().await + ) + } + pub(crate) async fn wait_lsn( &self, tenant_shard_id: TenantShardId, diff --git a/storage_controller/src/reconciler.rs b/storage_controller/src/reconciler.rs index a2fba0fa56..d1590ec75e 100644 --- a/storage_controller/src/reconciler.rs +++ b/storage_controller/src/reconciler.rs @@ -862,11 +862,11 @@ impl Reconciler { Some(conf) if conf.conf.as_ref() == Some(&wanted_conf) => { if refreshed { tracing::info!( - node_id=%node.get_id(), "Observed configuration correct after refresh. Notifying compute."); + node_id=%node.get_id(), "[Attached] Observed configuration correct after refresh. Notifying compute."); self.compute_notify().await?; } else { // Nothing to do - tracing::info!(node_id=%node.get_id(), "Observed configuration already correct."); + tracing::info!(node_id=%node.get_id(), "[Attached] Observed configuration already correct."); } } observed => { @@ -945,17 +945,17 @@ impl Reconciler { match self.observed.locations.get(&node.get_id()) { Some(conf) if conf.conf.as_ref() == Some(&wanted_conf) => { // Nothing to do - tracing::info!(node_id=%node.get_id(), "Observed configuration already correct.") + tracing::info!(node_id=%node.get_id(), "[Secondary] Observed configuration already correct.") } _ => { // Only try and configure secondary locations on nodes that are available. This // allows the reconciler to "succeed" while some secondaries are offline (e.g. after // a node failure, where the failed node will have a secondary intent) if node.is_available() { - tracing::info!(node_id=%node.get_id(), "Observed configuration requires update."); + tracing::info!(node_id=%node.get_id(), "[Secondary] Observed configuration requires update."); changes.push((node.clone(), wanted_conf)) } else { - tracing::info!(node_id=%node.get_id(), "Skipping configuration as secondary, node is unavailable"); + tracing::info!(node_id=%node.get_id(), "[Secondary] Skipping configuration as secondary, node is unavailable"); self.observed .locations .insert(node.get_id(), ObservedStateLocation { conf: None }); @@ -1066,6 +1066,9 @@ impl Reconciler { } result } else { + tracing::info!( + "Compute notification is skipped because the tenant shard does not have an attached (primary) location" + ); Ok(()) } } diff --git a/storage_controller/src/schema.rs b/storage_controller/src/schema.rs index 312f7e0b0e..f3dcdaf798 100644 --- a/storage_controller/src/schema.rs +++ b/storage_controller/src/schema.rs @@ -13,6 +13,24 @@ diesel::table! { } } +diesel::table! { + hadron_safekeepers (sk_node_id) { + sk_node_id -> Int8, + listen_http_addr -> Varchar, + listen_http_port -> Int4, + listen_pg_addr -> Varchar, + listen_pg_port -> Int4, + } +} + +diesel::table! { + hadron_timeline_safekeepers (timeline_id, sk_node_id) { + timeline_id -> Varchar, + sk_node_id -> Int8, + legacy_endpoint_id -> Nullable, + } +} + diesel::table! { metadata_health (tenant_id, shard_number, shard_count) { tenant_id -> Varchar, @@ -105,6 +123,8 @@ diesel::table! { diesel::allow_tables_to_appear_in_same_query!( controllers, + hadron_safekeepers, + hadron_timeline_safekeepers, metadata_health, nodes, safekeeper_timeline_pending_ops, diff --git a/storage_controller/src/service.rs b/storage_controller/src/service.rs index 0c5d7f44d4..71186076ec 100644 --- a/storage_controller/src/service.rs +++ b/storage_controller/src/service.rs @@ -207,34 +207,13 @@ enum ShardGenerationValidity { }, } -/// We collect the state of attachments for some operations to determine if the operation -/// needs to be retried when it fails. -struct TenantShardAttachState { - /// The targets of the operation. - /// - /// Tenant shard ID, node ID, node, is intent node observed primary. - targets: Vec<(TenantShardId, NodeId, Node, bool)>, - - /// The targets grouped by node ID. - by_node_id: HashMap, -} - -impl TenantShardAttachState { - fn for_api_call(&self) -> Vec<(TenantShardId, Node)> { - self.targets - .iter() - .map(|(tenant_shard_id, _, node, _)| (*tenant_shard_id, node.clone())) - .collect() - } -} - pub const RECONCILER_CONCURRENCY_DEFAULT: usize = 128; pub const PRIORITY_RECONCILER_CONCURRENCY_DEFAULT: usize = 256; pub const SAFEKEEPER_RECONCILER_CONCURRENCY_DEFAULT: usize = 32; -// Number of consecutive reconciliation errors, occured for one shard, +// Number of consecutive reconciliations that have occurred for one shard, // after which the shard is ignored when considering to run optimizations. -const MAX_CONSECUTIVE_RECONCILIATION_ERRORS: usize = 5; +const MAX_CONSECUTIVE_RECONCILES: usize = 10; // Depth of the channel used to enqueue shards for reconciliation when they can't do it immediately. // This channel is finite-size to avoid using excessive memory if we get into a state where reconciles are finishing more slowly @@ -719,47 +698,70 @@ pub(crate) enum ReconcileResultRequest { } #[derive(Clone)] -struct MutationLocation { - node: Node, - generation: Generation, +pub(crate) struct MutationLocation { + pub(crate) node: Node, + pub(crate) generation: Generation, } #[derive(Clone)] -struct ShardMutationLocations { - latest: MutationLocation, - other: Vec, +pub(crate) struct ShardMutationLocations { + pub(crate) latest: MutationLocation, + pub(crate) other: Vec, } #[derive(Default, Clone)] -struct TenantMutationLocations(BTreeMap); +pub(crate) struct TenantMutationLocations(pub BTreeMap); struct ReconcileAllResult { spawned_reconciles: usize, - keep_failing_reconciles: usize, + stuck_reconciles: usize, has_delayed_reconciles: bool, } impl ReconcileAllResult { fn new( spawned_reconciles: usize, - keep_failing_reconciles: usize, + stuck_reconciles: usize, has_delayed_reconciles: bool, ) -> Self { assert!( - spawned_reconciles >= keep_failing_reconciles, - "It is impossible to have more keep-failing reconciles than spawned reconciles" + spawned_reconciles >= stuck_reconciles, + "It is impossible to have less spawned reconciles than stuck reconciles" ); Self { spawned_reconciles, - keep_failing_reconciles, + stuck_reconciles, has_delayed_reconciles, } } /// We can run optimizations only if we don't have any delayed reconciles and - /// all spawned reconciles are also keep-failing reconciles. + /// all spawned reconciles are also stuck reconciles. fn can_run_optimizations(&self) -> bool { - !self.has_delayed_reconciles && self.spawned_reconciles == self.keep_failing_reconciles + !self.has_delayed_reconciles && self.spawned_reconciles == self.stuck_reconciles + } +} + +enum TenantIdOrShardId { + TenantId(TenantId), + TenantShardId(TenantShardId), +} + +impl TenantIdOrShardId { + fn tenant_id(&self) -> TenantId { + match self { + TenantIdOrShardId::TenantId(tenant_id) => *tenant_id, + TenantIdOrShardId::TenantShardId(tenant_shard_id) => tenant_shard_id.tenant_id, + } + } + + fn matches(&self, tenant_shard_id: &TenantShardId) -> bool { + match self { + TenantIdOrShardId::TenantId(tenant_id) => tenant_shard_id.tenant_id == *tenant_id, + TenantIdOrShardId::TenantShardId(this_tenant_shard_id) => { + this_tenant_shard_id == tenant_shard_id + } + } } } @@ -1503,7 +1505,6 @@ impl Service { match result.result { Ok(()) => { - tenant.consecutive_errors_count = 0; tenant.apply_observed_deltas(deltas); tenant.waiter.advance(result.sequence); } @@ -1522,8 +1523,6 @@ impl Service { } } - tenant.consecutive_errors_count = tenant.consecutive_errors_count.saturating_add(1); - // Ordering: populate last_error before advancing error_seq, // so that waiters will see the correct error after waiting. tenant.set_last_error(result.sequence, e); @@ -1535,6 +1534,8 @@ impl Service { } } + tenant.consecutive_reconciles_count = tenant.consecutive_reconciles_count.saturating_add(1); + // If we just finished detaching all shards for a tenant, it might be time to drop it from memory. if tenant.policy == PlacementPolicy::Detached { // We may only drop a tenant from memory while holding the exclusive lock on the tenant ID: this protects us @@ -4773,72 +4774,24 @@ impl Service { Ok(()) } - fn is_observed_consistent_with_intent( - &self, - shard: &TenantShard, - intent_node_id: NodeId, - ) -> bool { - if let Some(location) = shard.observed.locations.get(&intent_node_id) - && let Some(ref conf) = location.conf - && (conf.mode == LocationConfigMode::AttachedSingle - || conf.mode == LocationConfigMode::AttachedMulti) - { - true - } else { - false - } - } - - fn collect_tenant_shards( - &self, - tenant_id: TenantId, - ) -> Result { - let locked = self.inner.read().unwrap(); - let mut targets = Vec::new(); - let mut by_node_id = HashMap::new(); - - // If the request got an unsharded tenant id, then apply - // the operation to all shards. Otherwise, apply it to a specific shard. - let shards_range = TenantShardId::tenant_range(tenant_id); - - for (tenant_shard_id, shard) in locked.tenants.range(shards_range) { - if let Some(node_id) = shard.intent.get_attached() { - let node = locked - .nodes - .get(node_id) - .expect("Pageservers may not be deleted while referenced"); - - let consistent = self.is_observed_consistent_with_intent(shard, *node_id); - - targets.push((*tenant_shard_id, *node_id, node.clone(), consistent)); - by_node_id.insert(*node_id, (*tenant_shard_id, node.clone(), consistent)); - } - } - - Ok(TenantShardAttachState { - targets, - by_node_id, - }) + pub(crate) fn is_tenant_not_found_error(body: &str, tenant_id: TenantId) -> bool { + body.contains(&format!("tenant {tenant_id}")) } fn process_result_and_passthrough_errors( &self, + tenant_id: TenantId, results: Vec<(Node, Result)>, - attach_state: TenantShardAttachState, ) -> Result, ApiError> { let mut processed_results: Vec<(Node, T)> = Vec::with_capacity(results.len()); - debug_assert_eq!(results.len(), attach_state.targets.len()); for (node, res) in results { - let is_consistent = attach_state - .by_node_id - .get(&node.get_id()) - .map(|(_, _, consistent)| *consistent); match res { Ok(res) => processed_results.push((node, res)), - Err(mgmt_api::Error::ApiError(StatusCode::NOT_FOUND, _)) - if is_consistent == Some(false) => + Err(mgmt_api::Error::ApiError(StatusCode::NOT_FOUND, body)) + if Self::is_tenant_not_found_error(&body, tenant_id) => { - // This is expected if the attach is not finished yet. Return 503 so that the client can retry. + // If there's a tenant not found, we are still in the process of attaching the tenant. + // Return 503 so that the client can retry. return Err(ApiError::ResourceUnavailable( format!( "Timeline is not attached to the pageserver {} yet, please retry", @@ -4866,35 +4819,48 @@ impl Service { ) .await; - let attach_state = self.collect_tenant_shards(tenant_id)?; - - let results = self - .tenant_for_shards_api( - attach_state.for_api_call(), - |tenant_shard_id, client| async move { - client - .timeline_lease_lsn(tenant_shard_id, timeline_id, lsn) - .await - }, - 1, - 1, - SHORT_RECONCILE_TIMEOUT, - &self.cancel, - ) - .await; - - let leases = self.process_result_and_passthrough_errors(results, attach_state)?; - let mut valid_until = None; - for (_, lease) in leases { - if let Some(ref mut valid_until) = valid_until { - *valid_until = std::cmp::min(*valid_until, lease.valid_until); - } else { - valid_until = Some(lease.valid_until); + self.tenant_remote_mutation(tenant_id, |locations| async move { + if locations.0.is_empty() { + return Err(ApiError::NotFound( + anyhow::anyhow!("Tenant not found").into(), + )); } - } - Ok(LsnLease { - valid_until: valid_until.unwrap_or_else(SystemTime::now), + + let results = self + .tenant_for_shards_api( + locations + .0 + .iter() + .map(|(tenant_shard_id, ShardMutationLocations { latest, .. })| { + (*tenant_shard_id, latest.node.clone()) + }) + .collect(), + |tenant_shard_id, client| async move { + client + .timeline_lease_lsn(tenant_shard_id, timeline_id, lsn) + .await + }, + 1, + 1, + SHORT_RECONCILE_TIMEOUT, + &self.cancel, + ) + .await; + + let leases = self.process_result_and_passthrough_errors(tenant_id, results)?; + let mut valid_until = None; + for (_, lease) in leases { + if let Some(ref mut valid_until) = valid_until { + *valid_until = std::cmp::min(*valid_until, lease.valid_until); + } else { + valid_until = Some(lease.valid_until); + } + } + Ok(LsnLease { + valid_until: valid_until.unwrap_or_else(SystemTime::now), + }) }) + .await? } pub(crate) async fn tenant_timeline_download_heatmap_layers( @@ -5041,11 +5007,37 @@ impl Service { /// - Looks up the shards and the nodes where they were most recently attached /// - Guarantees that after the inner function returns, the shards' generations haven't moved on: this /// ensures that the remote operation acted on the most recent generation, and is therefore durable. - async fn tenant_remote_mutation( + pub(crate) async fn tenant_remote_mutation( &self, tenant_id: TenantId, op: O, ) -> Result + where + O: FnOnce(TenantMutationLocations) -> F, + F: std::future::Future, + { + self.tenant_remote_mutation_inner(TenantIdOrShardId::TenantId(tenant_id), op) + .await + } + + pub(crate) async fn tenant_shard_remote_mutation( + &self, + tenant_shard_id: TenantShardId, + op: O, + ) -> Result + where + O: FnOnce(TenantMutationLocations) -> F, + F: std::future::Future, + { + self.tenant_remote_mutation_inner(TenantIdOrShardId::TenantShardId(tenant_shard_id), op) + .await + } + + async fn tenant_remote_mutation_inner( + &self, + tenant_id_or_shard_id: TenantIdOrShardId, + op: O, + ) -> Result where O: FnOnce(TenantMutationLocations) -> F, F: std::future::Future, @@ -5057,7 +5049,13 @@ impl Service { // run concurrently with reconciliations, and it is not guaranteed that the node we find here // will still be the latest when we're done: we will check generations again at the end of // this function to handle that. - let generations = self.persistence.tenant_generations(tenant_id).await?; + let generations = self + .persistence + .tenant_generations(tenant_id_or_shard_id.tenant_id()) + .await? + .into_iter() + .filter(|i| tenant_id_or_shard_id.matches(&i.tenant_shard_id)) + .collect::>(); if generations .iter() @@ -5071,9 +5069,14 @@ impl Service { // One or more shards has not been attached to a pageserver. Check if this is because it's configured // to be detached (409: caller should give up), or because it's meant to be attached but isn't yet (503: caller should retry) let locked = self.inner.read().unwrap(); - for (shard_id, shard) in - locked.tenants.range(TenantShardId::tenant_range(tenant_id)) - { + let tenant_shards = locked + .tenants + .range(TenantShardId::tenant_range( + tenant_id_or_shard_id.tenant_id(), + )) + .filter(|(shard_id, _)| tenant_id_or_shard_id.matches(shard_id)) + .collect::>(); + for (shard_id, shard) in tenant_shards { match shard.policy { PlacementPolicy::Attached(_) => { // This shard is meant to be attached: the caller is not wrong to try and @@ -5183,7 +5186,14 @@ impl Service { // Post-check: are all the generations of all the shards the same as they were initially? This proves that // our remote operation executed on the latest generation and is therefore persistent. { - let latest_generations = self.persistence.tenant_generations(tenant_id).await?; + let latest_generations = self + .persistence + .tenant_generations(tenant_id_or_shard_id.tenant_id()) + .await? + .into_iter() + .filter(|i| tenant_id_or_shard_id.matches(&i.tenant_shard_id)) + .collect::>(); + if latest_generations .into_iter() .map( @@ -5317,7 +5327,7 @@ impl Service { pub(crate) async fn tenant_shard0_node( &self, tenant_id: TenantId, - ) -> Result<(Node, TenantShardId, bool), ApiError> { + ) -> Result<(Node, TenantShardId), ApiError> { let tenant_shard_id = { let locked = self.inner.read().unwrap(); let Some((tenant_shard_id, _shard)) = locked @@ -5335,7 +5345,7 @@ impl Service { self.tenant_shard_node(tenant_shard_id) .await - .map(|(node, consistent)| (node, tenant_shard_id, consistent)) + .map(|node| (node, tenant_shard_id)) } /// When you need to send an HTTP request to the pageserver that holds a shard of a tenant, this @@ -5345,7 +5355,7 @@ impl Service { pub(crate) async fn tenant_shard_node( &self, tenant_shard_id: TenantShardId, - ) -> Result<(Node, bool), ApiError> { + ) -> Result { // Look up in-memory state and maybe use the node from there. { let locked = self.inner.read().unwrap(); @@ -5375,8 +5385,7 @@ impl Service { "Shard refers to nonexistent node" ))); }; - let consistent = self.is_observed_consistent_with_intent(shard, *intent_node_id); - return Ok((node.clone(), consistent)); + return Ok(node.clone()); } }; @@ -5411,7 +5420,7 @@ impl Service { ))); }; // As a reconciliation is in flight, we do not have the observed state yet, and therefore we assume it is always inconsistent. - Ok((node.clone(), false)) + Ok(node.clone()) } pub(crate) fn tenant_locate( @@ -7385,6 +7394,7 @@ impl Service { self: &Arc, node_id: NodeId, policy_on_start: NodeSchedulingPolicy, + force: bool, cancel: CancellationToken, ) -> Result<(), OperationError> { let reconciler_config = ReconcilerConfigBuilder::new(ReconcilerPriority::Normal).build(); @@ -7392,23 +7402,27 @@ impl Service { let mut waiters: Vec = Vec::new(); let mut tid_iter = create_shared_shard_iterator(self.clone()); + let reset_node_policy_on_cancel = || async { + match self + .node_configure(node_id, None, Some(policy_on_start)) + .await + { + Ok(()) => OperationError::Cancelled, + Err(err) => { + OperationError::FinalizeError( + format!( + "Failed to finalise delete cancel of {} by setting scheduling policy to {}: {}", + node_id, String::from(policy_on_start), err + ) + .into(), + ) + } + } + }; + while !tid_iter.finished() { if cancel.is_cancelled() { - match self - .node_configure(node_id, None, Some(policy_on_start)) - .await - { - Ok(()) => return Err(OperationError::Cancelled), - Err(err) => { - return Err(OperationError::FinalizeError( - format!( - "Failed to finalise delete cancel of {} by setting scheduling policy to {}: {}", - node_id, String::from(policy_on_start), err - ) - .into(), - )); - } - } + return Err(reset_node_policy_on_cancel().await); } operation_utils::validate_node_state( @@ -7477,8 +7491,18 @@ impl Service { nodes, reconciler_config, ); - if let Some(some) = waiter { - waiters.push(some); + + if force { + // Here we remove an existing observed location for the node we're removing, and it will + // not be re-added by a reconciler's completion because we filter out removed nodes in + // process_result. + // + // Note that we update the shard's observed state _after_ calling maybe_configured_reconcile_shard: + // that means any reconciles we spawned will know about the node we're deleting, + // enabling them to do live migrations if it's still online. + tenant_shard.observed.locations.remove(&node_id); + } else if let Some(waiter) = waiter { + waiters.push(waiter); } } } @@ -7492,21 +7516,7 @@ impl Service { while !waiters.is_empty() { if cancel.is_cancelled() { - match self - .node_configure(node_id, None, Some(policy_on_start)) - .await - { - Ok(()) => return Err(OperationError::Cancelled), - Err(err) => { - return Err(OperationError::FinalizeError( - format!( - "Failed to finalise drain cancel of {} by setting scheduling policy to {}: {}", - node_id, String::from(policy_on_start), err - ) - .into(), - )); - } - } + return Err(reset_node_policy_on_cancel().await); } tracing::info!("Awaiting {} pending delete reconciliations", waiters.len()); @@ -7516,6 +7526,12 @@ impl Service { .await; } + let pf = pausable_failpoint!("delete-node-after-reconciles-spawned", &cancel); + if pf.is_err() { + // An error from pausable_failpoint indicates the cancel token was triggered. + return Err(reset_node_policy_on_cancel().await); + } + self.persistence .set_tombstone(node_id) .await @@ -8111,6 +8127,7 @@ impl Service { pub(crate) async fn start_node_delete( self: &Arc, node_id: NodeId, + force: bool, ) -> Result<(), ApiError> { let (ongoing_op, node_policy, schedulable_nodes_count) = { let locked = self.inner.read().unwrap(); @@ -8180,7 +8197,7 @@ impl Service { tracing::info!("Delete background operation starting"); let res = service - .delete_node(node_id, policy_on_start, cancel) + .delete_node(node_id, policy_on_start, force, cancel) .await; match res { Ok(()) => { @@ -8632,7 +8649,7 @@ impl Service { // This function is an efficient place to update lazy statistics, since we are walking // all tenants. let mut pending_reconciles = 0; - let mut keep_failing_reconciles = 0; + let mut stuck_reconciles = 0; let mut az_violations = 0; // If we find any tenants to drop from memory, stash them to offload after @@ -8668,30 +8685,32 @@ impl Service { // Eventual consistency: if an earlier reconcile job failed, and the shard is still // dirty, spawn another one - let consecutive_errors_count = shard.consecutive_errors_count; if self .maybe_reconcile_shard(shard, &pageservers, ReconcilerPriority::Normal) .is_some() { spawned_reconciles += 1; - // Count shards that are keep-failing. We still want to reconcile them - // to avoid a situation where a shard is stuck. - // But we don't want to consider them when deciding to run optimizations. - if consecutive_errors_count >= MAX_CONSECUTIVE_RECONCILIATION_ERRORS { + if shard.consecutive_reconciles_count >= MAX_CONSECUTIVE_RECONCILES { + // Count shards that are stuck, butwe still want to reconcile them. + // We don't want to consider them when deciding to run optimizations. tracing::warn!( tenant_id=%shard.tenant_shard_id.tenant_id, shard_id=%shard.tenant_shard_id.shard_slug(), - "Shard reconciliation is keep-failing: {} errors", - consecutive_errors_count + "Shard reconciliation is stuck: {} consecutive launches", + shard.consecutive_reconciles_count ); - keep_failing_reconciles += 1; + stuck_reconciles += 1; + } + } else { + if shard.delayed_reconcile { + // Shard wanted to reconcile but for some reason couldn't. + pending_reconciles += 1; } - } else if shard.delayed_reconcile { - // Shard wanted to reconcile but for some reason couldn't. - pending_reconciles += 1; - } + // Reset the counter when we don't need to launch a reconcile. + shard.consecutive_reconciles_count = 0; + } // If this tenant is detached, try dropping it from memory. This is usually done // proactively in [`Self::process_results`], but we do it here to handle the edge // case where a reconcile completes while someone else is holding an op lock for the tenant. @@ -8727,14 +8746,10 @@ impl Service { metrics::METRICS_REGISTRY .metrics_group - .storage_controller_keep_failing_reconciles - .set(keep_failing_reconciles as i64); + .storage_controller_stuck_reconciles + .set(stuck_reconciles as i64); - ReconcileAllResult::new( - spawned_reconciles, - keep_failing_reconciles, - has_delayed_reconciles, - ) + ReconcileAllResult::new(spawned_reconciles, stuck_reconciles, has_delayed_reconciles) } /// `optimize` in this context means identifying shards which have valid scheduled locations, but diff --git a/storage_controller/src/tenant_shard.rs b/storage_controller/src/tenant_shard.rs index 99079c57b0..f60378470e 100644 --- a/storage_controller/src/tenant_shard.rs +++ b/storage_controller/src/tenant_shard.rs @@ -131,14 +131,16 @@ pub(crate) struct TenantShard { #[serde(serialize_with = "read_last_error")] pub(crate) last_error: std::sync::Arc>>>, - /// Number of consecutive reconciliation errors that have occurred for this shard. + /// Amount of consecutive [`crate::service::Service::reconcile_all`] iterations that have been + /// scheduled a reconciliation for this shard. /// - /// When this count reaches MAX_CONSECUTIVE_RECONCILIATION_ERRORS, the tenant shard - /// will be countered as keep-failing in `reconcile_all` calculations. This will lead to - /// allowing optimizations to run even with some failing shards. + /// If this reaches `MAX_CONSECUTIVE_RECONCILES`, the shard is considered "stuck" and will be + /// ignored when deciding whether optimizations can run. This includes both successful and failed + /// reconciliations. /// - /// The counter is reset to 0 after a successful reconciliation. - pub(crate) consecutive_errors_count: usize, + /// Incremented in [`crate::service::Service::process_result`], and reset to 0 when + /// [`crate::service::Service::reconcile_all`] determines no reconciliation is needed for this shard. + pub(crate) consecutive_reconciles_count: usize, /// If we have a pending compute notification that for some reason we weren't able to send, /// set this to true. If this is set, calls to [`Self::get_reconcile_needed`] will return Yes @@ -603,7 +605,7 @@ impl TenantShard { waiter: Arc::new(SeqWait::new(Sequence(0))), error_waiter: Arc::new(SeqWait::new(Sequence(0))), last_error: Arc::default(), - consecutive_errors_count: 0, + consecutive_reconciles_count: 0, pending_compute_notification: false, scheduling_policy: ShardSchedulingPolicy::default(), preferred_node: None, @@ -1609,7 +1611,13 @@ impl TenantShard { // Update result counter let outcome_label = match &result { - Ok(_) => ReconcileOutcome::Success, + Ok(_) => { + if reconciler.compute_notify_failure { + ReconcileOutcome::SuccessNoNotify + } else { + ReconcileOutcome::Success + } + } Err(ReconcileError::Cancel) => ReconcileOutcome::Cancel, Err(_) => ReconcileOutcome::Error, }; @@ -1908,7 +1916,7 @@ impl TenantShard { waiter: Arc::new(SeqWait::new(Sequence::initial())), error_waiter: Arc::new(SeqWait::new(Sequence::initial())), last_error: Arc::default(), - consecutive_errors_count: 0, + consecutive_reconciles_count: 0, pending_compute_notification: false, delayed_reconcile: false, scheduling_policy: serde_json::from_str(&tsp.scheduling_policy).unwrap(), diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 7c90c0162c..fc33fb45c1 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -2119,11 +2119,14 @@ class NeonStorageController(MetricsGetter, LogUtils): headers=self.headers(TokenScope.ADMIN), ) - def node_delete(self, node_id): + def node_delete(self, node_id, force: bool = False): log.info(f"node_delete({node_id})") + query = f"{self.api}/control/v1/node/{node_id}/delete" + if force: + query += "?force=true" self.request( "PUT", - f"{self.api}/control/v1/node/{node_id}/delete", + query, headers=self.headers(TokenScope.ADMIN), ) diff --git a/test_runner/fixtures/pageserver/http.py b/test_runner/fixtures/pageserver/http.py index 23b9d1c8c9..f95b0ee4d1 100644 --- a/test_runner/fixtures/pageserver/http.py +++ b/test_runner/fixtures/pageserver/http.py @@ -847,7 +847,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): return res_json def timeline_lsn_lease( - self, tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, lsn: Lsn + self, tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, lsn: Lsn, **kwargs ): data = { "lsn": str(lsn), @@ -857,6 +857,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): res = self.post( f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/lsn_lease", json=data, + **kwargs, ) self.verbose_error(res) res_json = res.json() diff --git a/test_runner/regress/test_compatibility.py b/test_runner/regress/test_compatibility.py index a4d2bf8d9b..a3a20cdc62 100644 --- a/test_runner/regress/test_compatibility.py +++ b/test_runner/regress/test_compatibility.py @@ -187,19 +187,21 @@ def test_create_snapshot( env.pageserver.stop() env.storage_controller.stop() - # Directory `compatibility_snapshot_dir` is uploaded to S3 in a workflow, keep the name in sync with it - compatibility_snapshot_dir = ( + # Directory `new_compatibility_snapshot_dir` is uploaded to S3 in a workflow, keep the name in sync with it + new_compatibility_snapshot_dir = ( top_output_dir / f"compatibility_snapshot_pg{pg_version.v_prefixed}" ) - if compatibility_snapshot_dir.exists(): - shutil.rmtree(compatibility_snapshot_dir) + if new_compatibility_snapshot_dir.exists(): + shutil.rmtree(new_compatibility_snapshot_dir) shutil.copytree( test_output_dir, - compatibility_snapshot_dir, + new_compatibility_snapshot_dir, ignore=shutil.ignore_patterns("pg_dynshmem"), ) + log.info(f"Copied new compatibility snapshot dir to: {new_compatibility_snapshot_dir}") + # check_neon_works does recovery from WAL => the compatibility snapshot's WAL is old => will log this warning ingest_lag_log_line = ".*ingesting record with timestamp lagging more than wait_lsn_timeout.*" @@ -218,6 +220,7 @@ def test_backward_compatibility( """ Test that the new binaries can read old data """ + log.info(f"Using snapshot dir at {compatibility_snapshot_dir}") neon_env_builder.num_safekeepers = 3 env = neon_env_builder.from_repo_dir(compatibility_snapshot_dir / "repo") env.pageserver.allowed_errors.append(ingest_lag_log_line) @@ -242,7 +245,6 @@ def test_forward_compatibility( test_output_dir: Path, top_output_dir: Path, pg_version: PgVersion, - compatibility_snapshot_dir: Path, compute_reconfigure_listener: ComputeReconfigure, ): """ @@ -266,8 +268,14 @@ def test_forward_compatibility( neon_env_builder.neon_binpath = neon_env_builder.compatibility_neon_binpath neon_env_builder.pg_distrib_dir = neon_env_builder.compatibility_pg_distrib_dir + # Note that we are testing with new data, so we should use `new_compatibility_snapshot_dir`, which is created by test_create_snapshot. + new_compatibility_snapshot_dir = ( + top_output_dir / f"compatibility_snapshot_pg{pg_version.v_prefixed}" + ) + + log.info(f"Using snapshot dir at {new_compatibility_snapshot_dir}") env = neon_env_builder.from_repo_dir( - compatibility_snapshot_dir / "repo", + new_compatibility_snapshot_dir / "repo", ) # there may be an arbitrary number of unrelated tests run between create_snapshot and here env.pageserver.allowed_errors.append(ingest_lag_log_line) @@ -296,7 +304,7 @@ def test_forward_compatibility( check_neon_works( env, test_output_dir=test_output_dir, - sql_dump_path=compatibility_snapshot_dir / "dump.sql", + sql_dump_path=new_compatibility_snapshot_dir / "dump.sql", repo_dir=env.repo_dir, ) diff --git a/test_runner/regress/test_pageserver_layer_rolling.py b/test_runner/regress/test_pageserver_layer_rolling.py index 91c4ef521c..68f470d962 100644 --- a/test_runner/regress/test_pageserver_layer_rolling.py +++ b/test_runner/regress/test_pageserver_layer_rolling.py @@ -246,9 +246,9 @@ def test_total_size_limit(neon_env_builder: NeonEnvBuilder): system_memory = psutil.virtual_memory().total - # The smallest total size limit we can configure is 1/1024th of the system memory (e.g. 128MB on - # a system with 128GB of RAM). We will then write enough data to violate this limit. - max_dirty_data = 128 * 1024 * 1024 + # The smallest total size limit we can configure is 1/1024th of the system memory (e.g. 256MB on + # a system with 256GB of RAM). We will then write enough data to violate this limit. + max_dirty_data = 256 * 1024 * 1024 ephemeral_bytes_per_memory_kb = (max_dirty_data * 1024) // system_memory assert ephemeral_bytes_per_memory_kb > 0 @@ -272,7 +272,7 @@ def test_total_size_limit(neon_env_builder: NeonEnvBuilder): timeline_count = 10 # This is about 2MiB of data per timeline - entries_per_timeline = 100_000 + entries_per_timeline = 200_000 last_flush_lsns = asyncio.run(workload(env, tenant_conf, timeline_count, entries_per_timeline)) wait_until_pageserver_is_caught_up(env, last_flush_lsns) diff --git a/test_runner/regress/test_storage_controller.py b/test_runner/regress/test_storage_controller.py index 10845ef02e..9986c1f24a 100644 --- a/test_runner/regress/test_storage_controller.py +++ b/test_runner/regress/test_storage_controller.py @@ -12,7 +12,7 @@ from typing import TYPE_CHECKING import fixtures.utils import pytest from fixtures.auth_tokens import TokenScope -from fixtures.common_types import TenantId, TenantShardId, TimelineId +from fixtures.common_types import Lsn, TenantId, TenantShardId, TimelineId from fixtures.log_helper import log from fixtures.neon_fixtures import ( DEFAULT_AZ_ID, @@ -47,6 +47,7 @@ from fixtures.utils import ( wait_until, ) from fixtures.workload import Workload +from requests.adapters import HTTPAdapter from urllib3 import Retry from werkzeug.wrappers.response import Response @@ -72,6 +73,12 @@ def get_node_shard_counts(env: NeonEnv, tenant_ids): return counts +class DeletionAPIKind(Enum): + OLD = "old" + FORCE = "force" + GRACEFUL = "graceful" + + @pytest.mark.parametrize(**fixtures.utils.allpairs_versions()) def test_storage_controller_smoke( neon_env_builder: NeonEnvBuilder, compute_reconfigure_listener: ComputeReconfigure, combination @@ -990,7 +997,7 @@ def test_storage_controller_compute_hook_retry( @run_only_on_default_postgres("postgres behavior is not relevant") -def test_storage_controller_compute_hook_keep_failing( +def test_storage_controller_compute_hook_stuck_reconciles( httpserver: HTTPServer, neon_env_builder: NeonEnvBuilder, httpserver_listen_address: ListenAddress, @@ -1040,7 +1047,7 @@ def test_storage_controller_compute_hook_keep_failing( env.storage_controller.allowed_errors.append(NOTIFY_BLOCKED_LOG) env.storage_controller.allowed_errors.extend(NOTIFY_FAILURE_LOGS) env.storage_controller.allowed_errors.append(".*Keeping extra secondaries.*") - env.storage_controller.allowed_errors.append(".*Shard reconciliation is keep-failing.*") + env.storage_controller.allowed_errors.append(".*Shard reconciliation is stuck.*") env.storage_controller.node_configure(banned_tenant_ps.id, {"availability": "Offline"}) # Migrate all allowed tenant shards to the first alive pageserver @@ -1055,7 +1062,7 @@ def test_storage_controller_compute_hook_keep_failing( # Make some reconcile_all calls to trigger optimizations # RECONCILE_COUNT must be greater than storcon's MAX_CONSECUTIVE_RECONCILIATION_ERRORS - RECONCILE_COUNT = 12 + RECONCILE_COUNT = 20 for i in range(RECONCILE_COUNT): try: n = env.storage_controller.reconcile_all() @@ -1068,6 +1075,8 @@ def test_storage_controller_compute_hook_keep_failing( assert banned_descr["shards"][0]["is_pending_compute_notification"] is True time.sleep(2) + env.storage_controller.assert_log_contains(".*Shard reconciliation is stuck.*") + # Check that the allowed tenant shards are optimized due to affinity rules locations = alive_pageservers[0].http_client().tenant_list_locations()["tenant_shards"] not_optimized_shard_count = 0 @@ -2572,9 +2581,11 @@ def test_background_operation_cancellation(neon_env_builder: NeonEnvBuilder): @pytest.mark.parametrize("while_offline", [True, False]) +@pytest.mark.parametrize("deletion_api", [DeletionAPIKind.OLD, DeletionAPIKind.FORCE]) def test_storage_controller_node_deletion( neon_env_builder: NeonEnvBuilder, while_offline: bool, + deletion_api: DeletionAPIKind, ): """ Test that deleting a node works & properly reschedules everything that was on the node. @@ -2598,6 +2609,8 @@ def test_storage_controller_node_deletion( assert env.storage_controller.reconcile_all() == 0 victim = env.pageservers[-1] + if deletion_api == DeletionAPIKind.FORCE and not while_offline: + victim.allowed_errors.append(".*request was dropped before completing.*") # The procedure a human would follow is: # 1. Mark pageserver scheduling=pause @@ -2621,7 +2634,12 @@ def test_storage_controller_node_deletion( wait_until(assert_shards_migrated) log.info(f"Deleting pageserver {victim.id}") - env.storage_controller.node_delete_old(victim.id) + if deletion_api == DeletionAPIKind.FORCE: + env.storage_controller.node_delete(victim.id, force=True) + elif deletion_api == DeletionAPIKind.OLD: + env.storage_controller.node_delete_old(victim.id) + else: + raise AssertionError(f"Invalid deletion API: {deletion_api}") if not while_offline: @@ -2634,7 +2652,15 @@ def test_storage_controller_node_deletion( wait_until(assert_victim_evacuated) # The node should be gone from the list API - assert victim.id not in [n["id"] for n in env.storage_controller.node_list()] + def assert_node_is_gone(): + assert victim.id not in [n["id"] for n in env.storage_controller.node_list()] + + if deletion_api == DeletionAPIKind.FORCE: + wait_until(assert_node_is_gone) + elif deletion_api == DeletionAPIKind.OLD: + assert_node_is_gone() + else: + raise AssertionError(f"Invalid deletion API: {deletion_api}") # No tenants should refer to the node in their intent for tenant_id in tenant_ids: @@ -2656,7 +2682,11 @@ def test_storage_controller_node_deletion( env.storage_controller.consistency_check() -def test_storage_controller_node_delete_cancellation(neon_env_builder: NeonEnvBuilder): +@pytest.mark.parametrize("deletion_api", [DeletionAPIKind.FORCE, DeletionAPIKind.GRACEFUL]) +def test_storage_controller_node_delete_cancellation( + neon_env_builder: NeonEnvBuilder, + deletion_api: DeletionAPIKind, +): neon_env_builder.num_pageservers = 3 neon_env_builder.num_azs = 3 env = neon_env_builder.init_configs() @@ -2680,12 +2710,16 @@ def test_storage_controller_node_delete_cancellation(neon_env_builder: NeonEnvBu assert len(nodes) == 3 env.storage_controller.configure_failpoints(("sleepy-delete-loop", "return(10000)")) + env.storage_controller.configure_failpoints(("delete-node-after-reconciles-spawned", "pause")) ps_id_to_delete = env.pageservers[0].id env.storage_controller.warm_up_all_secondaries() + + assert deletion_api in [DeletionAPIKind.FORCE, DeletionAPIKind.GRACEFUL] + force = deletion_api == DeletionAPIKind.FORCE env.storage_controller.retryable_node_operation( - lambda ps_id: env.storage_controller.node_delete(ps_id), + lambda ps_id: env.storage_controller.node_delete(ps_id, force), ps_id_to_delete, max_attempts=3, backoff=2, @@ -2701,6 +2735,8 @@ def test_storage_controller_node_delete_cancellation(neon_env_builder: NeonEnvBu env.storage_controller.cancel_node_delete(ps_id_to_delete) + env.storage_controller.configure_failpoints(("delete-node-after-reconciles-spawned", "off")) + env.storage_controller.poll_node_status( ps_id_to_delete, PageserverAvailability.ACTIVE, @@ -3252,7 +3288,10 @@ def test_storage_controller_ps_restarted_during_drain(neon_env_builder: NeonEnvB wait_until(reconfigure_node_again) -def test_ps_unavailable_after_delete(neon_env_builder: NeonEnvBuilder): +@pytest.mark.parametrize("deletion_api", [DeletionAPIKind.OLD, DeletionAPIKind.FORCE]) +def test_ps_unavailable_after_delete( + neon_env_builder: NeonEnvBuilder, deletion_api: DeletionAPIKind +): neon_env_builder.num_pageservers = 3 env = neon_env_builder.init_start() @@ -3265,10 +3304,16 @@ def test_ps_unavailable_after_delete(neon_env_builder: NeonEnvBuilder): assert_nodes_count(3) ps = env.pageservers[0] - env.storage_controller.node_delete_old(ps.id) - # After deletion, the node count must be reduced - assert_nodes_count(2) + if deletion_api == DeletionAPIKind.FORCE: + ps.allowed_errors.append(".*request was dropped before completing.*") + env.storage_controller.node_delete(ps.id, force=True) + wait_until(lambda: assert_nodes_count(2)) + elif deletion_api == DeletionAPIKind.OLD: + env.storage_controller.node_delete_old(ps.id) + assert_nodes_count(2) + else: + raise AssertionError(f"Invalid deletion API: {deletion_api}") # Running pageserver CLI init in a separate thread with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: @@ -4814,3 +4859,103 @@ def test_storage_controller_migrate_with_pageserver_restart( "shards": [{"node_id": int(secondary.id), "shard_number": 0}], "preferred_az": DEFAULT_AZ_ID, } + + +@run_only_on_default_postgres("PG version is not important for this test") +def test_storage_controller_forward_404(neon_env_builder: NeonEnvBuilder): + """ + Ensures that the storage controller correctly forwards 404s and converts some of them + into 503s before forwarding to the client. + """ + neon_env_builder.num_pageservers = 2 + neon_env_builder.num_azs = 2 + + env = neon_env_builder.init_start() + env.storage_controller.allowed_errors.append(".*Reconcile error.*") + env.storage_controller.allowed_errors.append(".*Timed out.*") + + env.storage_controller.tenant_policy_update(env.initial_tenant, {"placement": {"Attached": 1}}) + env.storage_controller.reconcile_until_idle() + + # 404s on tenants and timelines are forwarded as-is when reconciler is not running. + + # Access a non-existing timeline -> 404 + with pytest.raises(PageserverApiException) as e: + env.storage_controller.pageserver_api().timeline_detail( + env.initial_tenant, TimelineId.generate() + ) + assert e.value.status_code == 404 + with pytest.raises(PageserverApiException) as e: + env.storage_controller.pageserver_api().timeline_lsn_lease( + env.initial_tenant, TimelineId.generate(), Lsn(0) + ) + assert e.value.status_code == 404 + + # Access a non-existing tenant when reconciler is not running -> 404 + with pytest.raises(PageserverApiException) as e: + env.storage_controller.pageserver_api().timeline_detail( + TenantId.generate(), env.initial_timeline + ) + assert e.value.status_code == 404 + with pytest.raises(PageserverApiException) as e: + env.storage_controller.pageserver_api().timeline_lsn_lease( + TenantId.generate(), env.initial_timeline, Lsn(0) + ) + assert e.value.status_code == 404 + + # Normal requests should succeed + detail = env.storage_controller.pageserver_api().timeline_detail( + env.initial_tenant, env.initial_timeline + ) + last_record_lsn = Lsn(detail["last_record_lsn"]) + env.storage_controller.pageserver_api().timeline_lsn_lease( + env.initial_tenant, env.initial_timeline, last_record_lsn + ) + + # Get into a situation where the intent state is not the same as the observed state. + describe = env.storage_controller.tenant_describe(env.initial_tenant)["shards"][0] + current_primary = describe["node_attached"] + current_secondary = describe["node_secondary"][0] + assert current_primary != current_secondary + + # Pause the reconciler so that the generation number won't be updated. + env.storage_controller.configure_failpoints( + ("reconciler-live-migrate-post-generation-inc", "pause") + ) + + # Do the migration in another thread; the request will be dropped as we don't wait. + shard_zero = TenantShardId(env.initial_tenant, 0, 0) + concurrent.futures.ThreadPoolExecutor(max_workers=1).submit( + env.storage_controller.tenant_shard_migrate, + shard_zero, + current_secondary, + StorageControllerMigrationConfig(override_scheduler=True), + ) + # Not the best way to do this, we should wait until the migration gets started. + time.sleep(1) + placement = env.storage_controller.get_tenants_placement()[str(shard_zero)] + assert placement["observed"] != placement["intent"] + assert placement["observed"]["attached"] == current_primary + assert placement["intent"]["attached"] == current_secondary + + # Now we issue requests that would cause 404 again + retry_strategy = Retry(total=0) + adapter = HTTPAdapter(max_retries=retry_strategy) + + no_retry_api = env.storage_controller.pageserver_api() + no_retry_api.mount("http://", adapter) + no_retry_api.mount("https://", adapter) + + # As intent state != observed state, tenant not found error should return 503, + # so that the client can retry once we've successfully migrated. + with pytest.raises(PageserverApiException) as e: + no_retry_api.timeline_detail(env.initial_tenant, TimelineId.generate()) + assert e.value.status_code == 503, f"unexpected status code and error: {e.value}" + with pytest.raises(PageserverApiException) as e: + no_retry_api.timeline_lsn_lease(env.initial_tenant, TimelineId.generate(), Lsn(0)) + assert e.value.status_code == 503, f"unexpected status code and error: {e.value}" + + # Unblock reconcile operations + env.storage_controller.configure_failpoints( + ("reconciler-live-migrate-post-generation-inc", "off") + ) diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index c61598cdf6..d6d64a2045 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -107,7 +107,6 @@ tracing-core = { version = "0.1" } tracing-log = { version = "0.2" } tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] } url = { version = "2", features = ["serde"] } -uuid = { version = "1", features = ["serde", "v4", "v7"] } zeroize = { version = "1", features = ["derive", "serde"] } zstd = { version = "0.13" } zstd-safe = { version = "7", default-features = false, features = ["arrays", "legacy", "std", "zdict_builder"] }