From ca88521653fcd574ecd1eae49a5d861139d3e4d2 Mon Sep 17 00:00:00 2001 From: HaoyuHuang Date: Tue, 29 Jul 2025 14:30:34 -0700 Subject: [PATCH 1/6] Set neon_superuser privilege under lakebase mode (#12775) ## Problem ## Summary of changes --- compute_tools/src/spec_apply.rs | 7 ++++++- compute_tools/src/sql/create_privileged_role.sql | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/compute_tools/src/spec_apply.rs b/compute_tools/src/spec_apply.rs index 2356078703..00f34c1f0e 100644 --- a/compute_tools/src/spec_apply.rs +++ b/compute_tools/src/spec_apply.rs @@ -679,7 +679,12 @@ async fn get_operations<'a>( ApplySpecPhase::CreatePrivilegedRole => Ok(Box::new(once(Operation { query: format!( include_str!("sql/create_privileged_role.sql"), - privileged_role_name = params.privileged_role_name + privileged_role_name = params.privileged_role_name, + privileges = if params.lakebase_mode { + "CREATEDB CREATEROLE NOLOGIN BYPASSRLS" + } else { + "CREATEDB CREATEROLE NOLOGIN REPLICATION BYPASSRLS" + } ), comment: None, }))), diff --git a/compute_tools/src/sql/create_privileged_role.sql b/compute_tools/src/sql/create_privileged_role.sql index df27ac32fc..ac2521445f 100644 --- a/compute_tools/src/sql/create_privileged_role.sql +++ b/compute_tools/src/sql/create_privileged_role.sql @@ -2,7 +2,7 @@ DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{privileged_role_name}') THEN - CREATE ROLE {privileged_role_name} CREATEDB CREATEROLE NOLOGIN REPLICATION BYPASSRLS IN ROLE pg_read_all_data, pg_write_all_data; + CREATE ROLE {privileged_role_name} {privileges} IN ROLE pg_read_all_data, pg_write_all_data; END IF; END $$; From 1dce2a9e746edf7b93ce1048ebf63bf5c1395c18 Mon Sep 17 00:00:00 2001 From: Heikki Linnakangas Date: Wed, 30 Jul 2025 01:20:05 +0300 Subject: [PATCH 2/6] Change how pageserver connection info is passed in compute spec (#12604) Add a new 'pageserver_connection_info' field in the compute spec. It replaces the old 'pageserver_connstring' field with a more complicated struct that includes both libpq and grpc URLs, for each shard (or only one of the the URLs, depending on the configuration). It also includes a flag suggesting which one to use; compute_ctl now uses it to decide which protocol to use for the basebackup. This is backwards-compatible with everything that's in production. If the control plane fills in `pageserver_connection_info`, compute_ctl uses that. If it fills in the `pageserver_connstring`/`shard_stripe_size` fields, it uses those. As last resort, it uses the 'neon.pageserver_connstring' GUC from the list of Postgres settings. The 'grpc' flag in the endpoint config is now more of a suggestion, and it's used to populate the 'prefer_protocol' flag in the compute spec. Regardless of the flag, compute_ctl gets both URLs, so it can choose to use libpq or grpc as it wishes. It currently always obeys the flag to choose which method to use for getting the basebackup, but Postgres itself will always use the libpq protocol. (That will be changed with the new rust-based communicator project, which implements the gRPC client in the compute). After that, the `pageserver_connection_info.prefer_protocol` flag in the spec file can be used to control whether compute_ctl uses grpc or libpq. The actual compute's grpc usage will be controlled by the `neon.enable_new_communicator` GUC (not yet; that will be introduced in the future, with the new rust-base communicator project). It can be set separately from 'prefer_protocol'. Later: - Once all old computes are gone, remove the code to pass `neon.pageserver_connstring` --- compute_tools/src/compute.rs | 133 ++++++++++-------- compute_tools/src/config.rs | 72 +++++++++- compute_tools/src/configurator.rs | 7 +- compute_tools/src/lsn_lease.rs | 61 ++++---- control_plane/src/bin/neon_local.rs | 170 ++++++++-------------- control_plane/src/endpoint.rs | 187 +++++++++++++++++++------ libs/compute_api/src/spec.rs | 172 ++++++++++++++++++++++- libs/utils/src/shard.rs | 4 + pgxn/neon/libpagestore.c | 8 +- pgxn/neon/pagestore_client.h | 2 +- storage_controller/src/compute_hook.rs | 82 ++++++++--- test_runner/regress/test_basebackup.py | 6 +- 12 files changed, 628 insertions(+), 276 deletions(-) diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index 27d33d8cd8..1033670e2b 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -7,7 +7,7 @@ use compute_api::responses::{ }; use compute_api::spec::{ ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, ExtVersion, GenericOption, - PageserverProtocol, PgIdent, Role, + PageserverConnectionInfo, PageserverProtocol, PgIdent, Role, }; use futures::StreamExt; use futures::future::join_all; @@ -38,7 +38,7 @@ use utils::id::{TenantId, TimelineId}; use utils::lsn::Lsn; use utils::measured_stream::MeasuredReader; use utils::pid_file; -use utils::shard::{ShardCount, ShardIndex, ShardNumber}; +use utils::shard::{ShardIndex, ShardNumber, ShardStripeSize}; use crate::configurator::launch_configurator; use crate::disk_quota::set_disk_quota; @@ -249,7 +249,7 @@ pub struct ParsedSpec { pub spec: ComputeSpec, pub tenant_id: TenantId, pub timeline_id: TimelineId, - pub pageserver_connstr: String, + pub pageserver_conninfo: PageserverConnectionInfo, pub safekeeper_connstrings: Vec, pub storage_auth_token: Option, /// k8s dns name and port @@ -297,25 +297,47 @@ impl ParsedSpec { } impl TryFrom for ParsedSpec { - type Error = String; - fn try_from(spec: ComputeSpec) -> Result { + type Error = anyhow::Error; + fn try_from(spec: ComputeSpec) -> Result { // Extract the options from the spec file that are needed to connect to // the storage system. // - // For backwards-compatibility, the top-level fields in the spec file - // may be empty. In that case, we need to dig them from the GUCs in the - // cluster.settings field. - let pageserver_connstr = spec - .pageserver_connstring - .clone() - .or_else(|| spec.cluster.settings.find("neon.pageserver_connstring")) - .ok_or("pageserver connstr should be provided")?; + // In compute specs generated by old control plane versions, the spec file might + // be missing the `pageserver_connection_info` field. In that case, we need to dig + // the pageserver connection info from the `pageserver_connstr` field instead, or + // if that's missing too, from the GUC in the cluster.settings field. + let mut pageserver_conninfo = spec.pageserver_connection_info.clone(); + if pageserver_conninfo.is_none() { + if let Some(pageserver_connstr_field) = &spec.pageserver_connstring { + pageserver_conninfo = Some(PageserverConnectionInfo::from_connstr( + pageserver_connstr_field, + spec.shard_stripe_size, + )?); + } + } + if pageserver_conninfo.is_none() { + if let Some(guc) = spec.cluster.settings.find("neon.pageserver_connstring") { + let stripe_size = if let Some(guc) = spec.cluster.settings.find("neon.stripe_size") + { + Some(ShardStripeSize(u32::from_str(&guc)?)) + } else { + None + }; + pageserver_conninfo = + Some(PageserverConnectionInfo::from_connstr(&guc, stripe_size)?); + } + } + let pageserver_conninfo = pageserver_conninfo.ok_or(anyhow::anyhow!( + "pageserver connection information should be provided" + ))?; + + // Similarly for safekeeper connection strings let safekeeper_connstrings = if spec.safekeeper_connstrings.is_empty() { if matches!(spec.mode, ComputeMode::Primary) { spec.cluster .settings .find("neon.safekeepers") - .ok_or("safekeeper connstrings should be provided")? + .ok_or(anyhow::anyhow!("safekeeper connstrings should be provided"))? .split(',') .map(|str| str.to_string()) .collect() @@ -330,22 +352,22 @@ impl TryFrom for ParsedSpec { let tenant_id: TenantId = if let Some(tenant_id) = spec.tenant_id { tenant_id } else { - spec.cluster + let guc = spec + .cluster .settings .find("neon.tenant_id") - .ok_or("tenant id should be provided") - .map(|s| TenantId::from_str(&s))? - .or(Err("invalid tenant id"))? + .ok_or(anyhow::anyhow!("tenant id should be provided"))?; + TenantId::from_str(&guc).context("invalid tenant id")? }; let timeline_id: TimelineId = if let Some(timeline_id) = spec.timeline_id { timeline_id } else { - spec.cluster + let guc = spec + .cluster .settings .find("neon.timeline_id") - .ok_or("timeline id should be provided") - .map(|s| TimelineId::from_str(&s))? - .or(Err("invalid timeline id"))? + .ok_or(anyhow::anyhow!("timeline id should be provided"))?; + TimelineId::from_str(&guc).context(anyhow::anyhow!("invalid timeline id"))? }; let endpoint_storage_addr: Option = spec @@ -359,7 +381,7 @@ impl TryFrom for ParsedSpec { let res = ParsedSpec { spec, - pageserver_connstr, + pageserver_conninfo, safekeeper_connstrings, storage_auth_token, tenant_id, @@ -369,7 +391,7 @@ impl TryFrom for ParsedSpec { }; // Now check validity of the parsed specification - res.validate()?; + res.validate().map_err(anyhow::Error::msg)?; Ok(res) } } @@ -1195,12 +1217,10 @@ impl ComputeNode { fn try_get_basebackup(&self, compute_state: &ComputeState, lsn: Lsn) -> Result<()> { let spec = compute_state.pspec.as_ref().expect("spec must be set"); - let shard0_connstr = spec.pageserver_connstr.split(',').next().unwrap(); let started = Instant::now(); - - let (connected, size) = match PageserverProtocol::from_connstring(shard0_connstr)? { - PageserverProtocol::Libpq => self.try_get_basebackup_libpq(spec, lsn)?, + let (connected, size) = match spec.pageserver_conninfo.prefer_protocol { PageserverProtocol::Grpc => self.try_get_basebackup_grpc(spec, lsn)?, + PageserverProtocol::Libpq => self.try_get_basebackup_libpq(spec, lsn)?, }; self.fix_zenith_signal_neon_signal()?; @@ -1238,23 +1258,20 @@ impl ComputeNode { /// Fetches a basebackup via gRPC. The connstring must use grpc://. Returns the timestamp when /// the connection was established, and the (compressed) size of the basebackup. fn try_get_basebackup_grpc(&self, spec: &ParsedSpec, lsn: Lsn) -> Result<(Instant, usize)> { - let shard0_connstr = spec - .pageserver_connstr - .split(',') - .next() - .unwrap() - .to_string(); - let shard_index = match spec.pageserver_connstr.split(',').count() as u8 { - 0 | 1 => ShardIndex::unsharded(), - count => ShardIndex::new(ShardNumber(0), ShardCount(count)), + let shard0_index = ShardIndex { + shard_number: ShardNumber(0), + shard_count: spec.pageserver_conninfo.shard_count, }; - + let shard0_url = spec + .pageserver_conninfo + .shard_url(ShardNumber(0), PageserverProtocol::Grpc)? + .to_owned(); let (reader, connected) = tokio::runtime::Handle::current().block_on(async move { let mut client = page_api::Client::connect( - shard0_connstr, + shard0_url, spec.tenant_id, spec.timeline_id, - shard_index, + shard0_index, spec.storage_auth_token.clone(), None, // NB: base backups use payload compression ) @@ -1286,7 +1303,9 @@ impl ComputeNode { /// Fetches a basebackup via libpq. The connstring must use postgresql://. Returns the timestamp /// when the connection was established, and the (compressed) size of the basebackup. fn try_get_basebackup_libpq(&self, spec: &ParsedSpec, lsn: Lsn) -> Result<(Instant, usize)> { - let shard0_connstr = spec.pageserver_connstr.split(',').next().unwrap(); + let shard0_connstr = spec + .pageserver_conninfo + .shard_url(ShardNumber(0), PageserverProtocol::Libpq)?; let mut config = postgres::Config::from_str(shard0_connstr)?; // Use the storage auth token from the config file, if given. @@ -1373,10 +1392,7 @@ impl ComputeNode { return result; } Err(ref e) if attempts < max_attempts => { - warn!( - "Failed to get basebackup: {} (attempt {}/{})", - e, attempts, max_attempts - ); + warn!("Failed to get basebackup: {e:?} (attempt {attempts}/{max_attempts})"); std::thread::sleep(std::time::Duration::from_millis(retry_period_ms as u64)); retry_period_ms *= 1.5; } @@ -1589,16 +1605,8 @@ impl ComputeNode { } }; - info!( - "getting basebackup@{} from pageserver {}", - lsn, &pspec.pageserver_connstr - ); - self.get_basebackup(compute_state, lsn).with_context(|| { - format!( - "failed to get basebackup@{} from pageserver {}", - lsn, &pspec.pageserver_connstr - ) - })?; + self.get_basebackup(compute_state, lsn) + .with_context(|| format!("failed to get basebackup@{lsn}"))?; if let Some(settings) = databricks_settings { copy_tls_certificates( @@ -2642,22 +2650,22 @@ LIMIT 100", /// The operation will time out after a specified duration. pub fn wait_timeout_while_pageserver_connstr_unchanged(&self, duration: Duration) { let state = self.state.lock().unwrap(); - let old_pageserver_connstr = state + let old_pageserver_conninfo = state .pspec .as_ref() .expect("spec must be set") - .pageserver_connstr + .pageserver_conninfo .clone(); let mut unchanged = true; let _ = self .state_changed .wait_timeout_while(state, duration, |s| { - let pageserver_connstr = &s + let pageserver_conninfo = &s .pspec .as_ref() .expect("spec must be set") - .pageserver_connstr; - unchanged = pageserver_connstr == &old_pageserver_connstr; + .pageserver_conninfo; + unchanged = pageserver_conninfo == &old_pageserver_conninfo; unchanged }) .unwrap(); @@ -2915,7 +2923,10 @@ mod tests { match ParsedSpec::try_from(spec.clone()) { Ok(_p) => panic!("Failed to detect duplicate entry"), - Err(e) => assert!(e.starts_with("duplicate entry in safekeeper_connstrings:")), + Err(e) => assert!( + e.to_string() + .starts_with("duplicate entry in safekeeper_connstrings:") + ), }; } } diff --git a/compute_tools/src/config.rs b/compute_tools/src/config.rs index 55a1eda0b7..e7dde5c5f5 100644 --- a/compute_tools/src/config.rs +++ b/compute_tools/src/config.rs @@ -18,6 +18,8 @@ use crate::pg_helpers::{ }; use crate::tls::{self, SERVER_CRT, SERVER_KEY}; +use utils::shard::{ShardIndex, ShardNumber}; + /// Check that `line` is inside a text file and put it there if it is not. /// Create file if it doesn't exist. pub fn line_in_file(path: &Path, line: &str) -> Result { @@ -69,9 +71,75 @@ pub fn write_postgres_conf( } // Add options for connecting to storage writeln!(file, "# Neon storage settings")?; - if let Some(s) = &spec.pageserver_connstring { - writeln!(file, "neon.pageserver_connstring={}", escape_conf_value(s))?; + writeln!(file)?; + if let Some(conninfo) = &spec.pageserver_connection_info { + let mut libpq_urls: Option> = Some(Vec::new()); + let num_shards = if conninfo.shard_count.0 == 0 { + 1 // unsharded, treat it as a single shard + } else { + conninfo.shard_count.0 + }; + + for shard_number in 0..num_shards { + let shard_index = ShardIndex { + shard_number: ShardNumber(shard_number), + shard_count: conninfo.shard_count, + }; + let info = conninfo.shards.get(&shard_index).ok_or_else(|| { + anyhow::anyhow!( + "shard {shard_index} missing from pageserver_connection_info shard map" + ) + })?; + + let first_pageserver = info + .pageservers + .first() + .expect("must have at least one pageserver"); + + // Add the libpq URL to the array, or if the URL is missing, reset the array + // forgetting any previous entries. All servers must have a libpq URL, or none + // at all. + if let Some(url) = &first_pageserver.libpq_url { + if let Some(ref mut urls) = libpq_urls { + urls.push(url.clone()); + } + } else { + libpq_urls = None + } + } + if let Some(libpq_urls) = libpq_urls { + writeln!( + file, + "# derived from compute spec's pageserver_conninfo field" + )?; + writeln!( + file, + "neon.pageserver_connstring={}", + escape_conf_value(&libpq_urls.join(",")) + )?; + } else { + writeln!(file, "# no neon.pageserver_connstring")?; + } + + if let Some(stripe_size) = conninfo.stripe_size { + writeln!( + file, + "# from compute spec's pageserver_conninfo.stripe_size field" + )?; + writeln!(file, "neon.stripe_size={stripe_size}")?; + } + } else { + if let Some(s) = &spec.pageserver_connstring { + writeln!(file, "# from compute spec's pageserver_connstring field")?; + writeln!(file, "neon.pageserver_connstring={}", escape_conf_value(s))?; + } + + if let Some(stripe_size) = spec.shard_stripe_size { + writeln!(file, "# from compute spec's shard_stripe_size field")?; + writeln!(file, "neon.stripe_size={stripe_size}")?; + } } + if !spec.safekeeper_connstrings.is_empty() { let mut neon_safekeepers_value = String::new(); tracing::info!( diff --git a/compute_tools/src/configurator.rs b/compute_tools/src/configurator.rs index feca8337b2..79eb80c4a0 100644 --- a/compute_tools/src/configurator.rs +++ b/compute_tools/src/configurator.rs @@ -122,8 +122,11 @@ fn configurator_main_loop(compute: &Arc) { // into the type system. assert_eq!(state.status, ComputeStatus::RefreshConfiguration); - if state.pspec.as_ref().map(|ps| ps.pageserver_connstr.clone()) - == Some(pspec.pageserver_connstr.clone()) + if state + .pspec + .as_ref() + .map(|ps| ps.pageserver_conninfo.clone()) + == Some(pspec.pageserver_conninfo.clone()) { info!( "Refresh configuration: Retrieved spec is the same as the current spec. Waiting for control plane to update the spec before attempting reconfiguration." diff --git a/compute_tools/src/lsn_lease.rs b/compute_tools/src/lsn_lease.rs index bb0828429d..6abfea82e0 100644 --- a/compute_tools/src/lsn_lease.rs +++ b/compute_tools/src/lsn_lease.rs @@ -4,14 +4,13 @@ use std::thread; use std::time::{Duration, SystemTime}; use anyhow::{Result, bail}; -use compute_api::spec::{ComputeMode, PageserverProtocol}; -use itertools::Itertools as _; +use compute_api::spec::{ComputeMode, PageserverConnectionInfo, PageserverProtocol}; use pageserver_page_api as page_api; use postgres::{NoTls, SimpleQueryMessage}; use tracing::{info, warn}; use utils::id::{TenantId, TimelineId}; use utils::lsn::Lsn; -use utils::shard::{ShardCount, ShardNumber, TenantShardId}; +use utils::shard::TenantShardId; use crate::compute::ComputeNode; @@ -78,17 +77,16 @@ fn acquire_lsn_lease_with_retry( loop { // Note: List of pageservers is dynamic, need to re-read configs before each attempt. - let (connstrings, auth) = { + let (conninfo, auth) = { let state = compute.state.lock().unwrap(); let spec = state.pspec.as_ref().expect("spec must be set"); ( - spec.pageserver_connstr.clone(), + spec.pageserver_conninfo.clone(), spec.storage_auth_token.clone(), ) }; - let result = - try_acquire_lsn_lease(&connstrings, auth.as_deref(), tenant_id, timeline_id, lsn); + let result = try_acquire_lsn_lease(conninfo, auth.as_deref(), tenant_id, timeline_id, lsn); match result { Ok(Some(res)) => { return Ok(res); @@ -112,35 +110,44 @@ fn acquire_lsn_lease_with_retry( /// Tries to acquire LSN leases on all Pageserver shards. fn try_acquire_lsn_lease( - connstrings: &str, + conninfo: PageserverConnectionInfo, auth: Option<&str>, tenant_id: TenantId, timeline_id: TimelineId, lsn: Lsn, ) -> Result> { - let connstrings = connstrings.split(',').collect_vec(); - let shard_count = connstrings.len(); let mut leases = Vec::new(); - for (shard_number, &connstring) in connstrings.iter().enumerate() { - let tenant_shard_id = match shard_count { - 0 | 1 => TenantShardId::unsharded(tenant_id), - shard_count => TenantShardId { - tenant_id, - shard_number: ShardNumber(shard_number as u8), - shard_count: ShardCount::new(shard_count as u8), - }, + for (shard_index, shard) in conninfo.shards.into_iter() { + let tenant_shard_id = TenantShardId { + tenant_id, + shard_number: shard_index.shard_number, + shard_count: shard_index.shard_count, }; - let lease = match PageserverProtocol::from_connstring(connstring)? { - PageserverProtocol::Libpq => { - acquire_lsn_lease_libpq(connstring, auth, tenant_shard_id, timeline_id, lsn)? - } - PageserverProtocol::Grpc => { - acquire_lsn_lease_grpc(connstring, auth, tenant_shard_id, timeline_id, lsn)? - } - }; - leases.push(lease); + // XXX: If there are more than pageserver for the one shard, do we need to get a + // leas on all of them? Currently, that's what we assume, but this is hypothetical + // as of this writing, as we never pass the info for more than one pageserver per + // shard. + for pageserver in shard.pageservers { + let lease = match conninfo.prefer_protocol { + PageserverProtocol::Grpc => acquire_lsn_lease_grpc( + &pageserver.grpc_url.unwrap(), + auth, + tenant_shard_id, + timeline_id, + lsn, + )?, + PageserverProtocol::Libpq => acquire_lsn_lease_libpq( + &pageserver.libpq_url.unwrap(), + auth, + tenant_shard_id, + timeline_id, + lsn, + )?, + }; + leases.push(lease); + } } Ok(leases.into_iter().min().flatten()) diff --git a/control_plane/src/bin/neon_local.rs b/control_plane/src/bin/neon_local.rs index c126835066..95500b0b18 100644 --- a/control_plane/src/bin/neon_local.rs +++ b/control_plane/src/bin/neon_local.rs @@ -19,6 +19,9 @@ use compute_api::requests::ComputeClaimsScope; use compute_api::spec::{ComputeMode, PageserverProtocol}; use control_plane::broker::StorageBroker; use control_plane::endpoint::{ComputeControlPlane, EndpointTerminateMode}; +use control_plane::endpoint::{ + local_pageserver_conf_to_conn_info, tenant_locate_response_to_conn_info, +}; use control_plane::endpoint_storage::{ENDPOINT_STORAGE_DEFAULT_ADDR, EndpointStorage}; use control_plane::local_env; use control_plane::local_env::{ @@ -44,7 +47,6 @@ use pageserver_api::models::{ }; use pageserver_api::shard::{DEFAULT_STRIPE_SIZE, ShardCount, ShardStripeSize, TenantShardId}; use postgres_backend::AuthType; -use postgres_connection::parse_host_port; use safekeeper_api::membership::{SafekeeperGeneration, SafekeeperId}; use safekeeper_api::{ DEFAULT_HTTP_LISTEN_PORT as DEFAULT_SAFEKEEPER_HTTP_PORT, @@ -52,7 +54,6 @@ use safekeeper_api::{ }; use storage_broker::DEFAULT_LISTEN_ADDR as DEFAULT_BROKER_ADDR; use tokio::task::JoinSet; -use url::Host; use utils::auth::{Claims, Scope}; use utils::id::{NodeId, TenantId, TenantTimelineId, TimelineId}; use utils::lsn::Lsn; @@ -1546,62 +1547,41 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res )?; } - let (pageservers, stripe_size) = if let Some(pageserver_id) = pageserver_id { - let conf = env.get_pageserver_conf(pageserver_id).unwrap(); - // Use gRPC if requested. - let pageserver = if endpoint.grpc { - let grpc_addr = conf.listen_grpc_addr.as_ref().expect("bad config"); - let (host, port) = parse_host_port(grpc_addr)?; - let port = port.unwrap_or(DEFAULT_PAGESERVER_GRPC_PORT); - (PageserverProtocol::Grpc, host, port) - } else { - let (host, port) = parse_host_port(&conf.listen_pg_addr)?; - let port = port.unwrap_or(5432); - (PageserverProtocol::Libpq, host, port) - }; - // If caller is telling us what pageserver to use, this is not a tenant which is - // fully managed by storage controller, therefore not sharded. - (vec![pageserver], DEFAULT_STRIPE_SIZE) + let prefer_protocol = if endpoint.grpc { + PageserverProtocol::Grpc + } else { + PageserverProtocol::Libpq + }; + + let mut pageserver_conninfo = if let Some(ps_id) = pageserver_id { + let conf = env.get_pageserver_conf(ps_id).unwrap(); + local_pageserver_conf_to_conn_info(conf)? } else { // Look up the currently attached location of the tenant, and its striping metadata, // to pass these on to postgres. let storage_controller = StorageController::from_env(env); let locate_result = storage_controller.tenant_locate(endpoint.tenant_id).await?; - let pageservers = futures::future::try_join_all( - locate_result.shards.into_iter().map(|shard| async move { - if let ComputeMode::Static(lsn) = endpoint.mode { - // Initialize LSN leases for static computes. + assert!(!locate_result.shards.is_empty()); + + // Initialize LSN leases for static computes. + if let ComputeMode::Static(lsn) = endpoint.mode { + futures::future::try_join_all(locate_result.shards.iter().map( + |shard| async move { let conf = env.get_pageserver_conf(shard.node_id).unwrap(); let pageserver = PageServerNode::from_env(env, conf); pageserver .http_client .timeline_init_lsn_lease(shard.shard_id, endpoint.timeline_id, lsn) - .await?; - } + .await + }, + )) + .await?; + } - let pageserver = if endpoint.grpc { - ( - PageserverProtocol::Grpc, - Host::parse(&shard.listen_grpc_addr.expect("no gRPC address"))?, - shard.listen_grpc_port.expect("no gRPC port"), - ) - } else { - ( - PageserverProtocol::Libpq, - Host::parse(&shard.listen_pg_addr)?, - shard.listen_pg_port, - ) - }; - anyhow::Ok(pageserver) - }), - ) - .await?; - let stripe_size = locate_result.shard_params.stripe_size; - - (pageservers, stripe_size) + tenant_locate_response_to_conn_info(&locate_result)? }; - assert!(!pageservers.is_empty()); + pageserver_conninfo.prefer_protocol = prefer_protocol; let ps_conf = env.get_pageserver_conf(DEFAULT_PAGESERVER_ID)?; let auth_token = if matches!(ps_conf.pg_auth_type, AuthType::NeonJWT) { @@ -1631,9 +1611,8 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res endpoint_storage_addr, safekeepers_generation, safekeepers, - pageservers, + pageserver_conninfo, remote_ext_base_url: remote_ext_base_url.clone(), - shard_stripe_size: stripe_size.0 as usize, create_test_user: args.create_test_user, start_timeout: args.start_timeout, autoprewarm: args.autoprewarm, @@ -1650,37 +1629,29 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res .endpoints .get(endpoint_id.as_str()) .with_context(|| format!("postgres endpoint {endpoint_id} is not found"))?; - let pageservers = match args.pageserver_id { + let prefer_protocol = if endpoint.grpc { + PageserverProtocol::Grpc + } else { + PageserverProtocol::Libpq + }; + let mut pageserver_conninfo = match args.pageserver_id { Some(pageserver_id) => { - let pageserver = - PageServerNode::from_env(env, env.get_pageserver_conf(pageserver_id)?); - - vec![( - PageserverProtocol::Libpq, - pageserver.pg_connection_config.host().clone(), - pageserver.pg_connection_config.port(), - )] + let conf = env.get_pageserver_conf(pageserver_id)?; + local_pageserver_conf_to_conn_info(conf)? } None => { let storage_controller = StorageController::from_env(env); - storage_controller - .tenant_locate(endpoint.tenant_id) - .await? - .shards - .into_iter() - .map(|shard| { - ( - PageserverProtocol::Libpq, - Host::parse(&shard.listen_pg_addr) - .expect("Storage controller reported malformed host"), - shard.listen_pg_port, - ) - }) - .collect::>() + let locate_result = + storage_controller.tenant_locate(endpoint.tenant_id).await?; + + tenant_locate_response_to_conn_info(&locate_result)? } }; + pageserver_conninfo.prefer_protocol = prefer_protocol; - endpoint.update_pageservers_in_config(pageservers).await?; + endpoint + .update_pageservers_in_config(&pageserver_conninfo) + .await?; } EndpointCmd::Reconfigure(args) => { let endpoint_id = &args.endpoint_id; @@ -1688,51 +1659,30 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res .endpoints .get(endpoint_id.as_str()) .with_context(|| format!("postgres endpoint {endpoint_id} is not found"))?; - let pageservers = if let Some(ps_id) = args.endpoint_pageserver_id { - let conf = env.get_pageserver_conf(ps_id)?; - // Use gRPC if requested. - let pageserver = if endpoint.grpc { - let grpc_addr = conf.listen_grpc_addr.as_ref().expect("bad config"); - let (host, port) = parse_host_port(grpc_addr)?; - let port = port.unwrap_or(DEFAULT_PAGESERVER_GRPC_PORT); - (PageserverProtocol::Grpc, host, port) - } else { - let (host, port) = parse_host_port(&conf.listen_pg_addr)?; - let port = port.unwrap_or(5432); - (PageserverProtocol::Libpq, host, port) - }; - vec![pageserver] + + let prefer_protocol = if endpoint.grpc { + PageserverProtocol::Grpc } else { - let storage_controller = StorageController::from_env(env); - storage_controller - .tenant_locate(endpoint.tenant_id) - .await? - .shards - .into_iter() - .map(|shard| { - // Use gRPC if requested. - if endpoint.grpc { - ( - PageserverProtocol::Grpc, - Host::parse(&shard.listen_grpc_addr.expect("no gRPC address")) - .expect("bad hostname"), - shard.listen_grpc_port.expect("no gRPC port"), - ) - } else { - ( - PageserverProtocol::Libpq, - Host::parse(&shard.listen_pg_addr).expect("bad hostname"), - shard.listen_pg_port, - ) - } - }) - .collect::>() + PageserverProtocol::Libpq }; + let mut pageserver_conninfo = if let Some(ps_id) = args.endpoint_pageserver_id { + let conf = env.get_pageserver_conf(ps_id)?; + local_pageserver_conf_to_conn_info(conf)? + } else { + // Look up the currently attached location of the tenant, and its striping metadata, + // to pass these on to postgres. + let storage_controller = StorageController::from_env(env); + let locate_result = storage_controller.tenant_locate(endpoint.tenant_id).await?; + + tenant_locate_response_to_conn_info(&locate_result)? + }; + pageserver_conninfo.prefer_protocol = prefer_protocol; + // If --safekeepers argument is given, use only the listed // safekeeper nodes; otherwise all from the env. let safekeepers = parse_safekeepers(&args.safekeepers)?; endpoint - .reconfigure(Some(pageservers), None, safekeepers, None) + .reconfigure(Some(&pageserver_conninfo), safekeepers, None) .await?; } EndpointCmd::RefreshConfiguration(args) => { diff --git a/control_plane/src/endpoint.rs b/control_plane/src/endpoint.rs index 1c7f489d68..814ee2a52f 100644 --- a/control_plane/src/endpoint.rs +++ b/control_plane/src/endpoint.rs @@ -37,7 +37,7 @@ //! //! ``` //! -use std::collections::BTreeMap; +use std::collections::{BTreeMap, HashMap}; use std::fmt::Display; use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream}; use std::path::PathBuf; @@ -58,8 +58,12 @@ use compute_api::responses::{ }; use compute_api::spec::{ Cluster, ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, Database, PageserverProtocol, - PgIdent, RemoteExtSpec, Role, + PageserverShardInfo, PgIdent, RemoteExtSpec, Role, }; + +// re-export these, because they're used in the reconfigure() function +pub use compute_api::spec::{PageserverConnectionInfo, PageserverShardConnectionInfo}; + use jsonwebtoken::jwk::{ AlgorithmParameters, CommonParameters, EllipticCurve, Jwk, JwkSet, KeyAlgorithm, KeyOperations, OctetKeyPairParameters, OctetKeyPairType, PublicKeyUse, @@ -74,9 +78,11 @@ use sha2::{Digest, Sha256}; use spki::der::Decode; use spki::{SubjectPublicKeyInfo, SubjectPublicKeyInfoRef}; use tracing::debug; -use url::Host; use utils::id::{NodeId, TenantId, TimelineId}; -use utils::shard::ShardStripeSize; +use utils::shard::{ShardCount, ShardIndex, ShardNumber}; + +use pageserver_api::config::DEFAULT_GRPC_LISTEN_PORT as DEFAULT_PAGESERVER_GRPC_PORT; +use postgres_connection::parse_host_port; use crate::local_env::LocalEnv; use crate::postgresql_conf::PostgresConf; @@ -387,9 +393,8 @@ pub struct EndpointStartArgs { pub endpoint_storage_addr: String, pub safekeepers_generation: Option, pub safekeepers: Vec, - pub pageservers: Vec<(PageserverProtocol, Host, u16)>, + pub pageserver_conninfo: PageserverConnectionInfo, pub remote_ext_base_url: Option, - pub shard_stripe_size: usize, pub create_test_user: bool, pub start_timeout: Duration, pub autoprewarm: bool, @@ -662,14 +667,6 @@ impl Endpoint { } } - fn build_pageserver_connstr(pageservers: &[(PageserverProtocol, Host, u16)]) -> String { - pageservers - .iter() - .map(|(scheme, host, port)| format!("{scheme}://no_user@{host}:{port}")) - .collect::>() - .join(",") - } - /// Map safekeepers ids to the actual connection strings. fn build_safekeepers_connstrs(&self, sk_ids: Vec) -> Result> { let mut safekeeper_connstrings = Vec::new(); @@ -715,9 +712,6 @@ impl Endpoint { std::fs::remove_dir_all(self.pgdata())?; } - let pageserver_connstring = Self::build_pageserver_connstr(&args.pageservers); - assert!(!pageserver_connstring.is_empty()); - let safekeeper_connstrings = self.build_safekeepers_connstrs(args.safekeepers)?; // check for file remote_extensions_spec.json @@ -732,6 +726,44 @@ impl Endpoint { remote_extensions = None; }; + // For the sake of backwards-compatibility, also fill in 'pageserver_connstring' + // + // XXX: I believe this is not really needed, except to make + // test_forward_compatibility happy. + // + // Use a closure so that we can conviniently return None in the middle of the + // loop. + let pageserver_connstring: Option = (|| { + let num_shards = args.pageserver_conninfo.shard_count.count(); + let mut connstrings = Vec::new(); + for shard_no in 0..num_shards { + let shard_index = ShardIndex { + shard_count: args.pageserver_conninfo.shard_count, + shard_number: ShardNumber(shard_no), + }; + let shard = args + .pageserver_conninfo + .shards + .get(&shard_index) + .ok_or_else(|| { + anyhow!( + "shard {} not found in pageserver_connection_info", + shard_index + ) + })?; + let pageserver = shard + .pageservers + .first() + .ok_or(anyhow!("must have at least one pageserver"))?; + if let Some(libpq_url) = &pageserver.libpq_url { + connstrings.push(libpq_url.clone()); + } else { + return Ok::<_, anyhow::Error>(None); + } + } + Ok(Some(connstrings.join(","))) + })()?; + // Create config file let config = { let mut spec = ComputeSpec { @@ -776,13 +808,14 @@ impl Endpoint { branch_id: None, endpoint_id: Some(self.endpoint_id.clone()), mode: self.mode, - pageserver_connstring: Some(pageserver_connstring), + pageserver_connection_info: Some(args.pageserver_conninfo.clone()), + pageserver_connstring, safekeepers_generation: args.safekeepers_generation.map(|g| g.into_inner()), safekeeper_connstrings, storage_auth_token: args.auth_token.clone(), remote_extensions, pgbouncer_settings: None, - shard_stripe_size: Some(args.shard_stripe_size), + shard_stripe_size: args.pageserver_conninfo.stripe_size, // redundant with pageserver_connection_info.stripe_size local_proxy_config: None, reconfigure_concurrency: self.reconfigure_concurrency, drop_subscriptions_before_start: self.drop_subscriptions_before_start, @@ -966,7 +999,7 @@ impl Endpoint { // Update the pageservers in the spec file of the endpoint. This is useful to test the spec refresh scenario. pub async fn update_pageservers_in_config( &self, - pageservers: Vec<(PageserverProtocol, Host, u16)>, + pageserver_conninfo: &PageserverConnectionInfo, ) -> Result<()> { let config_path = self.endpoint_path().join("config.json"); let mut config: ComputeConfig = { @@ -974,10 +1007,8 @@ impl Endpoint { serde_json::from_reader(file)? }; - let pageserver_connstring = Self::build_pageserver_connstr(&pageservers); - assert!(!pageserver_connstring.is_empty()); let mut spec = config.spec.unwrap(); - spec.pageserver_connstring = Some(pageserver_connstring); + spec.pageserver_connection_info = Some(pageserver_conninfo.clone()); config.spec = Some(spec); let file = std::fs::File::create(&config_path)?; @@ -1020,8 +1051,7 @@ impl Endpoint { pub async fn reconfigure( &self, - pageservers: Option>, - stripe_size: Option, + pageserver_conninfo: Option<&PageserverConnectionInfo>, safekeepers: Option>, safekeeper_generation: Option, ) -> Result<()> { @@ -1036,15 +1066,15 @@ impl Endpoint { let postgresql_conf = self.read_postgresql_conf()?; spec.cluster.postgresql_conf = Some(postgresql_conf); - // If pageservers are not specified, don't change them. - if let Some(pageservers) = pageservers { - anyhow::ensure!(!pageservers.is_empty(), "no pageservers provided"); - - let pageserver_connstr = Self::build_pageserver_connstr(&pageservers); - spec.pageserver_connstring = Some(pageserver_connstr); - if stripe_size.is_some() { - spec.shard_stripe_size = stripe_size.map(|s| s.0 as usize); - } + if let Some(pageserver_conninfo) = pageserver_conninfo { + // If pageservers are provided, we need to ensure that they are not empty. + // This is a requirement for the compute_ctl configuration. + anyhow::ensure!( + !pageserver_conninfo.shards.is_empty(), + "no pageservers provided" + ); + spec.pageserver_connection_info = Some(pageserver_conninfo.clone()); + spec.shard_stripe_size = pageserver_conninfo.stripe_size; } // If safekeepers are not specified, don't change them. @@ -1093,11 +1123,9 @@ impl Endpoint { pub async fn reconfigure_pageservers( &self, - pageservers: Vec<(PageserverProtocol, Host, u16)>, - stripe_size: Option, + pageservers: &PageserverConnectionInfo, ) -> Result<()> { - self.reconfigure(Some(pageservers), stripe_size, None, None) - .await + self.reconfigure(Some(pageservers), None, None).await } pub async fn reconfigure_safekeepers( @@ -1105,7 +1133,7 @@ impl Endpoint { safekeepers: Vec, generation: SafekeeperGeneration, ) -> Result<()> { - self.reconfigure(None, None, Some(safekeepers), Some(generation)) + self.reconfigure(None, Some(safekeepers), Some(generation)) .await } @@ -1188,3 +1216,84 @@ impl Endpoint { ) } } + +/// If caller is telling us what pageserver to use, this is not a tenant which is +/// fully managed by storage controller, therefore not sharded. +pub fn local_pageserver_conf_to_conn_info( + conf: &crate::local_env::PageServerConf, +) -> Result { + let libpq_url = { + let (host, port) = parse_host_port(&conf.listen_pg_addr)?; + let port = port.unwrap_or(5432); + Some(format!("postgres://no_user@{host}:{port}")) + }; + let grpc_url = if let Some(grpc_addr) = &conf.listen_grpc_addr { + let (host, port) = parse_host_port(grpc_addr)?; + let port = port.unwrap_or(DEFAULT_PAGESERVER_GRPC_PORT); + Some(format!("grpc://no_user@{host}:{port}")) + } else { + None + }; + let ps_conninfo = PageserverShardConnectionInfo { + id: Some(conf.id), + libpq_url, + grpc_url, + }; + + let shard_info = PageserverShardInfo { + pageservers: vec![ps_conninfo], + }; + + let shards: HashMap<_, _> = vec![(ShardIndex::unsharded(), shard_info)] + .into_iter() + .collect(); + Ok(PageserverConnectionInfo { + shard_count: ShardCount::unsharded(), + stripe_size: None, + shards, + prefer_protocol: PageserverProtocol::default(), + }) +} + +pub fn tenant_locate_response_to_conn_info( + response: &pageserver_api::controller_api::TenantLocateResponse, +) -> Result { + let mut shards = HashMap::new(); + for shard in response.shards.iter() { + tracing::info!("parsing {}", shard.listen_pg_addr); + let libpq_url = { + let host = &shard.listen_pg_addr; + let port = shard.listen_pg_port; + Some(format!("postgres://no_user@{host}:{port}")) + }; + let grpc_url = if let Some(grpc_addr) = &shard.listen_grpc_addr { + let host = grpc_addr; + let port = shard.listen_grpc_port.expect("no gRPC port"); + Some(format!("grpc://no_user@{host}:{port}")) + } else { + None + }; + + let shard_info = PageserverShardInfo { + pageservers: vec![PageserverShardConnectionInfo { + id: Some(shard.node_id), + libpq_url, + grpc_url, + }], + }; + + shards.insert(shard.shard_id.to_index(), shard_info); + } + + let stripe_size = if response.shard_params.count.is_unsharded() { + None + } else { + Some(response.shard_params.stripe_size) + }; + Ok(PageserverConnectionInfo { + shard_count: response.shard_params.count, + stripe_size, + shards, + prefer_protocol: PageserverProtocol::default(), + }) +} diff --git a/libs/compute_api/src/spec.rs b/libs/compute_api/src/spec.rs index 6709c06fc6..12d825e1bf 100644 --- a/libs/compute_api/src/spec.rs +++ b/libs/compute_api/src/spec.rs @@ -12,8 +12,9 @@ use regex::Regex; use remote_storage::RemotePath; use serde::{Deserialize, Serialize}; use url::Url; -use utils::id::{TenantId, TimelineId}; +use utils::id::{NodeId, TenantId, TimelineId}; use utils::lsn::Lsn; +use utils::shard::{ShardCount, ShardIndex, ShardNumber, ShardStripeSize}; use crate::responses::TlsConfig; @@ -105,8 +106,27 @@ pub struct ComputeSpec { // updated to fill these fields, we can make these non optional. pub tenant_id: Option, pub timeline_id: Option, + + /// Pageserver information can be passed in three different ways: + /// 1. Here in `pageserver_connection_info` + /// 2. In the `pageserver_connstring` field. + /// 3. in `cluster.settings`. + /// + /// The goal is to use method 1. everywhere. But for backwards-compatibility with old + /// versions of the control plane, `compute_ctl` will check 2. and 3. if the + /// `pageserver_connection_info` field is missing. + /// + /// If both `pageserver_connection_info` and `pageserver_connstring`+`shard_stripe_size` are + /// given, they must contain the same information. + pub pageserver_connection_info: Option, + pub pageserver_connstring: Option, + /// Stripe size for pageserver sharding, in pages. This is set together with the legacy + /// `pageserver_connstring` field. When the modern `pageserver_connection_info` field is used, + /// the stripe size is stored in `pageserver_connection_info.stripe_size` instead. + pub shard_stripe_size: Option, + // More neon ids that we expose to the compute_ctl // and to postgres as neon extension GUCs. pub project_id: Option, @@ -139,10 +159,6 @@ pub struct ComputeSpec { pub pgbouncer_settings: Option>, - // Stripe size for pageserver sharding, in pages - #[serde(default)] - pub shard_stripe_size: Option, - /// Local Proxy configuration used for JWT authentication #[serde(default)] pub local_proxy_config: Option, @@ -217,6 +233,140 @@ pub enum ComputeFeature { UnknownFeature, } +#[derive(Clone, Debug, Deserialize, Serialize, Eq, PartialEq)] +pub struct PageserverConnectionInfo { + /// NB: 0 for unsharded tenants, 1 for sharded tenants with 1 shard, following storage + pub shard_count: ShardCount, + + /// INVARIANT: null if shard_count is 0, otherwise non-null and immutable + pub stripe_size: Option, + + pub shards: HashMap, + + /// If the compute supports both protocols, this indicates which one it should use. The compute + /// may use other available protocols too, if it doesn't support the preferred one. The URL's + /// for the protocol specified here must be present for all shards, i.e. do not mark a protocol + /// as preferred if it cannot actually be used with all the pageservers. + #[serde(default)] + pub prefer_protocol: PageserverProtocol, +} + +/// Extract PageserverConnectionInfo from a comma-separated list of libpq connection strings. +/// +/// This is used for backwards-compatibility, to parse the legacy +/// [ComputeSpec::pageserver_connstring] field, or the 'neon.pageserver_connstring' GUC. Nowadays, +/// the 'pageserver_connection_info' field should be used instead. +impl PageserverConnectionInfo { + pub fn from_connstr( + connstr: &str, + stripe_size: Option, + ) -> Result { + let shard_infos: Vec<_> = connstr + .split(',') + .map(|connstr| PageserverShardInfo { + pageservers: vec![PageserverShardConnectionInfo { + id: None, + libpq_url: Some(connstr.to_string()), + grpc_url: None, + }], + }) + .collect(); + + match shard_infos.len() { + 0 => anyhow::bail!("empty connection string"), + 1 => { + // We assume that if there's only connection string, it means "unsharded", + // rather than a sharded system with just a single shard. The latter is + // possible in principle, but we never do it. + let shard_count = ShardCount::unsharded(); + let only_shard = shard_infos.first().unwrap().clone(); + let shards = vec![(ShardIndex::unsharded(), only_shard)]; + Ok(PageserverConnectionInfo { + shard_count, + stripe_size: None, + shards: shards.into_iter().collect(), + prefer_protocol: PageserverProtocol::Libpq, + }) + } + n => { + if stripe_size.is_none() { + anyhow::bail!("{n} shards but no stripe_size"); + } + let shard_count = ShardCount(n.try_into()?); + let shards = shard_infos + .into_iter() + .enumerate() + .map(|(idx, shard_info)| { + ( + ShardIndex { + shard_count, + shard_number: ShardNumber( + idx.try_into().expect("shard number fits in u8"), + ), + }, + shard_info, + ) + }) + .collect(); + Ok(PageserverConnectionInfo { + shard_count, + stripe_size, + shards, + prefer_protocol: PageserverProtocol::Libpq, + }) + } + } + } + + /// Convenience routine to get the connection string for a shard. + pub fn shard_url( + &self, + shard_number: ShardNumber, + protocol: PageserverProtocol, + ) -> anyhow::Result<&str> { + let shard_index = ShardIndex { + shard_number, + shard_count: self.shard_count, + }; + let shard = self.shards.get(&shard_index).ok_or_else(|| { + anyhow::anyhow!("shard connection info missing for shard {}", shard_index) + })?; + + // Just use the first pageserver in the list. That's good enough for this + // convenience routine; if you need more control, like round robin policy or + // failover support, roll your own. (As of this writing, we never have more than + // one pageserver per shard anyway, but that will change in the future.) + let pageserver = shard + .pageservers + .first() + .ok_or(anyhow::anyhow!("must have at least one pageserver"))?; + + let result = match protocol { + PageserverProtocol::Grpc => pageserver + .grpc_url + .as_ref() + .ok_or(anyhow::anyhow!("no grpc_url for shard {shard_index}"))?, + PageserverProtocol::Libpq => pageserver + .libpq_url + .as_ref() + .ok_or(anyhow::anyhow!("no libpq_url for shard {shard_index}"))?, + }; + Ok(result) + } +} + +#[derive(Clone, Debug, Deserialize, Serialize, Eq, PartialEq)] +pub struct PageserverShardInfo { + pub pageservers: Vec, +} + +#[derive(Clone, Debug, Deserialize, Serialize, Eq, PartialEq)] +pub struct PageserverShardConnectionInfo { + pub id: Option, + pub libpq_url: Option, + pub grpc_url: Option, +} + #[derive(Clone, Debug, Default, Deserialize, Serialize)] pub struct RemoteExtSpec { pub public_extensions: Option>, @@ -334,6 +484,12 @@ impl ComputeMode { } } +impl Display for ComputeMode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.to_type_str()) + } +} + /// Log level for audit logging #[derive(Clone, Debug, Default, Eq, PartialEq, Deserialize, Serialize)] pub enum ComputeAudit { @@ -470,13 +626,15 @@ pub struct JwksSettings { pub jwt_audience: Option, } -/// Protocol used to connect to a Pageserver. Parsed from the connstring scheme. -#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +/// Protocol used to connect to a Pageserver. +#[derive(Clone, Copy, Debug, Default, Deserialize, Serialize, PartialEq, Eq)] pub enum PageserverProtocol { /// The original protocol based on libpq and COPY. Uses postgresql:// or postgres:// scheme. #[default] + #[serde(rename = "libpq")] Libpq, /// A newer, gRPC-based protocol. Uses grpc:// scheme. + #[serde(rename = "grpc")] Grpc, } diff --git a/libs/utils/src/shard.rs b/libs/utils/src/shard.rs index 6ad6cab3a8..90323f7762 100644 --- a/libs/utils/src/shard.rs +++ b/libs/utils/src/shard.rs @@ -59,6 +59,10 @@ impl ShardCount { pub const MAX: Self = Self(u8::MAX); pub const MIN: Self = Self(0); + pub fn unsharded() -> Self { + ShardCount(0) + } + /// The internal value of a ShardCount may be zero, which means "1 shard, but use /// legacy format for TenantShardId that excludes the shard suffix", also known /// as [`TenantShardId::unsharded`]. diff --git a/pgxn/neon/libpagestore.c b/pgxn/neon/libpagestore.c index 158118f860..87bdbf376c 100644 --- a/pgxn/neon/libpagestore.c +++ b/pgxn/neon/libpagestore.c @@ -71,7 +71,7 @@ char *neon_project_id; char *neon_branch_id; char *neon_endpoint_id; int32 max_cluster_size; -char *page_server_connstring; +char *pageserver_connstring; char *neon_auth_token; int readahead_buffer_size = 128; @@ -1453,7 +1453,7 @@ PagestoreShmemInit(void) pg_atomic_init_u64(&pagestore_shared->begin_update_counter, 0); pg_atomic_init_u64(&pagestore_shared->end_update_counter, 0); memset(&pagestore_shared->shard_map, 0, sizeof(ShardMap)); - AssignPageserverConnstring(page_server_connstring, NULL); + AssignPageserverConnstring(pageserver_connstring, NULL); } } @@ -1472,7 +1472,7 @@ pg_init_libpagestore(void) DefineCustomStringVariable("neon.pageserver_connstring", "connection string to the page server", NULL, - &page_server_connstring, + &pageserver_connstring, "", PGC_SIGHUP, 0, /* no flags required */ @@ -1643,7 +1643,7 @@ pg_init_libpagestore(void) if (neon_auth_token) neon_log(LOG, "using storage auth token from NEON_AUTH_TOKEN environment variable"); - if (page_server_connstring && page_server_connstring[0]) + if (pageserver_connstring[0]) { neon_log(PageStoreTrace, "set neon_smgr hook"); smgr_hook = smgr_neon; diff --git a/pgxn/neon/pagestore_client.h b/pgxn/neon/pagestore_client.h index 4470d3a94d..bfe00c9285 100644 --- a/pgxn/neon/pagestore_client.h +++ b/pgxn/neon/pagestore_client.h @@ -236,7 +236,7 @@ extern void prefetch_on_ps_disconnect(void); extern page_server_api *page_server; -extern char *page_server_connstring; +extern char *pageserver_connstring; extern int flush_every_n_requests; extern int readahead_buffer_size; extern char *neon_timeline; diff --git a/storage_controller/src/compute_hook.rs b/storage_controller/src/compute_hook.rs index fb03412f3c..efeb6005d5 100644 --- a/storage_controller/src/compute_hook.rs +++ b/storage_controller/src/compute_hook.rs @@ -6,13 +6,16 @@ use std::time::Duration; use anyhow::Context; use compute_api::spec::PageserverProtocol; -use control_plane::endpoint::{ComputeControlPlane, EndpointStatus}; +use compute_api::spec::PageserverShardInfo; +use control_plane::endpoint::{ + ComputeControlPlane, EndpointStatus, PageserverConnectionInfo, PageserverShardConnectionInfo, +}; use control_plane::local_env::LocalEnv; use futures::StreamExt; use hyper::StatusCode; use pageserver_api::config::DEFAULT_GRPC_LISTEN_PORT; use pageserver_api::controller_api::AvailabilityZone; -use pageserver_api::shard::{ShardCount, ShardNumber, ShardStripeSize, TenantShardId}; +use pageserver_api::shard::{ShardCount, ShardIndex, ShardNumber, ShardStripeSize, TenantShardId}; use postgres_connection::parse_host_port; use safekeeper_api::membership::SafekeeperGeneration; use serde::{Deserialize, Serialize}; @@ -506,27 +509,64 @@ impl ApiMethod for ComputeHookTenant { if endpoint.tenant_id == *tenant_id && endpoint.status() == EndpointStatus::Running { tracing::info!("Reconfiguring pageservers for endpoint {endpoint_name}"); - let pageservers = shards - .iter() - .map(|shard| { - let ps_conf = env - .get_pageserver_conf(shard.node_id) - .expect("Unknown pageserver"); - if endpoint.grpc { - let addr = ps_conf.listen_grpc_addr.as_ref().expect("no gRPC address"); - let (host, port) = parse_host_port(addr).expect("invalid gRPC address"); - let port = port.unwrap_or(DEFAULT_GRPC_LISTEN_PORT); - (PageserverProtocol::Grpc, host, port) - } else { - let (host, port) = parse_host_port(&ps_conf.listen_pg_addr) - .expect("Unable to parse listen_pg_addr"); - (PageserverProtocol::Libpq, host, port.unwrap_or(5432)) - } - }) - .collect::>(); + let shard_count = match shards.len() { + 1 => ShardCount::unsharded(), + n => ShardCount(n.try_into().expect("too many shards")), + }; + + let mut shard_infos: HashMap = HashMap::new(); + + let prefer_protocol = if endpoint.grpc { + PageserverProtocol::Grpc + } else { + PageserverProtocol::Libpq + }; + + for shard in shards.iter() { + let ps_conf = env + .get_pageserver_conf(shard.node_id) + .expect("Unknown pageserver"); + + let libpq_url = Some({ + let (host, port) = parse_host_port(&ps_conf.listen_pg_addr) + .expect("Unable to parse listen_pg_addr"); + let port = port.unwrap_or(5432); + format!("postgres://no_user@{host}:{port}") + }); + let grpc_url = if let Some(grpc_addr) = &ps_conf.listen_grpc_addr { + let (host, port) = + parse_host_port(grpc_addr).expect("invalid gRPC address"); + let port = port.unwrap_or(DEFAULT_GRPC_LISTEN_PORT); + Some(format!("grpc://no_user@{host}:{port}")) + } else { + None + }; + let pageserver = PageserverShardConnectionInfo { + id: Some(shard.node_id), + libpq_url, + grpc_url, + }; + let shard_info = PageserverShardInfo { + pageservers: vec![pageserver], + }; + shard_infos.insert( + ShardIndex { + shard_number: shard.shard_number, + shard_count, + }, + shard_info, + ); + } + + let pageserver_conninfo = PageserverConnectionInfo { + shard_count, + stripe_size: stripe_size.map(|val| ShardStripeSize(val.0)), + shards: shard_infos, + prefer_protocol, + }; endpoint - .reconfigure_pageservers(pageservers, *stripe_size) + .reconfigure_pageservers(&pageserver_conninfo) .await .map_err(NotifyError::NeonLocal)?; } diff --git a/test_runner/regress/test_basebackup.py b/test_runner/regress/test_basebackup.py index d1b10ec85d..23b9105617 100644 --- a/test_runner/regress/test_basebackup.py +++ b/test_runner/regress/test_basebackup.py @@ -2,13 +2,15 @@ from __future__ import annotations from typing import TYPE_CHECKING +import pytest from fixtures.utils import wait_until if TYPE_CHECKING: from fixtures.neon_fixtures import NeonEnvBuilder -def test_basebackup_cache(neon_env_builder: NeonEnvBuilder): +@pytest.mark.parametrize("grpc", [True, False]) +def test_basebackup_cache(neon_env_builder: NeonEnvBuilder, grpc: bool): """ Simple test for basebackup cache. 1. Check that we always hit the cache after compute restart. @@ -22,7 +24,7 @@ def test_basebackup_cache(neon_env_builder: NeonEnvBuilder): """ env = neon_env_builder.init_start() - ep = env.endpoints.create("main") + ep = env.endpoints.create("main", grpc=grpc) ps = env.pageserver ps_http = ps.http_client() From b3c1aecd115891753955773257e6fe95efb35242 Mon Sep 17 00:00:00 2001 From: Heikki Linnakangas Date: Wed, 30 Jul 2025 15:19:00 +0300 Subject: [PATCH 3/6] tests: Stop endpoints in parallel (#12769) Shaves off a few seconds from tests involving multiple endpoints. --- test_runner/fixtures/neon_fixtures.py | 28 +++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 7f59547c73..b5148cbfdc 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -5280,16 +5280,32 @@ class EndpointFactory: ) def stop_all(self, fail_on_error=True) -> Self: - exception = None - for ep in self.endpoints: + """ + Stop all the endpoints in parallel. + """ + + # Note: raising an exception from a task in a task group cancels + # all the other tasks. We don't want that, hence the 'stop_one' + # function catches exceptions and puts them on the 'exceptions' + # list for later processing. + exceptions = [] + + async def stop_one(ep): try: - ep.stop() + await asyncio.to_thread(ep.stop) except Exception as e: log.error(f"Failed to stop endpoint {ep.endpoint_id}: {e}") - exception = e + exceptions.append(e) - if fail_on_error and exception is not None: - raise exception + async def async_stop_all(): + async with asyncio.TaskGroup() as tg: + for ep in self.endpoints: + tg.create_task(stop_one(ep)) + + asyncio.run(async_stop_all()) + + if fail_on_error and exceptions: + raise ExceptionGroup("stopping an endpoint failed", exceptions) return self From e989e0da78baf77d16644099a0f648675160d58d Mon Sep 17 00:00:00 2001 From: Ruslan Talpa Date: Wed, 30 Jul 2025 17:17:51 +0300 Subject: [PATCH 4/6] [proxy] accept jwts when configured as rest_broker (#12777) ## Problem when compiled with rest_broker feature and is_rest_broker=true (but is_auth_broker=false) accept_jwts is set to false ## Summary of changes set the config with ``` accept_jwts: args.is_auth_broker || args.is_rest_broker ``` Co-authored-by: Ruslan Talpa --- proxy/src/binary/proxy.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/proxy/src/binary/proxy.rs b/proxy/src/binary/proxy.rs index 29b0ad53f2..583cdc95bf 100644 --- a/proxy/src/binary/proxy.rs +++ b/proxy/src/binary/proxy.rs @@ -700,7 +700,10 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { ip_allowlist_check_enabled: !args.is_private_access_proxy, is_vpc_acccess_proxy: args.is_private_access_proxy, is_auth_broker: args.is_auth_broker, + #[cfg(not(feature = "rest_broker"))] accept_jwts: args.is_auth_broker, + #[cfg(feature = "rest_broker")] + accept_jwts: args.is_auth_broker || args.is_rest_broker, console_redirect_confirmation_timeout: args.webauth_confirmation_timeout, }; From 056056bef0d6ec6858d59e4f872bd51091daa93b Mon Sep 17 00:00:00 2001 From: Suhas Thalanki <54014218+thesuhas@users.noreply.github.com> Date: Wed, 30 Jul 2025 10:33:19 -0400 Subject: [PATCH 5/6] fix(compute): validate `prewarm_local_cache()` input (#12648) ## Problem ``` postgres=> select neon.prewarm_local_cache('\xfcfcfcfc01000000ffffffff070000000000000000000000000000000000000000000000000000000000000000000000000000ff', 1); WARNING: terminating connection because of crash of another server process DETAIL: The postmaster has commanded this server process to roll back the current transaction and exit, because another server process exited abnormally and possibly corrupted shared memory. HINT: In a moment you should be able to reconnect to the database and repeat your command. FATAL: server conn crashed? ``` The function takes a bytea argument and casts it to a C struct, without validating the contents. ## Summary of changes Added validation for number of pages to be prefetched and for the chunks as well. --- pgxn/neon/file_cache.c | 29 ++++++++++++++++++++++++++--- pgxn/neon/neon_utils.c | 19 +++++++++++++++++++ pgxn/neon/neon_utils.h | 4 ++++ 3 files changed, 49 insertions(+), 3 deletions(-) diff --git a/pgxn/neon/file_cache.c b/pgxn/neon/file_cache.c index 3c680eab86..bb810ada68 100644 --- a/pgxn/neon/file_cache.c +++ b/pgxn/neon/file_cache.c @@ -49,6 +49,7 @@ #include "neon.h" #include "neon_lwlsncache.h" #include "neon_perf_counters.h" +#include "neon_utils.h" #include "pagestore_client.h" #include "communicator.h" @@ -673,8 +674,19 @@ lfc_get_state(size_t max_entries) { if (GET_STATE(entry, j) != UNAVAILABLE) { - BITMAP_SET(bitmap, i*lfc_blocks_per_chunk + j); - n_pages += 1; + /* Validate the buffer tag before including it */ + BufferTag test_tag = entry->key; + test_tag.blockNum += j; + + if (BufferTagIsValid(&test_tag)) + { + BITMAP_SET(bitmap, i*lfc_blocks_per_chunk + j); + n_pages += 1; + } + else + { + elog(ERROR, "LFC: Skipping invalid buffer tag during cache state capture: blockNum=%u", test_tag.blockNum); + } } } if (++i == n_entries) @@ -683,7 +695,7 @@ lfc_get_state(size_t max_entries) Assert(i == n_entries); fcs->n_pages = n_pages; Assert(pg_popcount((char*)bitmap, ((n_entries << lfc_chunk_size_log) + 7)/8) == n_pages); - elog(LOG, "LFC: save state of %d chunks %d pages", (int)n_entries, (int)n_pages); + elog(LOG, "LFC: save state of %d chunks %d pages (validated)", (int)n_entries, (int)n_pages); } LWLockRelease(lfc_lock); @@ -702,6 +714,7 @@ lfc_prewarm(FileCacheState* fcs, uint32 n_workers) size_t n_entries; size_t prewarm_batch = Min(lfc_prewarm_batch, readahead_buffer_size); size_t fcs_size; + uint32_t max_prefetch_pages; dsm_segment *seg; BackgroundWorkerHandle* bgw_handle[MAX_PREWARM_WORKERS]; @@ -746,6 +759,11 @@ lfc_prewarm(FileCacheState* fcs, uint32 n_workers) n_entries = Min(fcs->n_chunks, lfc_prewarm_limit); Assert(n_entries != 0); + max_prefetch_pages = n_entries << fcs_chunk_size_log; + if (fcs->n_pages > max_prefetch_pages) { + elog(ERROR, "LFC: Number of pages in file cache state (%d) is more than the limit (%d)", fcs->n_pages, max_prefetch_pages); + } + LWLockAcquire(lfc_lock, LW_EXCLUSIVE); /* Do not prewarm more entries than LFC limit */ @@ -898,6 +916,11 @@ lfc_prewarm_main(Datum main_arg) { tag = fcs->chunks[snd_idx >> fcs_chunk_size_log]; tag.blockNum += snd_idx & ((1 << fcs_chunk_size_log) - 1); + + if (!BufferTagIsValid(&tag)) { + elog(ERROR, "LFC: Invalid buffer tag: %u", tag.blockNum); + } + if (!lfc_cache_contains(BufTagGetNRelFileInfo(tag), tag.forkNum, tag.blockNum)) { (void)communicator_prefetch_register_bufferv(tag, NULL, 1, NULL); diff --git a/pgxn/neon/neon_utils.c b/pgxn/neon/neon_utils.c index 1fad44bd58..847d380eb3 100644 --- a/pgxn/neon/neon_utils.c +++ b/pgxn/neon/neon_utils.c @@ -183,3 +183,22 @@ alloc_curl_handle(void) } #endif + +/* + * Check if a BufferTag is valid by verifying all its fields are not invalid. + */ +bool +BufferTagIsValid(const BufferTag *tag) +{ + #if PG_MAJORVERSION_NUM >= 16 + return (tag->spcOid != InvalidOid) && + (tag->relNumber != InvalidRelFileNumber) && + (tag->forkNum != InvalidForkNumber) && + (tag->blockNum != InvalidBlockNumber); + #else + return (tag->rnode.spcNode != InvalidOid) && + (tag->rnode.relNode != InvalidOid) && + (tag->forkNum != InvalidForkNumber) && + (tag->blockNum != InvalidBlockNumber); + #endif +} diff --git a/pgxn/neon/neon_utils.h b/pgxn/neon/neon_utils.h index 7480ac28cc..65d280788d 100644 --- a/pgxn/neon/neon_utils.h +++ b/pgxn/neon/neon_utils.h @@ -2,6 +2,7 @@ #define __NEON_UTILS_H__ #include "lib/stringinfo.h" +#include "storage/buf_internals.h" #ifndef WALPROPOSER_LIB #include @@ -16,6 +17,9 @@ void pq_sendint32_le(StringInfo buf, uint32 i); void pq_sendint64_le(StringInfo buf, uint64 i); void disable_core_dump(void); +/* Buffer tag validation function */ +bool BufferTagIsValid(const BufferTag *tag); + #ifndef WALPROPOSER_LIB CURL * alloc_curl_handle(void); From 842a5091d5db4c23aeb29aea070c37ad06b12d63 Mon Sep 17 00:00:00 2001 From: Suhas Thalanki <54014218+thesuhas@users.noreply.github.com> Date: Wed, 30 Jul 2025 11:14:59 -0400 Subject: [PATCH 6/6] [BRC-3051] Walproposer: Safekeeper quorum health metrics (#930) (#12750) Today we don't have any indications (other than spammy logs in PG that nobody monitors) if the Walproposer in PG cannot connect to/get votes from all Safekeepers. This means we don't have signals indicating that the Safekeepers are operating at degraded redundancy. We need these signals. Added plumbing in PG extension so that the `neon_perf_counters` view exports the following gauge metrics on safekeeper health: - `num_configured_safekeepers`: The total number of safekeepers configured in PG. - `num_active_safekeepers`: The number of safekeepers that PG is actively streaming WAL to. An alert should be raised whenever `num_active_safekeepers` < `num_configured_safekeepers`. The metrics are implemented by adding additional state to the Walproposer shared memory keeping track of the active statuses of safekeepers using a simple array. The status of the safekeeper is set to active (1) after the Walproposer acquires a quorum and starts streaming data to the safekeeper, and is set to inactive (0) when the connection with a safekeeper is shut down. We scan the safekeeper status array in Walproposer shared memory when collecting the metrics to produce results for the gauges. Added coverage for the metrics to integration test `test_wal_acceptor.py::test_timeline_disk_usage_limit`. ## Problem ## Summary of changes --------- Co-authored-by: William Huang --- libs/walproposer/src/api_bindings.rs | 34 ++++++++++++++++++++++++ libs/walproposer/src/walproposer.rs | 15 +++++++++++ pgxn/neon/neon_perf_counters.c | 27 +++++++++++++++++++ pgxn/neon/walproposer.c | 16 ++++++++++- pgxn/neon/walproposer.h | 26 ++++++++++++++++++ pgxn/neon/walproposer_pg.c | 23 ++++++++++++++++ test_runner/regress/test_wal_acceptor.py | 33 ++++++++++++++++++++--- 7 files changed, 170 insertions(+), 4 deletions(-) diff --git a/libs/walproposer/src/api_bindings.rs b/libs/walproposer/src/api_bindings.rs index 9f88ea6b11..9c90beb379 100644 --- a/libs/walproposer/src/api_bindings.rs +++ b/libs/walproposer/src/api_bindings.rs @@ -341,6 +341,34 @@ extern "C-unwind" fn log_internal( } } +/* BEGIN_HADRON */ +extern "C" fn reset_safekeeper_statuses_for_metrics(wp: *mut WalProposer, num_safekeepers: u32) { + unsafe { + let callback_data = (*(*wp).config).callback_data; + let api = callback_data as *mut Box; + if api.is_null() { + return; + } + (*api).reset_safekeeper_statuses_for_metrics(&mut (*wp), num_safekeepers); + } +} + +extern "C" fn update_safekeeper_status_for_metrics( + wp: *mut WalProposer, + sk_index: u32, + status: u8, +) { + unsafe { + let callback_data = (*(*wp).config).callback_data; + let api = callback_data as *mut Box; + if api.is_null() { + return; + } + (*api).update_safekeeper_status_for_metrics(&mut (*wp), sk_index, status); + } +} +/* END_HADRON */ + #[derive(Debug, PartialEq)] pub enum Level { Debug5, @@ -414,6 +442,10 @@ pub(crate) fn create_api() -> walproposer_api { finish_sync_safekeepers: Some(finish_sync_safekeepers), process_safekeeper_feedback: Some(process_safekeeper_feedback), log_internal: Some(log_internal), + /* BEGIN_HADRON */ + reset_safekeeper_statuses_for_metrics: Some(reset_safekeeper_statuses_for_metrics), + update_safekeeper_status_for_metrics: Some(update_safekeeper_status_for_metrics), + /* END_HADRON */ } } @@ -451,6 +483,8 @@ pub fn empty_shmem() -> crate::bindings::WalproposerShmemState { replica_promote: false, min_ps_feedback: empty_feedback, wal_rate_limiter: empty_wal_rate_limiter, + num_safekeepers: 0, + safekeeper_status: [0; 32], } } diff --git a/libs/walproposer/src/walproposer.rs b/libs/walproposer/src/walproposer.rs index 93bb0d5eb0..8453279c5c 100644 --- a/libs/walproposer/src/walproposer.rs +++ b/libs/walproposer/src/walproposer.rs @@ -159,6 +159,21 @@ pub trait ApiImpl { fn after_election(&self, _wp: &mut WalProposer) { todo!() } + + /* BEGIN_HADRON */ + fn reset_safekeeper_statuses_for_metrics(&self, _wp: &mut WalProposer, _num_safekeepers: u32) { + // Do nothing for testing purposes. + } + + fn update_safekeeper_status_for_metrics( + &self, + _wp: &mut WalProposer, + _sk_index: u32, + _status: u8, + ) { + // Do nothing for testing purposes. + } + /* END_HADRON */ } #[derive(Debug)] diff --git a/pgxn/neon/neon_perf_counters.c b/pgxn/neon/neon_perf_counters.c index fada4cba1e..4527084514 100644 --- a/pgxn/neon/neon_perf_counters.c +++ b/pgxn/neon/neon_perf_counters.c @@ -391,6 +391,12 @@ neon_get_perf_counters(PG_FUNCTION_ARGS) neon_per_backend_counters totals = {0}; metric_t *metrics; + /* BEGIN_HADRON */ + WalproposerShmemState *wp_shmem; + uint32 num_safekeepers; + uint32 num_active_safekeepers; + /* END_HADRON */ + /* We put all the tuples into a tuplestore in one go. */ InitMaterializedSRF(fcinfo, 0); @@ -437,11 +443,32 @@ neon_get_perf_counters(PG_FUNCTION_ARGS) // Not ideal but piggyback our databricks counters into the neon perf counters view // so that we don't need to introduce neon--1.x+1.sql to add a new view. { + // Keeping this code in its own block to work around the C90 "don't mix declarations and code" rule when we define + // the `databricks_metrics` array in the next block. Yes, we are seriously dealing with C90 rules in 2025. + + // Read safekeeper status from wal proposer shared memory first. + // Note that we are taking a mutex when reading from walproposer shared memory so that the total safekeeper count is + // consistent with the active wal acceptors count. Assuming that we don't query this view too often the mutex should + // not be a huge deal. + wp_shmem = GetWalpropShmemState(); + SpinLockAcquire(&wp_shmem->mutex); + num_safekeepers = wp_shmem->num_safekeepers; + num_active_safekeepers = 0; + for (int i = 0; i < num_safekeepers; i++) { + if (wp_shmem->safekeeper_status[i] == 1) { + num_active_safekeepers++; + } + } + SpinLockRelease(&wp_shmem->mutex); + } + { metric_t databricks_metrics[] = { {"sql_index_corruption_count", false, 0, (double) pg_atomic_read_u32(&databricks_metrics_shared->index_corruption_count)}, {"sql_data_corruption_count", false, 0, (double) pg_atomic_read_u32(&databricks_metrics_shared->data_corruption_count)}, {"sql_internal_error_count", false, 0, (double) pg_atomic_read_u32(&databricks_metrics_shared->internal_error_count)}, {"ps_corruption_detected", false, 0, (double) pg_atomic_read_u32(&databricks_metrics_shared->ps_corruption_detected)}, + {"num_active_safekeepers", false, 0.0, (double) num_active_safekeepers}, + {"num_configured_safekeepers", false, 0.0, (double) num_safekeepers}, {NULL, false, 0, 0}, }; for (int i = 0; databricks_metrics[i].name != NULL; i++) diff --git a/pgxn/neon/walproposer.c b/pgxn/neon/walproposer.c index c85a6f4b6f..dd42eaf18e 100644 --- a/pgxn/neon/walproposer.c +++ b/pgxn/neon/walproposer.c @@ -154,7 +154,9 @@ WalProposerCreate(WalProposerConfig *config, walproposer_api api) wp->safekeeper[wp->n_safekeepers].state = SS_OFFLINE; wp->safekeeper[wp->n_safekeepers].active_state = SS_ACTIVE_SEND; wp->safekeeper[wp->n_safekeepers].wp = wp; - + /* BEGIN_HADRON */ + wp->safekeeper[wp->n_safekeepers].index = wp->n_safekeepers; + /* END_HADRON */ { Safekeeper *sk = &wp->safekeeper[wp->n_safekeepers]; int written = 0; @@ -183,6 +185,10 @@ WalProposerCreate(WalProposerConfig *config, walproposer_api api) if (wp->safekeepers_generation > INVALID_GENERATION && wp->config->proto_version < 3) wp_log(FATAL, "enabling generations requires protocol version 3"); wp_log(LOG, "using safekeeper protocol version %d", wp->config->proto_version); + + /* BEGIN_HADRON */ + wp->api.reset_safekeeper_statuses_for_metrics(wp, wp->n_safekeepers); + /* END_HADRON */ /* Fill the greeting package */ wp->greetRequest.pam.tag = 'g'; @@ -355,6 +361,10 @@ ShutdownConnection(Safekeeper *sk) sk->state = SS_OFFLINE; sk->streamingAt = InvalidXLogRecPtr; + /* BEGIN_HADRON */ + sk->wp->api.update_safekeeper_status_for_metrics(sk->wp, sk->index, 0); + /* END_HADRON */ + MembershipConfigurationFree(&sk->greetResponse.mconf); if (sk->voteResponse.termHistory.entries) pfree(sk->voteResponse.termHistory.entries); @@ -1530,6 +1540,10 @@ StartStreaming(Safekeeper *sk) sk->active_state = SS_ACTIVE_SEND; sk->streamingAt = sk->startStreamingAt; + /* BEGIN_HADRON */ + sk->wp->api.update_safekeeper_status_for_metrics(sk->wp, sk->index, 1); + /* END_HADRON */ + /* * Donors can only be in SS_ACTIVE state, so we potentially update the * donor when we switch one to SS_ACTIVE. diff --git a/pgxn/neon/walproposer.h b/pgxn/neon/walproposer.h index d6cd532bec..ac42c2925d 100644 --- a/pgxn/neon/walproposer.h +++ b/pgxn/neon/walproposer.h @@ -432,6 +432,10 @@ typedef struct WalproposerShmemState /* BEGIN_HADRON */ /* The WAL rate limiter */ WalRateLimiter wal_rate_limiter; + /* Number of safekeepers in the config */ + uint32 num_safekeepers; + /* Per-safekeeper status flags: 0=inactive, 1=active */ + uint8 safekeeper_status[MAX_SAFEKEEPERS]; /* END_HADRON */ } WalproposerShmemState; @@ -483,6 +487,11 @@ typedef struct Safekeeper char const *host; char const *port; + /* BEGIN_HADRON */ + /* index of this safekeeper in the WalProposer array */ + uint32 index; + /* END_HADRON */ + /* * connection string for connecting/reconnecting. * @@ -731,6 +740,23 @@ typedef struct walproposer_api * handled by elog(). */ void (*log_internal) (WalProposer *wp, int level, const char *line); + + /* + * BEGIN_HADRON + * APIs manipulating shared memory state used for Safekeeper quorum health metrics. + */ + + /* + * Reset the safekeeper statuses in shared memory for metric purposes. + */ + void (*reset_safekeeper_statuses_for_metrics) (WalProposer *wp, uint32 num_safekeepers); + + /* + * Update the safekeeper status in shared memory for metric purposes. + */ + void (*update_safekeeper_status_for_metrics) (WalProposer *wp, uint32 sk_index, uint8 status); + + /* END_HADRON */ } walproposer_api; /* diff --git a/pgxn/neon/walproposer_pg.c b/pgxn/neon/walproposer_pg.c index da86c5d498..47b5ec523f 100644 --- a/pgxn/neon/walproposer_pg.c +++ b/pgxn/neon/walproposer_pg.c @@ -2261,6 +2261,27 @@ GetNeonCurrentClusterSize(void) } uint64 GetNeonCurrentClusterSize(void); +/* BEGIN_HADRON */ +static void +walprop_pg_reset_safekeeper_statuses_for_metrics(WalProposer *wp, uint32 num_safekeepers) +{ + WalproposerShmemState* shmem = wp->api.get_shmem_state(wp); + SpinLockAcquire(&shmem->mutex); + shmem->num_safekeepers = num_safekeepers; + memset(shmem->safekeeper_status, 0, sizeof(shmem->safekeeper_status)); + SpinLockRelease(&shmem->mutex); +} + +static void +walprop_pg_update_safekeeper_status_for_metrics(WalProposer *wp, uint32 sk_index, uint8 status) +{ + WalproposerShmemState* shmem = wp->api.get_shmem_state(wp); + Assert(sk_index < MAX_SAFEKEEPERS); + SpinLockAcquire(&shmem->mutex); + shmem->safekeeper_status[sk_index] = status; + SpinLockRelease(&shmem->mutex); +} +/* END_HADRON */ static const walproposer_api walprop_pg = { .get_shmem_state = walprop_pg_get_shmem_state, @@ -2294,4 +2315,6 @@ static const walproposer_api walprop_pg = { .finish_sync_safekeepers = walprop_pg_finish_sync_safekeepers, .process_safekeeper_feedback = walprop_pg_process_safekeeper_feedback, .log_internal = walprop_pg_log_internal, + .reset_safekeeper_statuses_for_metrics = walprop_pg_reset_safekeeper_statuses_for_metrics, + .update_safekeeper_status_for_metrics = walprop_pg_update_safekeeper_status_for_metrics, }; diff --git a/test_runner/regress/test_wal_acceptor.py b/test_runner/regress/test_wal_acceptor.py index c691087259..33d308fb5a 100644 --- a/test_runner/regress/test_wal_acceptor.py +++ b/test_runner/regress/test_wal_acceptor.py @@ -2742,6 +2742,7 @@ def test_pull_timeline_partial_segment_integrity(neon_env_builder: NeonEnvBuilde wait_until(unevicted) +@pytest.mark.skip(reason="Lakebase mode") def test_timeline_disk_usage_limit(neon_env_builder: NeonEnvBuilder): """ Test that the timeline disk usage circuit breaker works as expected. We test that: @@ -2757,18 +2758,32 @@ def test_timeline_disk_usage_limit(neon_env_builder: NeonEnvBuilder): remote_storage_kind = s3_storage() neon_env_builder.enable_safekeeper_remote_storage(remote_storage_kind) - # Set a very small disk usage limit (1KB) - neon_env_builder.safekeeper_extra_opts = ["--max-timeline-disk-usage-bytes=1024"] - env = neon_env_builder.init_start() # Create a timeline and endpoint env.create_branch("test_timeline_disk_usage_limit") endpoint = env.endpoints.create_start("test_timeline_disk_usage_limit") + # Install the neon extension in the test database. We need it to query perf counter metrics. + with closing(endpoint.connect()) as conn: + with conn.cursor() as cur: + cur.execute("CREATE EXTENSION IF NOT EXISTS neon") + # Sanity-check safekeeper connection status in neon_perf_counters in the happy case. + cur.execute( + "SELECT value FROM neon_perf_counters WHERE metric = 'num_active_safekeepers'" + ) + assert cur.fetchone() == (1,), "Expected 1 active safekeeper" + cur.execute( + "SELECT value FROM neon_perf_counters WHERE metric = 'num_configured_safekeepers'" + ) + assert cur.fetchone() == (1,), "Expected 1 configured safekeeper" + # Get the safekeeper sk = env.safekeepers[0] + # Restart the safekeeper with a very small disk usage limit (1KB) + sk.stop().start(["--max-timeline-disk-usage-bytes=1024"]) + # Inject a failpoint to stop WAL backup with sk.http_client() as http_cli: http_cli.configure_failpoints([("backup-lsn-range-pausable", "pause")]) @@ -2794,6 +2809,18 @@ def test_timeline_disk_usage_limit(neon_env_builder: NeonEnvBuilder): wait_until(error_logged) log.info("Found expected error message in compute log, resuming.") + with closing(endpoint.connect()) as conn: + with conn.cursor() as cur: + # Confirm that neon_perf_counters also indicates that there are no active safekeepers + cur.execute( + "SELECT value FROM neon_perf_counters WHERE metric = 'num_active_safekeepers'" + ) + assert cur.fetchone() == (0,), "Expected 0 active safekeepers" + cur.execute( + "SELECT value FROM neon_perf_counters WHERE metric = 'num_configured_safekeepers'" + ) + assert cur.fetchone() == (1,), "Expected 1 configured safekeeper" + # Sanity check that the hanging insert is indeed still hanging. Otherwise means the circuit breaker we # implemented didn't work as expected. time.sleep(2)