diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index 27d33d8cd8..1033670e2b 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -7,7 +7,7 @@ use compute_api::responses::{ }; use compute_api::spec::{ ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, ExtVersion, GenericOption, - PageserverProtocol, PgIdent, Role, + PageserverConnectionInfo, PageserverProtocol, PgIdent, Role, }; use futures::StreamExt; use futures::future::join_all; @@ -38,7 +38,7 @@ use utils::id::{TenantId, TimelineId}; use utils::lsn::Lsn; use utils::measured_stream::MeasuredReader; use utils::pid_file; -use utils::shard::{ShardCount, ShardIndex, ShardNumber}; +use utils::shard::{ShardIndex, ShardNumber, ShardStripeSize}; use crate::configurator::launch_configurator; use crate::disk_quota::set_disk_quota; @@ -249,7 +249,7 @@ pub struct ParsedSpec { pub spec: ComputeSpec, pub tenant_id: TenantId, pub timeline_id: TimelineId, - pub pageserver_connstr: String, + pub pageserver_conninfo: PageserverConnectionInfo, pub safekeeper_connstrings: Vec, pub storage_auth_token: Option, /// k8s dns name and port @@ -297,25 +297,47 @@ impl ParsedSpec { } impl TryFrom for ParsedSpec { - type Error = String; - fn try_from(spec: ComputeSpec) -> Result { + type Error = anyhow::Error; + fn try_from(spec: ComputeSpec) -> Result { // Extract the options from the spec file that are needed to connect to // the storage system. // - // For backwards-compatibility, the top-level fields in the spec file - // may be empty. In that case, we need to dig them from the GUCs in the - // cluster.settings field. - let pageserver_connstr = spec - .pageserver_connstring - .clone() - .or_else(|| spec.cluster.settings.find("neon.pageserver_connstring")) - .ok_or("pageserver connstr should be provided")?; + // In compute specs generated by old control plane versions, the spec file might + // be missing the `pageserver_connection_info` field. In that case, we need to dig + // the pageserver connection info from the `pageserver_connstr` field instead, or + // if that's missing too, from the GUC in the cluster.settings field. + let mut pageserver_conninfo = spec.pageserver_connection_info.clone(); + if pageserver_conninfo.is_none() { + if let Some(pageserver_connstr_field) = &spec.pageserver_connstring { + pageserver_conninfo = Some(PageserverConnectionInfo::from_connstr( + pageserver_connstr_field, + spec.shard_stripe_size, + )?); + } + } + if pageserver_conninfo.is_none() { + if let Some(guc) = spec.cluster.settings.find("neon.pageserver_connstring") { + let stripe_size = if let Some(guc) = spec.cluster.settings.find("neon.stripe_size") + { + Some(ShardStripeSize(u32::from_str(&guc)?)) + } else { + None + }; + pageserver_conninfo = + Some(PageserverConnectionInfo::from_connstr(&guc, stripe_size)?); + } + } + let pageserver_conninfo = pageserver_conninfo.ok_or(anyhow::anyhow!( + "pageserver connection information should be provided" + ))?; + + // Similarly for safekeeper connection strings let safekeeper_connstrings = if spec.safekeeper_connstrings.is_empty() { if matches!(spec.mode, ComputeMode::Primary) { spec.cluster .settings .find("neon.safekeepers") - .ok_or("safekeeper connstrings should be provided")? + .ok_or(anyhow::anyhow!("safekeeper connstrings should be provided"))? .split(',') .map(|str| str.to_string()) .collect() @@ -330,22 +352,22 @@ impl TryFrom for ParsedSpec { let tenant_id: TenantId = if let Some(tenant_id) = spec.tenant_id { tenant_id } else { - spec.cluster + let guc = spec + .cluster .settings .find("neon.tenant_id") - .ok_or("tenant id should be provided") - .map(|s| TenantId::from_str(&s))? - .or(Err("invalid tenant id"))? + .ok_or(anyhow::anyhow!("tenant id should be provided"))?; + TenantId::from_str(&guc).context("invalid tenant id")? }; let timeline_id: TimelineId = if let Some(timeline_id) = spec.timeline_id { timeline_id } else { - spec.cluster + let guc = spec + .cluster .settings .find("neon.timeline_id") - .ok_or("timeline id should be provided") - .map(|s| TimelineId::from_str(&s))? - .or(Err("invalid timeline id"))? + .ok_or(anyhow::anyhow!("timeline id should be provided"))?; + TimelineId::from_str(&guc).context(anyhow::anyhow!("invalid timeline id"))? }; let endpoint_storage_addr: Option = spec @@ -359,7 +381,7 @@ impl TryFrom for ParsedSpec { let res = ParsedSpec { spec, - pageserver_connstr, + pageserver_conninfo, safekeeper_connstrings, storage_auth_token, tenant_id, @@ -369,7 +391,7 @@ impl TryFrom for ParsedSpec { }; // Now check validity of the parsed specification - res.validate()?; + res.validate().map_err(anyhow::Error::msg)?; Ok(res) } } @@ -1195,12 +1217,10 @@ impl ComputeNode { fn try_get_basebackup(&self, compute_state: &ComputeState, lsn: Lsn) -> Result<()> { let spec = compute_state.pspec.as_ref().expect("spec must be set"); - let shard0_connstr = spec.pageserver_connstr.split(',').next().unwrap(); let started = Instant::now(); - - let (connected, size) = match PageserverProtocol::from_connstring(shard0_connstr)? { - PageserverProtocol::Libpq => self.try_get_basebackup_libpq(spec, lsn)?, + let (connected, size) = match spec.pageserver_conninfo.prefer_protocol { PageserverProtocol::Grpc => self.try_get_basebackup_grpc(spec, lsn)?, + PageserverProtocol::Libpq => self.try_get_basebackup_libpq(spec, lsn)?, }; self.fix_zenith_signal_neon_signal()?; @@ -1238,23 +1258,20 @@ impl ComputeNode { /// Fetches a basebackup via gRPC. The connstring must use grpc://. Returns the timestamp when /// the connection was established, and the (compressed) size of the basebackup. fn try_get_basebackup_grpc(&self, spec: &ParsedSpec, lsn: Lsn) -> Result<(Instant, usize)> { - let shard0_connstr = spec - .pageserver_connstr - .split(',') - .next() - .unwrap() - .to_string(); - let shard_index = match spec.pageserver_connstr.split(',').count() as u8 { - 0 | 1 => ShardIndex::unsharded(), - count => ShardIndex::new(ShardNumber(0), ShardCount(count)), + let shard0_index = ShardIndex { + shard_number: ShardNumber(0), + shard_count: spec.pageserver_conninfo.shard_count, }; - + let shard0_url = spec + .pageserver_conninfo + .shard_url(ShardNumber(0), PageserverProtocol::Grpc)? + .to_owned(); let (reader, connected) = tokio::runtime::Handle::current().block_on(async move { let mut client = page_api::Client::connect( - shard0_connstr, + shard0_url, spec.tenant_id, spec.timeline_id, - shard_index, + shard0_index, spec.storage_auth_token.clone(), None, // NB: base backups use payload compression ) @@ -1286,7 +1303,9 @@ impl ComputeNode { /// Fetches a basebackup via libpq. The connstring must use postgresql://. Returns the timestamp /// when the connection was established, and the (compressed) size of the basebackup. fn try_get_basebackup_libpq(&self, spec: &ParsedSpec, lsn: Lsn) -> Result<(Instant, usize)> { - let shard0_connstr = spec.pageserver_connstr.split(',').next().unwrap(); + let shard0_connstr = spec + .pageserver_conninfo + .shard_url(ShardNumber(0), PageserverProtocol::Libpq)?; let mut config = postgres::Config::from_str(shard0_connstr)?; // Use the storage auth token from the config file, if given. @@ -1373,10 +1392,7 @@ impl ComputeNode { return result; } Err(ref e) if attempts < max_attempts => { - warn!( - "Failed to get basebackup: {} (attempt {}/{})", - e, attempts, max_attempts - ); + warn!("Failed to get basebackup: {e:?} (attempt {attempts}/{max_attempts})"); std::thread::sleep(std::time::Duration::from_millis(retry_period_ms as u64)); retry_period_ms *= 1.5; } @@ -1589,16 +1605,8 @@ impl ComputeNode { } }; - info!( - "getting basebackup@{} from pageserver {}", - lsn, &pspec.pageserver_connstr - ); - self.get_basebackup(compute_state, lsn).with_context(|| { - format!( - "failed to get basebackup@{} from pageserver {}", - lsn, &pspec.pageserver_connstr - ) - })?; + self.get_basebackup(compute_state, lsn) + .with_context(|| format!("failed to get basebackup@{lsn}"))?; if let Some(settings) = databricks_settings { copy_tls_certificates( @@ -2642,22 +2650,22 @@ LIMIT 100", /// The operation will time out after a specified duration. pub fn wait_timeout_while_pageserver_connstr_unchanged(&self, duration: Duration) { let state = self.state.lock().unwrap(); - let old_pageserver_connstr = state + let old_pageserver_conninfo = state .pspec .as_ref() .expect("spec must be set") - .pageserver_connstr + .pageserver_conninfo .clone(); let mut unchanged = true; let _ = self .state_changed .wait_timeout_while(state, duration, |s| { - let pageserver_connstr = &s + let pageserver_conninfo = &s .pspec .as_ref() .expect("spec must be set") - .pageserver_connstr; - unchanged = pageserver_connstr == &old_pageserver_connstr; + .pageserver_conninfo; + unchanged = pageserver_conninfo == &old_pageserver_conninfo; unchanged }) .unwrap(); @@ -2915,7 +2923,10 @@ mod tests { match ParsedSpec::try_from(spec.clone()) { Ok(_p) => panic!("Failed to detect duplicate entry"), - Err(e) => assert!(e.starts_with("duplicate entry in safekeeper_connstrings:")), + Err(e) => assert!( + e.to_string() + .starts_with("duplicate entry in safekeeper_connstrings:") + ), }; } } diff --git a/compute_tools/src/config.rs b/compute_tools/src/config.rs index 55a1eda0b7..e7dde5c5f5 100644 --- a/compute_tools/src/config.rs +++ b/compute_tools/src/config.rs @@ -18,6 +18,8 @@ use crate::pg_helpers::{ }; use crate::tls::{self, SERVER_CRT, SERVER_KEY}; +use utils::shard::{ShardIndex, ShardNumber}; + /// Check that `line` is inside a text file and put it there if it is not. /// Create file if it doesn't exist. pub fn line_in_file(path: &Path, line: &str) -> Result { @@ -69,9 +71,75 @@ pub fn write_postgres_conf( } // Add options for connecting to storage writeln!(file, "# Neon storage settings")?; - if let Some(s) = &spec.pageserver_connstring { - writeln!(file, "neon.pageserver_connstring={}", escape_conf_value(s))?; + writeln!(file)?; + if let Some(conninfo) = &spec.pageserver_connection_info { + let mut libpq_urls: Option> = Some(Vec::new()); + let num_shards = if conninfo.shard_count.0 == 0 { + 1 // unsharded, treat it as a single shard + } else { + conninfo.shard_count.0 + }; + + for shard_number in 0..num_shards { + let shard_index = ShardIndex { + shard_number: ShardNumber(shard_number), + shard_count: conninfo.shard_count, + }; + let info = conninfo.shards.get(&shard_index).ok_or_else(|| { + anyhow::anyhow!( + "shard {shard_index} missing from pageserver_connection_info shard map" + ) + })?; + + let first_pageserver = info + .pageservers + .first() + .expect("must have at least one pageserver"); + + // Add the libpq URL to the array, or if the URL is missing, reset the array + // forgetting any previous entries. All servers must have a libpq URL, or none + // at all. + if let Some(url) = &first_pageserver.libpq_url { + if let Some(ref mut urls) = libpq_urls { + urls.push(url.clone()); + } + } else { + libpq_urls = None + } + } + if let Some(libpq_urls) = libpq_urls { + writeln!( + file, + "# derived from compute spec's pageserver_conninfo field" + )?; + writeln!( + file, + "neon.pageserver_connstring={}", + escape_conf_value(&libpq_urls.join(",")) + )?; + } else { + writeln!(file, "# no neon.pageserver_connstring")?; + } + + if let Some(stripe_size) = conninfo.stripe_size { + writeln!( + file, + "# from compute spec's pageserver_conninfo.stripe_size field" + )?; + writeln!(file, "neon.stripe_size={stripe_size}")?; + } + } else { + if let Some(s) = &spec.pageserver_connstring { + writeln!(file, "# from compute spec's pageserver_connstring field")?; + writeln!(file, "neon.pageserver_connstring={}", escape_conf_value(s))?; + } + + if let Some(stripe_size) = spec.shard_stripe_size { + writeln!(file, "# from compute spec's shard_stripe_size field")?; + writeln!(file, "neon.stripe_size={stripe_size}")?; + } } + if !spec.safekeeper_connstrings.is_empty() { let mut neon_safekeepers_value = String::new(); tracing::info!( diff --git a/compute_tools/src/configurator.rs b/compute_tools/src/configurator.rs index feca8337b2..79eb80c4a0 100644 --- a/compute_tools/src/configurator.rs +++ b/compute_tools/src/configurator.rs @@ -122,8 +122,11 @@ fn configurator_main_loop(compute: &Arc) { // into the type system. assert_eq!(state.status, ComputeStatus::RefreshConfiguration); - if state.pspec.as_ref().map(|ps| ps.pageserver_connstr.clone()) - == Some(pspec.pageserver_connstr.clone()) + if state + .pspec + .as_ref() + .map(|ps| ps.pageserver_conninfo.clone()) + == Some(pspec.pageserver_conninfo.clone()) { info!( "Refresh configuration: Retrieved spec is the same as the current spec. Waiting for control plane to update the spec before attempting reconfiguration." diff --git a/compute_tools/src/lsn_lease.rs b/compute_tools/src/lsn_lease.rs index bb0828429d..6abfea82e0 100644 --- a/compute_tools/src/lsn_lease.rs +++ b/compute_tools/src/lsn_lease.rs @@ -4,14 +4,13 @@ use std::thread; use std::time::{Duration, SystemTime}; use anyhow::{Result, bail}; -use compute_api::spec::{ComputeMode, PageserverProtocol}; -use itertools::Itertools as _; +use compute_api::spec::{ComputeMode, PageserverConnectionInfo, PageserverProtocol}; use pageserver_page_api as page_api; use postgres::{NoTls, SimpleQueryMessage}; use tracing::{info, warn}; use utils::id::{TenantId, TimelineId}; use utils::lsn::Lsn; -use utils::shard::{ShardCount, ShardNumber, TenantShardId}; +use utils::shard::TenantShardId; use crate::compute::ComputeNode; @@ -78,17 +77,16 @@ fn acquire_lsn_lease_with_retry( loop { // Note: List of pageservers is dynamic, need to re-read configs before each attempt. - let (connstrings, auth) = { + let (conninfo, auth) = { let state = compute.state.lock().unwrap(); let spec = state.pspec.as_ref().expect("spec must be set"); ( - spec.pageserver_connstr.clone(), + spec.pageserver_conninfo.clone(), spec.storage_auth_token.clone(), ) }; - let result = - try_acquire_lsn_lease(&connstrings, auth.as_deref(), tenant_id, timeline_id, lsn); + let result = try_acquire_lsn_lease(conninfo, auth.as_deref(), tenant_id, timeline_id, lsn); match result { Ok(Some(res)) => { return Ok(res); @@ -112,35 +110,44 @@ fn acquire_lsn_lease_with_retry( /// Tries to acquire LSN leases on all Pageserver shards. fn try_acquire_lsn_lease( - connstrings: &str, + conninfo: PageserverConnectionInfo, auth: Option<&str>, tenant_id: TenantId, timeline_id: TimelineId, lsn: Lsn, ) -> Result> { - let connstrings = connstrings.split(',').collect_vec(); - let shard_count = connstrings.len(); let mut leases = Vec::new(); - for (shard_number, &connstring) in connstrings.iter().enumerate() { - let tenant_shard_id = match shard_count { - 0 | 1 => TenantShardId::unsharded(tenant_id), - shard_count => TenantShardId { - tenant_id, - shard_number: ShardNumber(shard_number as u8), - shard_count: ShardCount::new(shard_count as u8), - }, + for (shard_index, shard) in conninfo.shards.into_iter() { + let tenant_shard_id = TenantShardId { + tenant_id, + shard_number: shard_index.shard_number, + shard_count: shard_index.shard_count, }; - let lease = match PageserverProtocol::from_connstring(connstring)? { - PageserverProtocol::Libpq => { - acquire_lsn_lease_libpq(connstring, auth, tenant_shard_id, timeline_id, lsn)? - } - PageserverProtocol::Grpc => { - acquire_lsn_lease_grpc(connstring, auth, tenant_shard_id, timeline_id, lsn)? - } - }; - leases.push(lease); + // XXX: If there are more than pageserver for the one shard, do we need to get a + // leas on all of them? Currently, that's what we assume, but this is hypothetical + // as of this writing, as we never pass the info for more than one pageserver per + // shard. + for pageserver in shard.pageservers { + let lease = match conninfo.prefer_protocol { + PageserverProtocol::Grpc => acquire_lsn_lease_grpc( + &pageserver.grpc_url.unwrap(), + auth, + tenant_shard_id, + timeline_id, + lsn, + )?, + PageserverProtocol::Libpq => acquire_lsn_lease_libpq( + &pageserver.libpq_url.unwrap(), + auth, + tenant_shard_id, + timeline_id, + lsn, + )?, + }; + leases.push(lease); + } } Ok(leases.into_iter().min().flatten()) diff --git a/control_plane/src/bin/neon_local.rs b/control_plane/src/bin/neon_local.rs index c126835066..95500b0b18 100644 --- a/control_plane/src/bin/neon_local.rs +++ b/control_plane/src/bin/neon_local.rs @@ -19,6 +19,9 @@ use compute_api::requests::ComputeClaimsScope; use compute_api::spec::{ComputeMode, PageserverProtocol}; use control_plane::broker::StorageBroker; use control_plane::endpoint::{ComputeControlPlane, EndpointTerminateMode}; +use control_plane::endpoint::{ + local_pageserver_conf_to_conn_info, tenant_locate_response_to_conn_info, +}; use control_plane::endpoint_storage::{ENDPOINT_STORAGE_DEFAULT_ADDR, EndpointStorage}; use control_plane::local_env; use control_plane::local_env::{ @@ -44,7 +47,6 @@ use pageserver_api::models::{ }; use pageserver_api::shard::{DEFAULT_STRIPE_SIZE, ShardCount, ShardStripeSize, TenantShardId}; use postgres_backend::AuthType; -use postgres_connection::parse_host_port; use safekeeper_api::membership::{SafekeeperGeneration, SafekeeperId}; use safekeeper_api::{ DEFAULT_HTTP_LISTEN_PORT as DEFAULT_SAFEKEEPER_HTTP_PORT, @@ -52,7 +54,6 @@ use safekeeper_api::{ }; use storage_broker::DEFAULT_LISTEN_ADDR as DEFAULT_BROKER_ADDR; use tokio::task::JoinSet; -use url::Host; use utils::auth::{Claims, Scope}; use utils::id::{NodeId, TenantId, TenantTimelineId, TimelineId}; use utils::lsn::Lsn; @@ -1546,62 +1547,41 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res )?; } - let (pageservers, stripe_size) = if let Some(pageserver_id) = pageserver_id { - let conf = env.get_pageserver_conf(pageserver_id).unwrap(); - // Use gRPC if requested. - let pageserver = if endpoint.grpc { - let grpc_addr = conf.listen_grpc_addr.as_ref().expect("bad config"); - let (host, port) = parse_host_port(grpc_addr)?; - let port = port.unwrap_or(DEFAULT_PAGESERVER_GRPC_PORT); - (PageserverProtocol::Grpc, host, port) - } else { - let (host, port) = parse_host_port(&conf.listen_pg_addr)?; - let port = port.unwrap_or(5432); - (PageserverProtocol::Libpq, host, port) - }; - // If caller is telling us what pageserver to use, this is not a tenant which is - // fully managed by storage controller, therefore not sharded. - (vec![pageserver], DEFAULT_STRIPE_SIZE) + let prefer_protocol = if endpoint.grpc { + PageserverProtocol::Grpc + } else { + PageserverProtocol::Libpq + }; + + let mut pageserver_conninfo = if let Some(ps_id) = pageserver_id { + let conf = env.get_pageserver_conf(ps_id).unwrap(); + local_pageserver_conf_to_conn_info(conf)? } else { // Look up the currently attached location of the tenant, and its striping metadata, // to pass these on to postgres. let storage_controller = StorageController::from_env(env); let locate_result = storage_controller.tenant_locate(endpoint.tenant_id).await?; - let pageservers = futures::future::try_join_all( - locate_result.shards.into_iter().map(|shard| async move { - if let ComputeMode::Static(lsn) = endpoint.mode { - // Initialize LSN leases for static computes. + assert!(!locate_result.shards.is_empty()); + + // Initialize LSN leases for static computes. + if let ComputeMode::Static(lsn) = endpoint.mode { + futures::future::try_join_all(locate_result.shards.iter().map( + |shard| async move { let conf = env.get_pageserver_conf(shard.node_id).unwrap(); let pageserver = PageServerNode::from_env(env, conf); pageserver .http_client .timeline_init_lsn_lease(shard.shard_id, endpoint.timeline_id, lsn) - .await?; - } + .await + }, + )) + .await?; + } - let pageserver = if endpoint.grpc { - ( - PageserverProtocol::Grpc, - Host::parse(&shard.listen_grpc_addr.expect("no gRPC address"))?, - shard.listen_grpc_port.expect("no gRPC port"), - ) - } else { - ( - PageserverProtocol::Libpq, - Host::parse(&shard.listen_pg_addr)?, - shard.listen_pg_port, - ) - }; - anyhow::Ok(pageserver) - }), - ) - .await?; - let stripe_size = locate_result.shard_params.stripe_size; - - (pageservers, stripe_size) + tenant_locate_response_to_conn_info(&locate_result)? }; - assert!(!pageservers.is_empty()); + pageserver_conninfo.prefer_protocol = prefer_protocol; let ps_conf = env.get_pageserver_conf(DEFAULT_PAGESERVER_ID)?; let auth_token = if matches!(ps_conf.pg_auth_type, AuthType::NeonJWT) { @@ -1631,9 +1611,8 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res endpoint_storage_addr, safekeepers_generation, safekeepers, - pageservers, + pageserver_conninfo, remote_ext_base_url: remote_ext_base_url.clone(), - shard_stripe_size: stripe_size.0 as usize, create_test_user: args.create_test_user, start_timeout: args.start_timeout, autoprewarm: args.autoprewarm, @@ -1650,37 +1629,29 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res .endpoints .get(endpoint_id.as_str()) .with_context(|| format!("postgres endpoint {endpoint_id} is not found"))?; - let pageservers = match args.pageserver_id { + let prefer_protocol = if endpoint.grpc { + PageserverProtocol::Grpc + } else { + PageserverProtocol::Libpq + }; + let mut pageserver_conninfo = match args.pageserver_id { Some(pageserver_id) => { - let pageserver = - PageServerNode::from_env(env, env.get_pageserver_conf(pageserver_id)?); - - vec![( - PageserverProtocol::Libpq, - pageserver.pg_connection_config.host().clone(), - pageserver.pg_connection_config.port(), - )] + let conf = env.get_pageserver_conf(pageserver_id)?; + local_pageserver_conf_to_conn_info(conf)? } None => { let storage_controller = StorageController::from_env(env); - storage_controller - .tenant_locate(endpoint.tenant_id) - .await? - .shards - .into_iter() - .map(|shard| { - ( - PageserverProtocol::Libpq, - Host::parse(&shard.listen_pg_addr) - .expect("Storage controller reported malformed host"), - shard.listen_pg_port, - ) - }) - .collect::>() + let locate_result = + storage_controller.tenant_locate(endpoint.tenant_id).await?; + + tenant_locate_response_to_conn_info(&locate_result)? } }; + pageserver_conninfo.prefer_protocol = prefer_protocol; - endpoint.update_pageservers_in_config(pageservers).await?; + endpoint + .update_pageservers_in_config(&pageserver_conninfo) + .await?; } EndpointCmd::Reconfigure(args) => { let endpoint_id = &args.endpoint_id; @@ -1688,51 +1659,30 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res .endpoints .get(endpoint_id.as_str()) .with_context(|| format!("postgres endpoint {endpoint_id} is not found"))?; - let pageservers = if let Some(ps_id) = args.endpoint_pageserver_id { - let conf = env.get_pageserver_conf(ps_id)?; - // Use gRPC if requested. - let pageserver = if endpoint.grpc { - let grpc_addr = conf.listen_grpc_addr.as_ref().expect("bad config"); - let (host, port) = parse_host_port(grpc_addr)?; - let port = port.unwrap_or(DEFAULT_PAGESERVER_GRPC_PORT); - (PageserverProtocol::Grpc, host, port) - } else { - let (host, port) = parse_host_port(&conf.listen_pg_addr)?; - let port = port.unwrap_or(5432); - (PageserverProtocol::Libpq, host, port) - }; - vec![pageserver] + + let prefer_protocol = if endpoint.grpc { + PageserverProtocol::Grpc } else { - let storage_controller = StorageController::from_env(env); - storage_controller - .tenant_locate(endpoint.tenant_id) - .await? - .shards - .into_iter() - .map(|shard| { - // Use gRPC if requested. - if endpoint.grpc { - ( - PageserverProtocol::Grpc, - Host::parse(&shard.listen_grpc_addr.expect("no gRPC address")) - .expect("bad hostname"), - shard.listen_grpc_port.expect("no gRPC port"), - ) - } else { - ( - PageserverProtocol::Libpq, - Host::parse(&shard.listen_pg_addr).expect("bad hostname"), - shard.listen_pg_port, - ) - } - }) - .collect::>() + PageserverProtocol::Libpq }; + let mut pageserver_conninfo = if let Some(ps_id) = args.endpoint_pageserver_id { + let conf = env.get_pageserver_conf(ps_id)?; + local_pageserver_conf_to_conn_info(conf)? + } else { + // Look up the currently attached location of the tenant, and its striping metadata, + // to pass these on to postgres. + let storage_controller = StorageController::from_env(env); + let locate_result = storage_controller.tenant_locate(endpoint.tenant_id).await?; + + tenant_locate_response_to_conn_info(&locate_result)? + }; + pageserver_conninfo.prefer_protocol = prefer_protocol; + // If --safekeepers argument is given, use only the listed // safekeeper nodes; otherwise all from the env. let safekeepers = parse_safekeepers(&args.safekeepers)?; endpoint - .reconfigure(Some(pageservers), None, safekeepers, None) + .reconfigure(Some(&pageserver_conninfo), safekeepers, None) .await?; } EndpointCmd::RefreshConfiguration(args) => { diff --git a/control_plane/src/endpoint.rs b/control_plane/src/endpoint.rs index 1c7f489d68..814ee2a52f 100644 --- a/control_plane/src/endpoint.rs +++ b/control_plane/src/endpoint.rs @@ -37,7 +37,7 @@ //! //! ``` //! -use std::collections::BTreeMap; +use std::collections::{BTreeMap, HashMap}; use std::fmt::Display; use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream}; use std::path::PathBuf; @@ -58,8 +58,12 @@ use compute_api::responses::{ }; use compute_api::spec::{ Cluster, ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, Database, PageserverProtocol, - PgIdent, RemoteExtSpec, Role, + PageserverShardInfo, PgIdent, RemoteExtSpec, Role, }; + +// re-export these, because they're used in the reconfigure() function +pub use compute_api::spec::{PageserverConnectionInfo, PageserverShardConnectionInfo}; + use jsonwebtoken::jwk::{ AlgorithmParameters, CommonParameters, EllipticCurve, Jwk, JwkSet, KeyAlgorithm, KeyOperations, OctetKeyPairParameters, OctetKeyPairType, PublicKeyUse, @@ -74,9 +78,11 @@ use sha2::{Digest, Sha256}; use spki::der::Decode; use spki::{SubjectPublicKeyInfo, SubjectPublicKeyInfoRef}; use tracing::debug; -use url::Host; use utils::id::{NodeId, TenantId, TimelineId}; -use utils::shard::ShardStripeSize; +use utils::shard::{ShardCount, ShardIndex, ShardNumber}; + +use pageserver_api::config::DEFAULT_GRPC_LISTEN_PORT as DEFAULT_PAGESERVER_GRPC_PORT; +use postgres_connection::parse_host_port; use crate::local_env::LocalEnv; use crate::postgresql_conf::PostgresConf; @@ -387,9 +393,8 @@ pub struct EndpointStartArgs { pub endpoint_storage_addr: String, pub safekeepers_generation: Option, pub safekeepers: Vec, - pub pageservers: Vec<(PageserverProtocol, Host, u16)>, + pub pageserver_conninfo: PageserverConnectionInfo, pub remote_ext_base_url: Option, - pub shard_stripe_size: usize, pub create_test_user: bool, pub start_timeout: Duration, pub autoprewarm: bool, @@ -662,14 +667,6 @@ impl Endpoint { } } - fn build_pageserver_connstr(pageservers: &[(PageserverProtocol, Host, u16)]) -> String { - pageservers - .iter() - .map(|(scheme, host, port)| format!("{scheme}://no_user@{host}:{port}")) - .collect::>() - .join(",") - } - /// Map safekeepers ids to the actual connection strings. fn build_safekeepers_connstrs(&self, sk_ids: Vec) -> Result> { let mut safekeeper_connstrings = Vec::new(); @@ -715,9 +712,6 @@ impl Endpoint { std::fs::remove_dir_all(self.pgdata())?; } - let pageserver_connstring = Self::build_pageserver_connstr(&args.pageservers); - assert!(!pageserver_connstring.is_empty()); - let safekeeper_connstrings = self.build_safekeepers_connstrs(args.safekeepers)?; // check for file remote_extensions_spec.json @@ -732,6 +726,44 @@ impl Endpoint { remote_extensions = None; }; + // For the sake of backwards-compatibility, also fill in 'pageserver_connstring' + // + // XXX: I believe this is not really needed, except to make + // test_forward_compatibility happy. + // + // Use a closure so that we can conviniently return None in the middle of the + // loop. + let pageserver_connstring: Option = (|| { + let num_shards = args.pageserver_conninfo.shard_count.count(); + let mut connstrings = Vec::new(); + for shard_no in 0..num_shards { + let shard_index = ShardIndex { + shard_count: args.pageserver_conninfo.shard_count, + shard_number: ShardNumber(shard_no), + }; + let shard = args + .pageserver_conninfo + .shards + .get(&shard_index) + .ok_or_else(|| { + anyhow!( + "shard {} not found in pageserver_connection_info", + shard_index + ) + })?; + let pageserver = shard + .pageservers + .first() + .ok_or(anyhow!("must have at least one pageserver"))?; + if let Some(libpq_url) = &pageserver.libpq_url { + connstrings.push(libpq_url.clone()); + } else { + return Ok::<_, anyhow::Error>(None); + } + } + Ok(Some(connstrings.join(","))) + })()?; + // Create config file let config = { let mut spec = ComputeSpec { @@ -776,13 +808,14 @@ impl Endpoint { branch_id: None, endpoint_id: Some(self.endpoint_id.clone()), mode: self.mode, - pageserver_connstring: Some(pageserver_connstring), + pageserver_connection_info: Some(args.pageserver_conninfo.clone()), + pageserver_connstring, safekeepers_generation: args.safekeepers_generation.map(|g| g.into_inner()), safekeeper_connstrings, storage_auth_token: args.auth_token.clone(), remote_extensions, pgbouncer_settings: None, - shard_stripe_size: Some(args.shard_stripe_size), + shard_stripe_size: args.pageserver_conninfo.stripe_size, // redundant with pageserver_connection_info.stripe_size local_proxy_config: None, reconfigure_concurrency: self.reconfigure_concurrency, drop_subscriptions_before_start: self.drop_subscriptions_before_start, @@ -966,7 +999,7 @@ impl Endpoint { // Update the pageservers in the spec file of the endpoint. This is useful to test the spec refresh scenario. pub async fn update_pageservers_in_config( &self, - pageservers: Vec<(PageserverProtocol, Host, u16)>, + pageserver_conninfo: &PageserverConnectionInfo, ) -> Result<()> { let config_path = self.endpoint_path().join("config.json"); let mut config: ComputeConfig = { @@ -974,10 +1007,8 @@ impl Endpoint { serde_json::from_reader(file)? }; - let pageserver_connstring = Self::build_pageserver_connstr(&pageservers); - assert!(!pageserver_connstring.is_empty()); let mut spec = config.spec.unwrap(); - spec.pageserver_connstring = Some(pageserver_connstring); + spec.pageserver_connection_info = Some(pageserver_conninfo.clone()); config.spec = Some(spec); let file = std::fs::File::create(&config_path)?; @@ -1020,8 +1051,7 @@ impl Endpoint { pub async fn reconfigure( &self, - pageservers: Option>, - stripe_size: Option, + pageserver_conninfo: Option<&PageserverConnectionInfo>, safekeepers: Option>, safekeeper_generation: Option, ) -> Result<()> { @@ -1036,15 +1066,15 @@ impl Endpoint { let postgresql_conf = self.read_postgresql_conf()?; spec.cluster.postgresql_conf = Some(postgresql_conf); - // If pageservers are not specified, don't change them. - if let Some(pageservers) = pageservers { - anyhow::ensure!(!pageservers.is_empty(), "no pageservers provided"); - - let pageserver_connstr = Self::build_pageserver_connstr(&pageservers); - spec.pageserver_connstring = Some(pageserver_connstr); - if stripe_size.is_some() { - spec.shard_stripe_size = stripe_size.map(|s| s.0 as usize); - } + if let Some(pageserver_conninfo) = pageserver_conninfo { + // If pageservers are provided, we need to ensure that they are not empty. + // This is a requirement for the compute_ctl configuration. + anyhow::ensure!( + !pageserver_conninfo.shards.is_empty(), + "no pageservers provided" + ); + spec.pageserver_connection_info = Some(pageserver_conninfo.clone()); + spec.shard_stripe_size = pageserver_conninfo.stripe_size; } // If safekeepers are not specified, don't change them. @@ -1093,11 +1123,9 @@ impl Endpoint { pub async fn reconfigure_pageservers( &self, - pageservers: Vec<(PageserverProtocol, Host, u16)>, - stripe_size: Option, + pageservers: &PageserverConnectionInfo, ) -> Result<()> { - self.reconfigure(Some(pageservers), stripe_size, None, None) - .await + self.reconfigure(Some(pageservers), None, None).await } pub async fn reconfigure_safekeepers( @@ -1105,7 +1133,7 @@ impl Endpoint { safekeepers: Vec, generation: SafekeeperGeneration, ) -> Result<()> { - self.reconfigure(None, None, Some(safekeepers), Some(generation)) + self.reconfigure(None, Some(safekeepers), Some(generation)) .await } @@ -1188,3 +1216,84 @@ impl Endpoint { ) } } + +/// If caller is telling us what pageserver to use, this is not a tenant which is +/// fully managed by storage controller, therefore not sharded. +pub fn local_pageserver_conf_to_conn_info( + conf: &crate::local_env::PageServerConf, +) -> Result { + let libpq_url = { + let (host, port) = parse_host_port(&conf.listen_pg_addr)?; + let port = port.unwrap_or(5432); + Some(format!("postgres://no_user@{host}:{port}")) + }; + let grpc_url = if let Some(grpc_addr) = &conf.listen_grpc_addr { + let (host, port) = parse_host_port(grpc_addr)?; + let port = port.unwrap_or(DEFAULT_PAGESERVER_GRPC_PORT); + Some(format!("grpc://no_user@{host}:{port}")) + } else { + None + }; + let ps_conninfo = PageserverShardConnectionInfo { + id: Some(conf.id), + libpq_url, + grpc_url, + }; + + let shard_info = PageserverShardInfo { + pageservers: vec![ps_conninfo], + }; + + let shards: HashMap<_, _> = vec![(ShardIndex::unsharded(), shard_info)] + .into_iter() + .collect(); + Ok(PageserverConnectionInfo { + shard_count: ShardCount::unsharded(), + stripe_size: None, + shards, + prefer_protocol: PageserverProtocol::default(), + }) +} + +pub fn tenant_locate_response_to_conn_info( + response: &pageserver_api::controller_api::TenantLocateResponse, +) -> Result { + let mut shards = HashMap::new(); + for shard in response.shards.iter() { + tracing::info!("parsing {}", shard.listen_pg_addr); + let libpq_url = { + let host = &shard.listen_pg_addr; + let port = shard.listen_pg_port; + Some(format!("postgres://no_user@{host}:{port}")) + }; + let grpc_url = if let Some(grpc_addr) = &shard.listen_grpc_addr { + let host = grpc_addr; + let port = shard.listen_grpc_port.expect("no gRPC port"); + Some(format!("grpc://no_user@{host}:{port}")) + } else { + None + }; + + let shard_info = PageserverShardInfo { + pageservers: vec![PageserverShardConnectionInfo { + id: Some(shard.node_id), + libpq_url, + grpc_url, + }], + }; + + shards.insert(shard.shard_id.to_index(), shard_info); + } + + let stripe_size = if response.shard_params.count.is_unsharded() { + None + } else { + Some(response.shard_params.stripe_size) + }; + Ok(PageserverConnectionInfo { + shard_count: response.shard_params.count, + stripe_size, + shards, + prefer_protocol: PageserverProtocol::default(), + }) +} diff --git a/libs/compute_api/src/spec.rs b/libs/compute_api/src/spec.rs index 6709c06fc6..12d825e1bf 100644 --- a/libs/compute_api/src/spec.rs +++ b/libs/compute_api/src/spec.rs @@ -12,8 +12,9 @@ use regex::Regex; use remote_storage::RemotePath; use serde::{Deserialize, Serialize}; use url::Url; -use utils::id::{TenantId, TimelineId}; +use utils::id::{NodeId, TenantId, TimelineId}; use utils::lsn::Lsn; +use utils::shard::{ShardCount, ShardIndex, ShardNumber, ShardStripeSize}; use crate::responses::TlsConfig; @@ -105,8 +106,27 @@ pub struct ComputeSpec { // updated to fill these fields, we can make these non optional. pub tenant_id: Option, pub timeline_id: Option, + + /// Pageserver information can be passed in three different ways: + /// 1. Here in `pageserver_connection_info` + /// 2. In the `pageserver_connstring` field. + /// 3. in `cluster.settings`. + /// + /// The goal is to use method 1. everywhere. But for backwards-compatibility with old + /// versions of the control plane, `compute_ctl` will check 2. and 3. if the + /// `pageserver_connection_info` field is missing. + /// + /// If both `pageserver_connection_info` and `pageserver_connstring`+`shard_stripe_size` are + /// given, they must contain the same information. + pub pageserver_connection_info: Option, + pub pageserver_connstring: Option, + /// Stripe size for pageserver sharding, in pages. This is set together with the legacy + /// `pageserver_connstring` field. When the modern `pageserver_connection_info` field is used, + /// the stripe size is stored in `pageserver_connection_info.stripe_size` instead. + pub shard_stripe_size: Option, + // More neon ids that we expose to the compute_ctl // and to postgres as neon extension GUCs. pub project_id: Option, @@ -139,10 +159,6 @@ pub struct ComputeSpec { pub pgbouncer_settings: Option>, - // Stripe size for pageserver sharding, in pages - #[serde(default)] - pub shard_stripe_size: Option, - /// Local Proxy configuration used for JWT authentication #[serde(default)] pub local_proxy_config: Option, @@ -217,6 +233,140 @@ pub enum ComputeFeature { UnknownFeature, } +#[derive(Clone, Debug, Deserialize, Serialize, Eq, PartialEq)] +pub struct PageserverConnectionInfo { + /// NB: 0 for unsharded tenants, 1 for sharded tenants with 1 shard, following storage + pub shard_count: ShardCount, + + /// INVARIANT: null if shard_count is 0, otherwise non-null and immutable + pub stripe_size: Option, + + pub shards: HashMap, + + /// If the compute supports both protocols, this indicates which one it should use. The compute + /// may use other available protocols too, if it doesn't support the preferred one. The URL's + /// for the protocol specified here must be present for all shards, i.e. do not mark a protocol + /// as preferred if it cannot actually be used with all the pageservers. + #[serde(default)] + pub prefer_protocol: PageserverProtocol, +} + +/// Extract PageserverConnectionInfo from a comma-separated list of libpq connection strings. +/// +/// This is used for backwards-compatibility, to parse the legacy +/// [ComputeSpec::pageserver_connstring] field, or the 'neon.pageserver_connstring' GUC. Nowadays, +/// the 'pageserver_connection_info' field should be used instead. +impl PageserverConnectionInfo { + pub fn from_connstr( + connstr: &str, + stripe_size: Option, + ) -> Result { + let shard_infos: Vec<_> = connstr + .split(',') + .map(|connstr| PageserverShardInfo { + pageservers: vec![PageserverShardConnectionInfo { + id: None, + libpq_url: Some(connstr.to_string()), + grpc_url: None, + }], + }) + .collect(); + + match shard_infos.len() { + 0 => anyhow::bail!("empty connection string"), + 1 => { + // We assume that if there's only connection string, it means "unsharded", + // rather than a sharded system with just a single shard. The latter is + // possible in principle, but we never do it. + let shard_count = ShardCount::unsharded(); + let only_shard = shard_infos.first().unwrap().clone(); + let shards = vec![(ShardIndex::unsharded(), only_shard)]; + Ok(PageserverConnectionInfo { + shard_count, + stripe_size: None, + shards: shards.into_iter().collect(), + prefer_protocol: PageserverProtocol::Libpq, + }) + } + n => { + if stripe_size.is_none() { + anyhow::bail!("{n} shards but no stripe_size"); + } + let shard_count = ShardCount(n.try_into()?); + let shards = shard_infos + .into_iter() + .enumerate() + .map(|(idx, shard_info)| { + ( + ShardIndex { + shard_count, + shard_number: ShardNumber( + idx.try_into().expect("shard number fits in u8"), + ), + }, + shard_info, + ) + }) + .collect(); + Ok(PageserverConnectionInfo { + shard_count, + stripe_size, + shards, + prefer_protocol: PageserverProtocol::Libpq, + }) + } + } + } + + /// Convenience routine to get the connection string for a shard. + pub fn shard_url( + &self, + shard_number: ShardNumber, + protocol: PageserverProtocol, + ) -> anyhow::Result<&str> { + let shard_index = ShardIndex { + shard_number, + shard_count: self.shard_count, + }; + let shard = self.shards.get(&shard_index).ok_or_else(|| { + anyhow::anyhow!("shard connection info missing for shard {}", shard_index) + })?; + + // Just use the first pageserver in the list. That's good enough for this + // convenience routine; if you need more control, like round robin policy or + // failover support, roll your own. (As of this writing, we never have more than + // one pageserver per shard anyway, but that will change in the future.) + let pageserver = shard + .pageservers + .first() + .ok_or(anyhow::anyhow!("must have at least one pageserver"))?; + + let result = match protocol { + PageserverProtocol::Grpc => pageserver + .grpc_url + .as_ref() + .ok_or(anyhow::anyhow!("no grpc_url for shard {shard_index}"))?, + PageserverProtocol::Libpq => pageserver + .libpq_url + .as_ref() + .ok_or(anyhow::anyhow!("no libpq_url for shard {shard_index}"))?, + }; + Ok(result) + } +} + +#[derive(Clone, Debug, Deserialize, Serialize, Eq, PartialEq)] +pub struct PageserverShardInfo { + pub pageservers: Vec, +} + +#[derive(Clone, Debug, Deserialize, Serialize, Eq, PartialEq)] +pub struct PageserverShardConnectionInfo { + pub id: Option, + pub libpq_url: Option, + pub grpc_url: Option, +} + #[derive(Clone, Debug, Default, Deserialize, Serialize)] pub struct RemoteExtSpec { pub public_extensions: Option>, @@ -334,6 +484,12 @@ impl ComputeMode { } } +impl Display for ComputeMode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.to_type_str()) + } +} + /// Log level for audit logging #[derive(Clone, Debug, Default, Eq, PartialEq, Deserialize, Serialize)] pub enum ComputeAudit { @@ -470,13 +626,15 @@ pub struct JwksSettings { pub jwt_audience: Option, } -/// Protocol used to connect to a Pageserver. Parsed from the connstring scheme. -#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +/// Protocol used to connect to a Pageserver. +#[derive(Clone, Copy, Debug, Default, Deserialize, Serialize, PartialEq, Eq)] pub enum PageserverProtocol { /// The original protocol based on libpq and COPY. Uses postgresql:// or postgres:// scheme. #[default] + #[serde(rename = "libpq")] Libpq, /// A newer, gRPC-based protocol. Uses grpc:// scheme. + #[serde(rename = "grpc")] Grpc, } diff --git a/libs/utils/src/shard.rs b/libs/utils/src/shard.rs index 6ad6cab3a8..90323f7762 100644 --- a/libs/utils/src/shard.rs +++ b/libs/utils/src/shard.rs @@ -59,6 +59,10 @@ impl ShardCount { pub const MAX: Self = Self(u8::MAX); pub const MIN: Self = Self(0); + pub fn unsharded() -> Self { + ShardCount(0) + } + /// The internal value of a ShardCount may be zero, which means "1 shard, but use /// legacy format for TenantShardId that excludes the shard suffix", also known /// as [`TenantShardId::unsharded`]. diff --git a/pgxn/neon/libpagestore.c b/pgxn/neon/libpagestore.c index 158118f860..87bdbf376c 100644 --- a/pgxn/neon/libpagestore.c +++ b/pgxn/neon/libpagestore.c @@ -71,7 +71,7 @@ char *neon_project_id; char *neon_branch_id; char *neon_endpoint_id; int32 max_cluster_size; -char *page_server_connstring; +char *pageserver_connstring; char *neon_auth_token; int readahead_buffer_size = 128; @@ -1453,7 +1453,7 @@ PagestoreShmemInit(void) pg_atomic_init_u64(&pagestore_shared->begin_update_counter, 0); pg_atomic_init_u64(&pagestore_shared->end_update_counter, 0); memset(&pagestore_shared->shard_map, 0, sizeof(ShardMap)); - AssignPageserverConnstring(page_server_connstring, NULL); + AssignPageserverConnstring(pageserver_connstring, NULL); } } @@ -1472,7 +1472,7 @@ pg_init_libpagestore(void) DefineCustomStringVariable("neon.pageserver_connstring", "connection string to the page server", NULL, - &page_server_connstring, + &pageserver_connstring, "", PGC_SIGHUP, 0, /* no flags required */ @@ -1643,7 +1643,7 @@ pg_init_libpagestore(void) if (neon_auth_token) neon_log(LOG, "using storage auth token from NEON_AUTH_TOKEN environment variable"); - if (page_server_connstring && page_server_connstring[0]) + if (pageserver_connstring[0]) { neon_log(PageStoreTrace, "set neon_smgr hook"); smgr_hook = smgr_neon; diff --git a/pgxn/neon/pagestore_client.h b/pgxn/neon/pagestore_client.h index 4470d3a94d..bfe00c9285 100644 --- a/pgxn/neon/pagestore_client.h +++ b/pgxn/neon/pagestore_client.h @@ -236,7 +236,7 @@ extern void prefetch_on_ps_disconnect(void); extern page_server_api *page_server; -extern char *page_server_connstring; +extern char *pageserver_connstring; extern int flush_every_n_requests; extern int readahead_buffer_size; extern char *neon_timeline; diff --git a/storage_controller/src/compute_hook.rs b/storage_controller/src/compute_hook.rs index fb03412f3c..efeb6005d5 100644 --- a/storage_controller/src/compute_hook.rs +++ b/storage_controller/src/compute_hook.rs @@ -6,13 +6,16 @@ use std::time::Duration; use anyhow::Context; use compute_api::spec::PageserverProtocol; -use control_plane::endpoint::{ComputeControlPlane, EndpointStatus}; +use compute_api::spec::PageserverShardInfo; +use control_plane::endpoint::{ + ComputeControlPlane, EndpointStatus, PageserverConnectionInfo, PageserverShardConnectionInfo, +}; use control_plane::local_env::LocalEnv; use futures::StreamExt; use hyper::StatusCode; use pageserver_api::config::DEFAULT_GRPC_LISTEN_PORT; use pageserver_api::controller_api::AvailabilityZone; -use pageserver_api::shard::{ShardCount, ShardNumber, ShardStripeSize, TenantShardId}; +use pageserver_api::shard::{ShardCount, ShardIndex, ShardNumber, ShardStripeSize, TenantShardId}; use postgres_connection::parse_host_port; use safekeeper_api::membership::SafekeeperGeneration; use serde::{Deserialize, Serialize}; @@ -506,27 +509,64 @@ impl ApiMethod for ComputeHookTenant { if endpoint.tenant_id == *tenant_id && endpoint.status() == EndpointStatus::Running { tracing::info!("Reconfiguring pageservers for endpoint {endpoint_name}"); - let pageservers = shards - .iter() - .map(|shard| { - let ps_conf = env - .get_pageserver_conf(shard.node_id) - .expect("Unknown pageserver"); - if endpoint.grpc { - let addr = ps_conf.listen_grpc_addr.as_ref().expect("no gRPC address"); - let (host, port) = parse_host_port(addr).expect("invalid gRPC address"); - let port = port.unwrap_or(DEFAULT_GRPC_LISTEN_PORT); - (PageserverProtocol::Grpc, host, port) - } else { - let (host, port) = parse_host_port(&ps_conf.listen_pg_addr) - .expect("Unable to parse listen_pg_addr"); - (PageserverProtocol::Libpq, host, port.unwrap_or(5432)) - } - }) - .collect::>(); + let shard_count = match shards.len() { + 1 => ShardCount::unsharded(), + n => ShardCount(n.try_into().expect("too many shards")), + }; + + let mut shard_infos: HashMap = HashMap::new(); + + let prefer_protocol = if endpoint.grpc { + PageserverProtocol::Grpc + } else { + PageserverProtocol::Libpq + }; + + for shard in shards.iter() { + let ps_conf = env + .get_pageserver_conf(shard.node_id) + .expect("Unknown pageserver"); + + let libpq_url = Some({ + let (host, port) = parse_host_port(&ps_conf.listen_pg_addr) + .expect("Unable to parse listen_pg_addr"); + let port = port.unwrap_or(5432); + format!("postgres://no_user@{host}:{port}") + }); + let grpc_url = if let Some(grpc_addr) = &ps_conf.listen_grpc_addr { + let (host, port) = + parse_host_port(grpc_addr).expect("invalid gRPC address"); + let port = port.unwrap_or(DEFAULT_GRPC_LISTEN_PORT); + Some(format!("grpc://no_user@{host}:{port}")) + } else { + None + }; + let pageserver = PageserverShardConnectionInfo { + id: Some(shard.node_id), + libpq_url, + grpc_url, + }; + let shard_info = PageserverShardInfo { + pageservers: vec![pageserver], + }; + shard_infos.insert( + ShardIndex { + shard_number: shard.shard_number, + shard_count, + }, + shard_info, + ); + } + + let pageserver_conninfo = PageserverConnectionInfo { + shard_count, + stripe_size: stripe_size.map(|val| ShardStripeSize(val.0)), + shards: shard_infos, + prefer_protocol, + }; endpoint - .reconfigure_pageservers(pageservers, *stripe_size) + .reconfigure_pageservers(&pageserver_conninfo) .await .map_err(NotifyError::NeonLocal)?; } diff --git a/test_runner/regress/test_basebackup.py b/test_runner/regress/test_basebackup.py index d1b10ec85d..23b9105617 100644 --- a/test_runner/regress/test_basebackup.py +++ b/test_runner/regress/test_basebackup.py @@ -2,13 +2,15 @@ from __future__ import annotations from typing import TYPE_CHECKING +import pytest from fixtures.utils import wait_until if TYPE_CHECKING: from fixtures.neon_fixtures import NeonEnvBuilder -def test_basebackup_cache(neon_env_builder: NeonEnvBuilder): +@pytest.mark.parametrize("grpc", [True, False]) +def test_basebackup_cache(neon_env_builder: NeonEnvBuilder, grpc: bool): """ Simple test for basebackup cache. 1. Check that we always hit the cache after compute restart. @@ -22,7 +24,7 @@ def test_basebackup_cache(neon_env_builder: NeonEnvBuilder): """ env = neon_env_builder.init_start() - ep = env.endpoints.create("main") + ep = env.endpoints.create("main", grpc=grpc) ps = env.pageserver ps_http = ps.http_client()