diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index 975abd196a..ce19563679 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -6,8 +6,8 @@ use compute_api::responses::{ LfcPrewarmState, TlsConfig, }; use compute_api::spec::{ - ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, ExtVersion, - PageserverConnectionInfo, PageserverShardConnectionInfo, PgIdent, + ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, ExtVersion, PageserverConnectionInfo, + PageserverShardConnectionInfo, PgIdent, }; use futures::StreamExt; use futures::future::join_all; @@ -265,17 +265,23 @@ impl ParsedSpec { } } -fn extract_pageserver_conninfo_from_guc(pageserver_connstring_guc: &str) -> PageserverConnectionInfo { - +fn extract_pageserver_conninfo_from_guc( + pageserver_connstring_guc: &str, +) -> PageserverConnectionInfo { PageserverConnectionInfo { shards: pageserver_connstring_guc .split(',') .into_iter() .enumerate() - .map(|(i, connstr)| (i as u32, PageserverShardConnectionInfo { - libpq_url: Some(connstr.to_string()), - grpc_url: None, - })) + .map(|(i, connstr)| { + ( + i as u32, + PageserverShardConnectionInfo { + libpq_url: Some(connstr.to_string()), + grpc_url: None, + }, + ) + }) .collect(), prefer_grpc: false, } @@ -1041,13 +1047,18 @@ impl ComputeNode { fn try_get_basebackup_grpc(&self, spec: &ParsedSpec, lsn: Lsn) -> Result<()> { let start_time = Instant::now(); - let shard0 = spec.pageserver_conninfo.shards.get(&0).expect("shard 0 connection info missing"); + let shard0 = spec + .pageserver_conninfo + .shards + .get(&0) + .expect("shard 0 connection info missing"); let shard0_url = shard0.grpc_url.clone().expect("no grpc_url for shard 0"); info!("getting basebackup@{} from pageserver {}", lsn, shard0_url); - + let chunks = tokio::runtime::Handle::current().block_on(async move { - let mut client = page_api::proto::PageServiceClient::connect(shard0_url.to_string()).await?; + let mut client = + page_api::proto::PageServiceClient::connect(shard0_url.to_string()).await?; let req = page_api::proto::GetBaseBackupRequest { lsn: lsn.0, @@ -1098,9 +1109,16 @@ impl ComputeNode { fn try_get_basebackup_libpq(&self, spec: &ParsedSpec, lsn: Lsn) -> Result<()> { let start_time = Instant::now(); - let shard0 = spec.pageserver_conninfo.shards.get(&0).expect("shard 0 connection info missing"); + let shard0 = spec + .pageserver_conninfo + .shards + .get(&0) + .expect("shard 0 connection info missing"); let shard0_connstr = shard0.libpq_url.clone().expect("no libpq_url for shard 0"); - info!("getting basebackup@{} from pageserver {}", lsn, shard0_connstr); + info!( + "getting basebackup@{} from pageserver {}", + lsn, shard0_connstr + ); let mut config = postgres::Config::from_str(&shard0_connstr)?; @@ -1400,9 +1418,8 @@ impl ComputeNode { } }; - self.get_basebackup(compute_state, lsn).with_context(|| { - format!("failed to get basebackup@{}", lsn) - })?; + self.get_basebackup(compute_state, lsn) + .with_context(|| format!("failed to get basebackup@{}", lsn))?; // Update pg_hba.conf received with basebackup. update_pg_hba(pgdata_path)?; diff --git a/compute_tools/src/config.rs b/compute_tools/src/config.rs index c89febc38c..776ef7d6b6 100644 --- a/compute_tools/src/config.rs +++ b/compute_tools/src/config.rs @@ -62,8 +62,9 @@ pub fn write_postgres_conf( let mut grpc_urls: Option> = Some(Vec::new()); for shardno in 0..conninfo.shards.len() { - let info = conninfo.shards.get(&(shardno as u32)) - .ok_or_else(|| anyhow::anyhow!("shard {shardno} missing from pageserver_connection_info shard map"))?; + let info = conninfo.shards.get(&(shardno as u32)).ok_or_else(|| { + anyhow::anyhow!("shard {shardno} missing from pageserver_connection_info shard map") + })?; if let Some(url) = &info.libpq_url { if let Some(ref mut urls) = libpq_urls { @@ -81,12 +82,20 @@ pub fn write_postgres_conf( } } if let Some(libpq_urls) = libpq_urls { - writeln!(file, "neon.pageserver_connstring={}", escape_conf_value(&libpq_urls.join(",")))?; + writeln!( + file, + "neon.pageserver_connstring={}", + escape_conf_value(&libpq_urls.join(",")) + )?; } else { writeln!(file, "# no neon.pageserver_connstring")?; } if let Some(grpc_urls) = grpc_urls { - writeln!(file, "neon.pageserver_grpc_urls={}", escape_conf_value(&grpc_urls.join(",")))?; + writeln!( + file, + "neon.pageserver_grpc_urls={}", + escape_conf_value(&grpc_urls.join(",")) + )?; } else { writeln!(file, "# no neon.pageserver_grpc_urls")?; } diff --git a/compute_tools/src/lsn_lease.rs b/compute_tools/src/lsn_lease.rs index e9fae18262..0e800145dc 100644 --- a/compute_tools/src/lsn_lease.rs +++ b/compute_tools/src/lsn_lease.rs @@ -81,7 +81,8 @@ fn acquire_lsn_lease_with_retry( let spec = state.pspec.as_ref().expect("spec must be set"); - spec.pageserver_conninfo.shards + spec.pageserver_conninfo + .shards .iter() .map(|(_shardno, conninfo)| { // FIXME: for now, this requires a libpq connection, the grpc API doesn't diff --git a/control_plane/src/bin/neon_local.rs b/control_plane/src/bin/neon_local.rs index aef27046b2..b56d1f84e5 100644 --- a/control_plane/src/bin/neon_local.rs +++ b/control_plane/src/bin/neon_local.rs @@ -16,7 +16,7 @@ use std::time::Duration; use anyhow::{Context, Result, anyhow, bail}; use clap::Parser; use compute_api::requests::ComputeClaimsScope; -use compute_api::spec::{ComputeMode, PageserverShardConnectionInfo, PageserverConnectionInfo}; +use compute_api::spec::{ComputeMode, PageserverConnectionInfo, PageserverShardConnectionInfo}; use control_plane::broker::StorageBroker; use control_plane::endpoint::{ComputeControlPlane, EndpointTerminateMode}; use control_plane::endpoint_storage::{ENDPOINT_STORAGE_DEFAULT_ADDR, EndpointStorage}; @@ -1531,8 +1531,8 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res // to pass these on to postgres. let storage_controller = StorageController::from_env(env); let locate_result = storage_controller.tenant_locate(endpoint.tenant_id).await?; - let shards = futures::future::try_join_all( - locate_result.shards.into_iter().map(|shard| async move { + let shards = futures::future::try_join_all(locate_result.shards.into_iter().map( + |shard| async move { if let ComputeMode::Static(lsn) = endpoint.mode { // Initialize LSN leases for static computes. let conf = env.get_pageserver_conf(shard.node_id).unwrap(); @@ -1546,7 +1546,8 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res let libpq_host = Host::parse(&shard.listen_pg_addr)?; let libpq_port = shard.listen_pg_port; - let libpq_url = Some(format!("postgres://no_user@{libpq_host}:{libpq_port}")); + let libpq_url = + Some(format!("postgres://no_user@{libpq_host}:{libpq_port}")); let grpc_url = if let Some(grpc_host) = shard.listen_grpc_addr { let grpc_port = shard.listen_grpc_port.expect("no gRPC port"); @@ -1559,8 +1560,8 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res grpc_url, }; anyhow::Ok((shard.shard_id.shard_number.0 as u32, pageserver)) - }), - ) + }, + )) .await?; let stripe_size = locate_result.shard_params.stripe_size; @@ -1649,7 +1650,8 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res // Use gRPC if requested. let libpq_host = Host::parse(&shard.listen_pg_addr).expect("bad hostname"); let libpq_port = shard.listen_pg_port; - let libpq_url = Some(format!("postgres://no_user@{libpq_host}:{libpq_port}")); + let libpq_url = + Some(format!("postgres://no_user@{libpq_host}:{libpq_port}")); let grpc_url = if let Some(grpc_host) = shard.listen_grpc_addr { let grpc_port = shard.listen_grpc_port.expect("no gRPC port"); @@ -1657,10 +1659,13 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res } else { None }; - (shard.shard_id.shard_number.0 as u32, PageserverShardConnectionInfo { - libpq_url, - grpc_url, - }) + ( + shard.shard_id.shard_number.0 as u32, + PageserverShardConnectionInfo { + libpq_url, + grpc_url, + }, + ) }) .collect::>() }; @@ -1671,7 +1676,9 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res // If --safekeepers argument is given, use only the listed // safekeeper nodes; otherwise all from the env. let safekeepers = parse_safekeepers(&args.safekeepers)?; - endpoint.reconfigure(pageserver_conninfo, None, safekeepers).await?; + endpoint + .reconfigure(pageserver_conninfo, None, safekeepers) + .await?; } EndpointCmd::Stop(args) => { let endpoint_id = &args.endpoint_id; diff --git a/control_plane/src/endpoint.rs b/control_plane/src/endpoint.rs index ac5e3d14bf..694683f9bf 100644 --- a/control_plane/src/endpoint.rs +++ b/control_plane/src/endpoint.rs @@ -56,8 +56,8 @@ use compute_api::responses::{ TlsConfig, }; use compute_api::spec::{ - Cluster, ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, Database, - PgIdent, RemoteExtSpec, Role, + Cluster, ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, Database, PgIdent, + RemoteExtSpec, Role, }; // re-export these, because they're used in the reconfigure() function @@ -993,7 +993,10 @@ impl Endpoint { stripe_size: Option, safekeepers: Option>, ) -> Result<()> { - anyhow::ensure!(!pageserver_conninfo.shards.is_empty(), "no pageservers provided"); + anyhow::ensure!( + !pageserver_conninfo.shards.is_empty(), + "no pageservers provided" + ); let (mut spec, compute_ctl_config) = { let config_path = self.endpoint_path().join("config.json"); diff --git a/pageserver/client_grpc/examples/load_test.rs b/pageserver/client_grpc/examples/load_test.rs index 68eef85c19..b189daa5ea 100644 --- a/pageserver/client_grpc/examples/load_test.rs +++ b/pageserver/client_grpc/examples/load_test.rs @@ -2,15 +2,14 @@ use std::collections::{HashMap, HashSet}; use std::sync::{ - Arc, - Mutex, + Arc, Mutex, atomic::{AtomicU64, AtomicUsize, Ordering}, }; use std::time::{Duration, Instant}; +use rand::Rng; use tokio::task; use tokio::time::sleep; -use rand::Rng; use tonic::Status; // Pull in your ConnectionPool and PooledItemFactory from the pageserver_client_grpc crate. @@ -184,13 +183,13 @@ async fn main() { // -------------------------------------- // 2. Pool parameters // -------------------------------------- - let connect_timeout = Duration::from_millis(500); - let connect_backoff = Duration::from_millis(100); - let max_consumers = 100; // test limit - let error_threshold = 2; // mock never fails - let max_idle_duration = Duration::from_secs(2); - let max_total_connections = 3; - let aggregate_metrics = None; + let connect_timeout = Duration::from_millis(500); + let connect_backoff = Duration::from_millis(100); + let max_consumers = 100; // test limit + let error_threshold = 2; // mock never fails + let max_idle_duration = Duration::from_secs(2); + let max_total_connections = 3; + let aggregate_metrics = None; let pool: Arc> = ConnectionPool::new( factory, @@ -211,10 +210,10 @@ async fn main() { let start_time = Instant::now(); for worker_id in 0..num_workers { - let pool_clone = Arc::clone(&pool); - let usage_clone = Arc::clone(&usage_map); - let seen_clone = Arc::clone(&seen_set); - let mc = max_consumers; + let pool_clone = Arc::clone(&pool); + let usage_clone = Arc::clone(&usage_map); + let seen_clone = Arc::clone(&seen_set); + let mc = max_consumers; let handle = task::spawn(async move { client_worker(pool_clone, usage_clone, seen_clone, mc, worker_id).await; @@ -229,10 +228,7 @@ async fn main() { let _ = handle.await; } let elapsed = Instant::now().duration_since(start_time); - println!( - "All {} workers completed in {:?}", - num_workers, elapsed - ); + println!("All {} workers completed in {:?}", num_workers, elapsed); // -------------------------------------- // 5. Print the total number of unique connections seen so far @@ -289,7 +285,10 @@ async fn main() { // 10. Because `client_worker` asserted inside that no connection // ever exceeded `max_consumers`, reaching this point means that check passed. // -------------------------------------- - println!("All per-connection usage stayed within max_consumers = {}.", max_consumers); + println!( + "All per-connection usage stayed within max_consumers = {}.", + max_consumers + ); println!("Load test complete; exiting cleanly."); } diff --git a/pageserver/client_grpc/examples/request_tracker_load_test.rs b/pageserver/client_grpc/examples/request_tracker_load_test.rs index 2963af0fa1..5741b289a5 100644 --- a/pageserver/client_grpc/examples/request_tracker_load_test.rs +++ b/pageserver/client_grpc/examples/request_tracker_load_test.rs @@ -1,15 +1,15 @@ // examples/request_tracker_load_test.rs -use std::{sync::Arc, time::Duration}; -use tokio; -use pageserver_client_grpc::request_tracker::RequestTracker; -use pageserver_client_grpc::request_tracker::MockStreamFactory; -use pageserver_client_grpc::request_tracker::StreamReturner; -use pageserver_client_grpc::client_cache::ConnectionPool; -use pageserver_client_grpc::client_cache::PooledItemFactory; +use pageserver_client_grpc::AuthInterceptor; use pageserver_client_grpc::ClientCacheOptions; use pageserver_client_grpc::PageserverClientAggregateMetrics; -use pageserver_client_grpc::AuthInterceptor; +use pageserver_client_grpc::client_cache::ConnectionPool; +use pageserver_client_grpc::client_cache::PooledItemFactory; +use pageserver_client_grpc::request_tracker::MockStreamFactory; +use pageserver_client_grpc::request_tracker::RequestTracker; +use pageserver_client_grpc::request_tracker::StreamReturner; +use std::{sync::Arc, time::Duration}; +use tokio; use pageserver_client_grpc::client_cache::ChannelFactory; @@ -22,8 +22,8 @@ use pageserver_api::key::Key; use utils::lsn::Lsn; use utils::shard::ShardIndex; -use futures::stream::FuturesOrdered; use futures::StreamExt; +use futures::stream::FuturesOrdered; use pageserver_page_api::proto; @@ -31,22 +31,21 @@ use pageserver_page_api::proto; async fn main() { // 1) configure the client‐pool behavior let client_cache_options = ClientCacheOptions { - max_delay_ms: 0, - drop_rate: 0.0, - hang_rate: 0.0, - connect_timeout: Duration::from_secs(10), - connect_backoff: Duration::from_millis(200), - max_consumers: 64, - error_threshold: 10, - max_idle_duration: Duration::from_secs(60), + max_delay_ms: 0, + drop_rate: 0.0, + hang_rate: 0.0, + connect_timeout: Duration::from_secs(10), + connect_backoff: Duration::from_millis(200), + max_consumers: 64, + error_threshold: 10, + max_idle_duration: Duration::from_secs(60), max_total_connections: 12, }; // 2) metrics collector (we assume Default is implemented) let metrics = Arc::new(PageserverClientAggregateMetrics::new()); let pool = ConnectionPool::::new( - Arc::new(MockStreamFactory::new( - )), + Arc::new(MockStreamFactory::new()), client_cache_options.connect_timeout, client_cache_options.connect_backoff, client_cache_options.max_consumers, @@ -60,12 +59,13 @@ async fn main() { // There is no mock for the unary connection pool, so for now just // don't use this pool // - let channel_fact : Arc + Send + Sync> = Arc::new(ChannelFactory::new( - "".to_string(), - client_cache_options.max_delay_ms, - client_cache_options.drop_rate, - client_cache_options.hang_rate, - )); + let channel_fact: Arc + Send + Sync> = + Arc::new(ChannelFactory::new( + "".to_string(), + client_cache_options.max_delay_ms, + client_cache_options.drop_rate, + client_cache_options.hang_rate, + )); let unary_pool: Arc> = ConnectionPool::new( Arc::clone(&channel_fact), client_cache_options.connect_timeout, @@ -79,42 +79,34 @@ async fn main() { // ----------- // Dummy auth interceptor. This is not used in this test. - let auth_interceptor = AuthInterceptor::new("dummy_tenant_id", - "dummy_timeline_id", - None); - let tracker = RequestTracker::new( - pool, - unary_pool, - auth_interceptor, - ShardIndex::unsharded(), - ); + let auth_interceptor = AuthInterceptor::new("dummy_tenant_id", "dummy_timeline_id", None); + let tracker = RequestTracker::new(pool, unary_pool, auth_interceptor, ShardIndex::unsharded()); // 4) fire off 10 000 requests in parallel let mut handles = FuturesOrdered::new(); for _i in 0..500000 { + let mut rng = rand::thread_rng(); + let r = 0..=1000000i128; + let key: i128 = rng.gen_range(r.clone()); + let key = Key::from_i128(key); + let (rel_tag, block_no) = key + .to_rel_block() + .expect("we filter non-rel-block keys out above"); - let mut rng = rand::thread_rng(); - let r = 0..=1000000i128; - let key: i128 = rng.gen_range(r.clone()); - let key = Key::from_i128(key); - let (rel_tag, block_no) = key - .to_rel_block() - .expect("we filter non-rel-block keys out above"); - - let req2 = proto::GetPageRequest { - request_id: 0, - request_class: proto::GetPageClass::Normal as i32, - read_lsn: Some(proto::ReadLsn { - request_lsn: if rng.gen_bool(0.5) { - u64::from(Lsn::MAX) - } else { - 10000 - }, - not_modified_since_lsn: 10000, - }), - rel: Some(rel_tag.into()), - block_number: vec![block_no], - }; + let req2 = proto::GetPageRequest { + request_id: 0, + request_class: proto::GetPageClass::Normal as i32, + read_lsn: Some(proto::ReadLsn { + request_lsn: if rng.gen_bool(0.5) { + u64::from(Lsn::MAX) + } else { + 10000 + }, + not_modified_since_lsn: 10000, + }), + rel: Some(rel_tag.into()), + block_number: vec![block_no], + }; let req_model = pageserver_page_api::GetPageRequest::try_from(req2.clone()); // RequestTracker is Clone, so we can share it diff --git a/pageserver/client_grpc/src/client_cache.rs b/pageserver/client_grpc/src/client_cache.rs index 89c2d2b44e..b366ad0878 100644 --- a/pageserver/client_grpc/src/client_cache.rs +++ b/pageserver/client_grpc/src/client_cache.rs @@ -30,8 +30,8 @@ use http::Uri; use hyper_util::rt::TokioIo; use tower::service_fn; -use tokio_util::sync::CancellationToken; use async_trait::async_trait; +use tokio_util::sync::CancellationToken; // // The "TokioTcp" is flakey TCP network for testing purposes, in order @@ -168,7 +168,10 @@ impl AsyncWrite for TokioTcp { #[async_trait] pub trait PooledItemFactory: Send + Sync + 'static { /// Create a new pooled item. - async fn create(&self, connect_timeout: Duration) -> Result, tokio::time::error::Elapsed>; + async fn create( + &self, + connect_timeout: Duration, + ) -> Result, tokio::time::error::Elapsed>; } pub struct ChannelFactory { @@ -178,14 +181,8 @@ pub struct ChannelFactory { hang_rate: f64, } - impl ChannelFactory { - pub fn new( - endpoint: String, - max_delay_ms: u64, - drop_rate: f64, - hang_rate: f64, - ) -> Self { + pub fn new(endpoint: String, max_delay_ms: u64, drop_rate: f64, hang_rate: f64) -> Self { ChannelFactory { endpoint, max_delay_ms, @@ -197,7 +194,10 @@ impl ChannelFactory { #[async_trait] impl PooledItemFactory for ChannelFactory { - async fn create(&self, connect_timeout: Duration) -> Result, tokio::time::error::Elapsed> { + async fn create( + &self, + connect_timeout: Duration, + ) -> Result, tokio::time::error::Elapsed> { let max_delay_ms = self.max_delay_ms; let drop_rate = self.drop_rate; let hang_rate = self.hang_rate; @@ -239,7 +239,6 @@ impl PooledItemFactory for ChannelFactory { } }); - let attempt = tokio::time::timeout( connect_timeout, Endpoint::from_shared(self.endpoint.clone()) @@ -247,26 +246,21 @@ impl PooledItemFactory for ChannelFactory { .timeout(connect_timeout) .connect_with_connector(connector), ) - .await; + .await; match attempt { Ok(Ok(channel)) => { // Connection succeeded Ok(Ok(channel)) } - Ok(Err(e)) => { - Ok(Err(tonic::Status::new( - tonic::Code::Unavailable, - format!("Failed to connect: {}", e), - ))) - } - Err(e) => { - Err(e) - } + Ok(Err(e)) => Ok(Err(tonic::Status::new( + tonic::Code::Unavailable, + format!("Failed to connect: {}", e), + ))), + Err(e) => Err(e), } } } - /// A pooled gRPC client with capacity tracking and error handling. pub struct ConnectionPool { inner: Mutex>, @@ -511,15 +505,15 @@ impl ConnectionPool { let mut inner = self_clone.inner.lock().await; inner.waiters += 1; if inner.waiters > (inner.in_progress * self_clone.max_consumers) { - if (inner.entries.len() + inner.in_progress) < self_clone.max_total_connections { - + if (inner.entries.len() + inner.in_progress) + < self_clone.max_total_connections + { let self_clone_spawn = Arc::clone(&self_clone); tokio::task::spawn(async move { self_clone_spawn.create_connection().await; }); inner.in_progress += 1; } - } } // Wait for a connection to become available, either because it @@ -548,7 +542,6 @@ impl ConnectionPool { } async fn create_connection(&self) -> () { - // Generate a random backoff to add some jitter so that connections // don't all retry at the same time. let mut backoff_delay = Duration::from_millis( @@ -595,9 +588,7 @@ impl ConnectionPool { None => {} } - let attempt = self.fact - .create(self.connect_timeout) - .await; + let attempt = self.fact.create(self.connect_timeout).await; match attempt { // Connection succeeded @@ -732,10 +723,8 @@ impl PooledClient { } pub async fn finish(mut self, result: Result<(), tonic::Status>) { self.is_ok = result.is_ok(); - self.pool.return_client( - self.id, - self.is_ok, - self.permit, - ).await; + self.pool + .return_client(self.id, self.is_ok, self.permit) + .await; } } diff --git a/pageserver/client_grpc/src/request_tracker.rs b/pageserver/client_grpc/src/request_tracker.rs index ed585660cc..899abf217b 100644 --- a/pageserver/client_grpc/src/request_tracker.rs +++ b/pageserver/client_grpc/src/request_tracker.rs @@ -7,30 +7,27 @@ //! account. In the future, there can be multiple pageservers per shard, and RequestTracker manages //! load balancing between them, but that's not implemented yet. -use std::sync::Arc; -use pageserver_page_api::GetPageRequest; -use pageserver_page_api::GetPageResponse; -use pageserver_page_api::*; -use pageserver_page_api::proto; -use crate::client_cache; -use crate::client_cache::ConnectionPool; -use crate::client_cache::ChannelFactory; use crate::AuthInterceptor; -use tonic::{transport::{Channel}, Request}; use crate::ClientCacheOptions; use crate::PageserverClientAggregateMetrics; -use std::sync::atomic::AtomicU64; +use crate::client_cache; +use crate::client_cache::ChannelFactory; +use crate::client_cache::ConnectionPool; +use pageserver_page_api::GetPageRequest; +use pageserver_page_api::GetPageResponse; +use pageserver_page_api::proto; +use pageserver_page_api::*; +use std::sync::Arc; use std::sync::Mutex; +use std::sync::atomic::AtomicU64; +use tonic::{Request, transport::Channel}; use utils::shard::ShardIndex; -use tokio_stream::wrappers::ReceiverStream; use pageserver_page_api::proto::PageServiceClient; +use tokio_stream::wrappers::ReceiverStream; -use tonic::{ - Status, - Code, -}; +use tonic::{Code, Status}; use async_trait::async_trait; use std::time::Duration; @@ -45,20 +42,28 @@ use client_cache::PooledItemFactory; #[derive(Clone)] pub struct StreamReturner { sender: tokio::sync::mpsc::Sender, - sender_hashmap: Arc>>>>, -} -pub struct MockStreamFactory { + sender_hashmap: Arc< + tokio::sync::Mutex< + std::collections::HashMap< + u64, + tokio::sync::mpsc::Sender>, + >, + >, + >, } +pub struct MockStreamFactory {} impl MockStreamFactory { pub fn new() -> Self { - MockStreamFactory { - } + MockStreamFactory {} } } #[async_trait] impl PooledItemFactory for MockStreamFactory { - async fn create(&self, _connect_timeout: Duration) -> Result, tokio::time::error::Elapsed> { + async fn create( + &self, + _connect_timeout: Duration, + ) -> Result, tokio::time::error::Elapsed> { let (sender, mut receiver) = tokio::sync::mpsc::channel::(1000); // Create a StreamReturner that will send requests to the receiver channel let stream_returner = StreamReturner { @@ -69,7 +74,6 @@ impl PooledItemFactory for MockStreamFactory { let map = Arc::clone(&stream_returner.sender_hashmap); tokio::spawn(async move { while let Some(request) = receiver.recv().await { - // Break out of the loop with 1% chance if rand::random::() < 0.001 { break; @@ -111,7 +115,6 @@ impl PooledItemFactory for MockStreamFactory { } } - pub struct StreamFactory { connection_pool: Arc>, auth_interceptor: AuthInterceptor, @@ -134,21 +137,22 @@ impl StreamFactory { #[async_trait] impl PooledItemFactory for StreamFactory { - async fn create(&self, _connect_timeout: Duration) -> - Result, tokio::time::error::Elapsed> - { - let pool_clone : Arc> = Arc::clone(&self.connection_pool); + async fn create( + &self, + _connect_timeout: Duration, + ) -> Result, tokio::time::error::Elapsed> { + let pool_clone: Arc> = Arc::clone(&self.connection_pool); let pooled_client = pool_clone.get_client().await; let channel = pooled_client.unwrap().channel(); - let mut client = - PageServiceClient::with_interceptor(channel, self.auth_interceptor.for_shard(self.shard)); + let mut client = PageServiceClient::with_interceptor( + channel, + self.auth_interceptor.for_shard(self.shard), + ); let (sender, receiver) = tokio::sync::mpsc::channel::(1000); let outbound = ReceiverStream::new(receiver); - let client_resp = client - .get_pages(Request::new(outbound)) - .await; + let client_resp = client.get_pages(Request::new(outbound)).await; match client_resp { Err(status) => { @@ -161,17 +165,23 @@ impl PooledItemFactory for StreamFactory { Ok(resp) => { let stream_returner = StreamReturner { sender: sender.clone(), - sender_hashmap: Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())), + sender_hashmap: Arc::new(tokio::sync::Mutex::new( + std::collections::HashMap::new(), + )), }; - let map : Arc>>>> - = Arc::clone(&stream_returner.sender_hashmap); + let map: Arc< + tokio::sync::Mutex< + std::collections::HashMap< + u64, + tokio::sync::mpsc::Sender>, + >, + >, + > = Arc::clone(&stream_returner.sender_hashmap); tokio::spawn(async move { - let map_clone = Arc::clone(&map); let mut inner = resp.into_inner(); loop { - let resp = inner.message().await; if !resp.is_ok() { break; // Exit the loop if no more messages @@ -216,10 +226,11 @@ pub struct RequestTracker { } impl RequestTracker { - pub fn new(stream_pool: Arc>, - unary_pool: Arc>, - auth_interceptor: AuthInterceptor, - shard: ShardIndex, + pub fn new( + stream_pool: Arc>, + unary_pool: Arc>, + auth_interceptor: AuthInterceptor, + shard: ShardIndex, ) -> Self { let cur_id = Arc::new(AtomicU64::new(0)); @@ -228,7 +239,7 @@ impl RequestTracker { stream_pool: stream_pool, unary_pool: unary_pool, auth_interceptor: auth_interceptor, - shard: shard.clone() + shard: shard.clone(), } } @@ -240,9 +251,14 @@ impl RequestTracker { let unary_pool = Arc::clone(&self.unary_pool); let pooled_client = unary_pool.get_client().await.unwrap(); let channel = pooled_client.channel(); - let mut ps_client = PageServiceClient::with_interceptor(channel, self.auth_interceptor.for_shard(self.shard)); + let mut ps_client = PageServiceClient::with_interceptor( + channel, + self.auth_interceptor.for_shard(self.shard), + ); let request = proto::CheckRelExistsRequest::from(req.clone()); - let response = ps_client.check_rel_exists(tonic::Request::new(request)).await; + let response = ps_client + .check_rel_exists(tonic::Request::new(request)) + .await; match response { Err(status) => { @@ -266,7 +282,10 @@ impl RequestTracker { let unary_pool = Arc::clone(&self.unary_pool); let pooled_client = unary_pool.get_client().await.unwrap(); let channel = pooled_client.channel(); - let mut ps_client = PageServiceClient::with_interceptor(channel, self.auth_interceptor.for_shard(self.shard)); + let mut ps_client = PageServiceClient::with_interceptor( + channel, + self.auth_interceptor.for_shard(self.shard), + ); let request = proto::GetRelSizeRequest::from(req.clone()); let response = ps_client.get_rel_size(tonic::Request::new(request)).await; @@ -281,7 +300,6 @@ impl RequestTracker { return Ok(resp.get_ref().num_blocks); } } - } } @@ -292,8 +310,12 @@ impl RequestTracker { loop { // Current sharding model assumes that all metadata is present only at shard 0. let unary_pool = Arc::clone(&self.unary_pool); - let pooled_client = unary_pool.get_client().await.unwrap();let channel = pooled_client.channel(); - let mut ps_client = PageServiceClient::with_interceptor(channel, self.auth_interceptor.for_shard(self.shard)); + let pooled_client = unary_pool.get_client().await.unwrap(); + let channel = pooled_client.channel(); + let mut ps_client = PageServiceClient::with_interceptor( + channel, + self.auth_interceptor.for_shard(self.shard), + ); let request = proto::GetDbSizeRequest::from(req.clone()); let response = ps_client.get_db_size(tonic::Request::new(request)).await; @@ -308,7 +330,6 @@ impl RequestTracker { return Ok(resp.get_ref().num_bytes); } } - } } @@ -322,7 +343,9 @@ impl RequestTracker { //let request_id = self.cur_id.fetch_add(1, Ordering::SeqCst) + 1; let request_id = request.request_id; let response_sender: tokio::sync::mpsc::Sender>; - let mut response_receiver: tokio::sync::mpsc::Receiver>; + let mut response_receiver: tokio::sync::mpsc::Receiver< + Result, + >; (response_sender, response_receiver) = tokio::sync::mpsc::channel(1); //request.request_id = request_id; @@ -344,7 +367,9 @@ impl RequestTracker { let mut map_inner = map.lock().await; map_inner.insert(request_id, response_sender); } - let sent = returner.sender.send(proto::GetPageRequest::from(request)) + let sent = returner + .sender + .send(proto::GetPageRequest::from(request)) .await; if let Err(_e) = sent { @@ -354,22 +379,27 @@ impl RequestTracker { // remove from hashmap map_inner.remove(&request_id); } - stream_returner.finish(Err(Status::new(Code::Unknown, - "Failed to send request"))).await; + stream_returner + .finish(Err(Status::new(Code::Unknown, "Failed to send request"))) + .await; continue; } let response: Option>; response = response_receiver.recv().await; match response { - Some (resp) => { + Some(resp) => { match resp { Err(_status) => { // Handle the case where the response was not received - stream_returner.finish(Err(Status::new(Code::Unknown, - "Failed to receive response"))).await; + stream_returner + .finish(Err(Status::new( + Code::Unknown, + "Failed to receive response", + ))) + .await; continue; - }, + } Ok(resp) => { stream_returner.finish(Result::Ok(())).await; return Ok(resp.clone().into()); @@ -378,8 +408,9 @@ impl RequestTracker { } None => { // Handle the case where the response channel was closed - stream_returner.finish(Err(Status::new(Code::Unknown, - "Response channel closed"))).await; + stream_returner + .finish(Err(Status::new(Code::Unknown, "Response channel closed"))) + .await; continue; } } @@ -407,25 +438,25 @@ impl ShardedRequestTracker { // Default configuration for the client. These could be added to a config file // let tcp_client_cache_options = ClientCacheOptions { - max_delay_ms: 0, - drop_rate: 0.0, - hang_rate: 0.0, - connect_timeout: Duration::from_secs(1), - connect_backoff: Duration::from_millis(100), - max_consumers: 8, // Streams per connection - error_threshold: 10, - max_idle_duration: Duration::from_secs(5), + max_delay_ms: 0, + drop_rate: 0.0, + hang_rate: 0.0, + connect_timeout: Duration::from_secs(1), + connect_backoff: Duration::from_millis(100), + max_consumers: 8, // Streams per connection + error_threshold: 10, + max_idle_duration: Duration::from_secs(5), max_total_connections: 8, }; let stream_client_cache_options = ClientCacheOptions { - max_delay_ms: 0, - drop_rate: 0.0, - hang_rate: 0.0, - connect_timeout: Duration::from_secs(1), - connect_backoff: Duration::from_millis(100), - max_consumers: 64, // Requests per stream - error_threshold: 10, - max_idle_duration: Duration::from_secs(5), + max_delay_ms: 0, + drop_rate: 0.0, + hang_rate: 0.0, + connect_timeout: Duration::from_secs(1), + connect_backoff: Duration::from_millis(100), + max_consumers: 64, // Requests per stream + error_threshold: 10, + max_idle_duration: Duration::from_secs(5), max_total_connections: 64, // Total allowable number of streams }; ShardedRequestTracker { @@ -437,23 +468,26 @@ impl ShardedRequestTracker { } } - pub async fn update_shard_map(&self, - shard_urls: std::collections::HashMap, - metrics: Option>, - tenant_id: String, timeline_id: String, auth_str: Option<&str>) { - - - let mut trackers = std::collections::HashMap::new(); + pub async fn update_shard_map( + &self, + shard_urls: std::collections::HashMap, + metrics: Option>, + tenant_id: String, + timeline_id: String, + auth_str: Option<&str>, + ) { + let mut trackers = std::collections::HashMap::new(); for (shard, endpoint_url) in shard_urls { // // Create a pool of streams for streaming get_page requests // - let channel_fact : Arc + Send + Sync> = Arc::new(ChannelFactory::new( - endpoint_url.clone(), - self.tcp_client_cache_options.max_delay_ms, - self.tcp_client_cache_options.drop_rate, - self.tcp_client_cache_options.hang_rate, - )); + let channel_fact: Arc + Send + Sync> = + Arc::new(ChannelFactory::new( + endpoint_url.clone(), + self.tcp_client_cache_options.max_delay_ms, + self.tcp_client_cache_options.drop_rate, + self.tcp_client_cache_options.hang_rate, + )); let new_pool: Arc>; new_pool = ConnectionPool::new( Arc::clone(&channel_fact), @@ -466,13 +500,15 @@ impl ShardedRequestTracker { metrics.clone(), ); - let auth_interceptor = AuthInterceptor::new(tenant_id.as_str(), - timeline_id.as_str(), - auth_str); + let auth_interceptor = + AuthInterceptor::new(tenant_id.as_str(), timeline_id.as_str(), auth_str); let stream_pool = ConnectionPool::::new( - Arc::new(StreamFactory::new(new_pool.clone(), - auth_interceptor.clone(), ShardIndex::unsharded())), + Arc::new(StreamFactory::new( + new_pool.clone(), + auth_interceptor.clone(), + ShardIndex::unsharded(), + )), self.stream_client_cache_options.connect_timeout, self.stream_client_cache_options.connect_backoff, self.stream_client_cache_options.max_consumers, @@ -495,7 +531,7 @@ impl ShardedRequestTracker { self.tcp_client_cache_options.error_threshold, self.tcp_client_cache_options.max_idle_duration, self.tcp_client_cache_options.max_total_connections, - metrics.clone() + metrics.clone(), ); // // Create a new RequestTracker for this shard @@ -507,11 +543,7 @@ impl ShardedRequestTracker { inner.trackers = trackers; } - pub async fn get_page( - &self, - req: GetPageRequest, - ) -> Result { - + pub async fn get_page(&self, req: GetPageRequest) -> Result { // Get shard index from the request and look up the RequestTracker instance for that shard let shard_index = ShardIndex::unsharded(); // TODO! let mut tracker = self.lookup_tracker_for_shard(shard_index)?; diff --git a/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs b/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs index db9f6a7592..719bbef5d9 100644 --- a/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs +++ b/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs @@ -1,4 +1,4 @@ -use std::collections::{HashSet, HashMap, VecDeque}; +use std::collections::{HashMap, HashSet, VecDeque}; use std::future::Future; use std::num::NonZeroUsize; use std::pin::Pin; diff --git a/pgxn/neon/communicator/src/worker_process/main_loop.rs b/pgxn/neon/communicator/src/worker_process/main_loop.rs index 573e391262..ee20fe130f 100644 --- a/pgxn/neon/communicator/src/worker_process/main_loop.rs +++ b/pgxn/neon/communicator/src/worker_process/main_loop.rs @@ -93,11 +93,15 @@ pub(super) async fn init( .worker_process_init(last_lsn, file_cache); let request_tracker = ShardedRequestTracker::new(); - request_tracker.update_shard_map(shard_map, - None, - tenant_id, - timeline_id, - auth_token.as_deref()).await; + request_tracker + .update_shard_map( + shard_map, + None, + tenant_id, + timeline_id, + auth_token.as_deref(), + ) + .await; let request_counters = IntCounterVec::new( metrics::core::Opts::new( diff --git a/storage_controller/src/compute_hook.rs b/storage_controller/src/compute_hook.rs index 9095542c5f..f31f5d104e 100644 --- a/storage_controller/src/compute_hook.rs +++ b/storage_controller/src/compute_hook.rs @@ -5,7 +5,9 @@ use std::sync::Arc; use std::time::Duration; use anyhow::Context; -use control_plane::endpoint::{ComputeControlPlane, EndpointStatus, PageserverConnectionInfo, PageserverShardConnectionInfo}; +use control_plane::endpoint::{ + ComputeControlPlane, EndpointStatus, PageserverConnectionInfo, PageserverShardConnectionInfo, +}; use control_plane::local_env::LocalEnv; use futures::StreamExt; use hyper::StatusCode; @@ -438,8 +440,8 @@ impl ComputeHook { format!("postgres://no_user@{host}:{port}") }); let grpc_url = if let Some(grpc_addr) = &ps_conf.listen_grpc_addr { - let (host, port) = parse_host_port(grpc_addr) - .expect("invalid gRPC address"); + let (host, port) = + parse_host_port(grpc_addr).expect("invalid gRPC address"); let port = port.unwrap_or(DEFAULT_GRPC_LISTEN_PORT); Some(format!("grpc://no_user@{host}:{port}")) } else {