Run cargo fmt

This commit is contained in:
Heikki Linnakangas
2025-06-29 21:21:07 +03:00
parent 8b7796cbfa
commit f3ba201800
12 changed files with 308 additions and 253 deletions

View File

@@ -6,8 +6,8 @@ use compute_api::responses::{
LfcPrewarmState, TlsConfig,
};
use compute_api::spec::{
ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, ExtVersion,
PageserverConnectionInfo, PageserverShardConnectionInfo, PgIdent,
ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, ExtVersion, PageserverConnectionInfo,
PageserverShardConnectionInfo, PgIdent,
};
use futures::StreamExt;
use futures::future::join_all;
@@ -265,17 +265,23 @@ impl ParsedSpec {
}
}
fn extract_pageserver_conninfo_from_guc(pageserver_connstring_guc: &str) -> PageserverConnectionInfo {
fn extract_pageserver_conninfo_from_guc(
pageserver_connstring_guc: &str,
) -> PageserverConnectionInfo {
PageserverConnectionInfo {
shards: pageserver_connstring_guc
.split(',')
.into_iter()
.enumerate()
.map(|(i, connstr)| (i as u32, PageserverShardConnectionInfo {
libpq_url: Some(connstr.to_string()),
grpc_url: None,
}))
.map(|(i, connstr)| {
(
i as u32,
PageserverShardConnectionInfo {
libpq_url: Some(connstr.to_string()),
grpc_url: None,
},
)
})
.collect(),
prefer_grpc: false,
}
@@ -1041,13 +1047,18 @@ impl ComputeNode {
fn try_get_basebackup_grpc(&self, spec: &ParsedSpec, lsn: Lsn) -> Result<()> {
let start_time = Instant::now();
let shard0 = spec.pageserver_conninfo.shards.get(&0).expect("shard 0 connection info missing");
let shard0 = spec
.pageserver_conninfo
.shards
.get(&0)
.expect("shard 0 connection info missing");
let shard0_url = shard0.grpc_url.clone().expect("no grpc_url for shard 0");
info!("getting basebackup@{} from pageserver {}", lsn, shard0_url);
let chunks = tokio::runtime::Handle::current().block_on(async move {
let mut client = page_api::proto::PageServiceClient::connect(shard0_url.to_string()).await?;
let mut client =
page_api::proto::PageServiceClient::connect(shard0_url.to_string()).await?;
let req = page_api::proto::GetBaseBackupRequest {
lsn: lsn.0,
@@ -1098,9 +1109,16 @@ impl ComputeNode {
fn try_get_basebackup_libpq(&self, spec: &ParsedSpec, lsn: Lsn) -> Result<()> {
let start_time = Instant::now();
let shard0 = spec.pageserver_conninfo.shards.get(&0).expect("shard 0 connection info missing");
let shard0 = spec
.pageserver_conninfo
.shards
.get(&0)
.expect("shard 0 connection info missing");
let shard0_connstr = shard0.libpq_url.clone().expect("no libpq_url for shard 0");
info!("getting basebackup@{} from pageserver {}", lsn, shard0_connstr);
info!(
"getting basebackup@{} from pageserver {}",
lsn, shard0_connstr
);
let mut config = postgres::Config::from_str(&shard0_connstr)?;
@@ -1400,9 +1418,8 @@ impl ComputeNode {
}
};
self.get_basebackup(compute_state, lsn).with_context(|| {
format!("failed to get basebackup@{}", lsn)
})?;
self.get_basebackup(compute_state, lsn)
.with_context(|| format!("failed to get basebackup@{}", lsn))?;
// Update pg_hba.conf received with basebackup.
update_pg_hba(pgdata_path)?;

View File

@@ -62,8 +62,9 @@ pub fn write_postgres_conf(
let mut grpc_urls: Option<Vec<String>> = Some(Vec::new());
for shardno in 0..conninfo.shards.len() {
let info = conninfo.shards.get(&(shardno as u32))
.ok_or_else(|| anyhow::anyhow!("shard {shardno} missing from pageserver_connection_info shard map"))?;
let info = conninfo.shards.get(&(shardno as u32)).ok_or_else(|| {
anyhow::anyhow!("shard {shardno} missing from pageserver_connection_info shard map")
})?;
if let Some(url) = &info.libpq_url {
if let Some(ref mut urls) = libpq_urls {
@@ -81,12 +82,20 @@ pub fn write_postgres_conf(
}
}
if let Some(libpq_urls) = libpq_urls {
writeln!(file, "neon.pageserver_connstring={}", escape_conf_value(&libpq_urls.join(",")))?;
writeln!(
file,
"neon.pageserver_connstring={}",
escape_conf_value(&libpq_urls.join(","))
)?;
} else {
writeln!(file, "# no neon.pageserver_connstring")?;
}
if let Some(grpc_urls) = grpc_urls {
writeln!(file, "neon.pageserver_grpc_urls={}", escape_conf_value(&grpc_urls.join(",")))?;
writeln!(
file,
"neon.pageserver_grpc_urls={}",
escape_conf_value(&grpc_urls.join(","))
)?;
} else {
writeln!(file, "# no neon.pageserver_grpc_urls")?;
}

View File

@@ -81,7 +81,8 @@ fn acquire_lsn_lease_with_retry(
let spec = state.pspec.as_ref().expect("spec must be set");
spec.pageserver_conninfo.shards
spec.pageserver_conninfo
.shards
.iter()
.map(|(_shardno, conninfo)| {
// FIXME: for now, this requires a libpq connection, the grpc API doesn't

View File

@@ -16,7 +16,7 @@ use std::time::Duration;
use anyhow::{Context, Result, anyhow, bail};
use clap::Parser;
use compute_api::requests::ComputeClaimsScope;
use compute_api::spec::{ComputeMode, PageserverShardConnectionInfo, PageserverConnectionInfo};
use compute_api::spec::{ComputeMode, PageserverConnectionInfo, PageserverShardConnectionInfo};
use control_plane::broker::StorageBroker;
use control_plane::endpoint::{ComputeControlPlane, EndpointTerminateMode};
use control_plane::endpoint_storage::{ENDPOINT_STORAGE_DEFAULT_ADDR, EndpointStorage};
@@ -1531,8 +1531,8 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
// to pass these on to postgres.
let storage_controller = StorageController::from_env(env);
let locate_result = storage_controller.tenant_locate(endpoint.tenant_id).await?;
let shards = futures::future::try_join_all(
locate_result.shards.into_iter().map(|shard| async move {
let shards = futures::future::try_join_all(locate_result.shards.into_iter().map(
|shard| async move {
if let ComputeMode::Static(lsn) = endpoint.mode {
// Initialize LSN leases for static computes.
let conf = env.get_pageserver_conf(shard.node_id).unwrap();
@@ -1546,7 +1546,8 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
let libpq_host = Host::parse(&shard.listen_pg_addr)?;
let libpq_port = shard.listen_pg_port;
let libpq_url = Some(format!("postgres://no_user@{libpq_host}:{libpq_port}"));
let libpq_url =
Some(format!("postgres://no_user@{libpq_host}:{libpq_port}"));
let grpc_url = if let Some(grpc_host) = shard.listen_grpc_addr {
let grpc_port = shard.listen_grpc_port.expect("no gRPC port");
@@ -1559,8 +1560,8 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
grpc_url,
};
anyhow::Ok((shard.shard_id.shard_number.0 as u32, pageserver))
}),
)
},
))
.await?;
let stripe_size = locate_result.shard_params.stripe_size;
@@ -1649,7 +1650,8 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
// Use gRPC if requested.
let libpq_host = Host::parse(&shard.listen_pg_addr).expect("bad hostname");
let libpq_port = shard.listen_pg_port;
let libpq_url = Some(format!("postgres://no_user@{libpq_host}:{libpq_port}"));
let libpq_url =
Some(format!("postgres://no_user@{libpq_host}:{libpq_port}"));
let grpc_url = if let Some(grpc_host) = shard.listen_grpc_addr {
let grpc_port = shard.listen_grpc_port.expect("no gRPC port");
@@ -1657,10 +1659,13 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
} else {
None
};
(shard.shard_id.shard_number.0 as u32, PageserverShardConnectionInfo {
libpq_url,
grpc_url,
})
(
shard.shard_id.shard_number.0 as u32,
PageserverShardConnectionInfo {
libpq_url,
grpc_url,
},
)
})
.collect::<Vec<_>>()
};
@@ -1671,7 +1676,9 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
// If --safekeepers argument is given, use only the listed
// safekeeper nodes; otherwise all from the env.
let safekeepers = parse_safekeepers(&args.safekeepers)?;
endpoint.reconfigure(pageserver_conninfo, None, safekeepers).await?;
endpoint
.reconfigure(pageserver_conninfo, None, safekeepers)
.await?;
}
EndpointCmd::Stop(args) => {
let endpoint_id = &args.endpoint_id;

View File

@@ -56,8 +56,8 @@ use compute_api::responses::{
TlsConfig,
};
use compute_api::spec::{
Cluster, ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, Database,
PgIdent, RemoteExtSpec, Role,
Cluster, ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, Database, PgIdent,
RemoteExtSpec, Role,
};
// re-export these, because they're used in the reconfigure() function
@@ -993,7 +993,10 @@ impl Endpoint {
stripe_size: Option<ShardStripeSize>,
safekeepers: Option<Vec<NodeId>>,
) -> Result<()> {
anyhow::ensure!(!pageserver_conninfo.shards.is_empty(), "no pageservers provided");
anyhow::ensure!(
!pageserver_conninfo.shards.is_empty(),
"no pageservers provided"
);
let (mut spec, compute_ctl_config) = {
let config_path = self.endpoint_path().join("config.json");

View File

@@ -2,15 +2,14 @@
use std::collections::{HashMap, HashSet};
use std::sync::{
Arc,
Mutex,
Arc, Mutex,
atomic::{AtomicU64, AtomicUsize, Ordering},
};
use std::time::{Duration, Instant};
use rand::Rng;
use tokio::task;
use tokio::time::sleep;
use rand::Rng;
use tonic::Status;
// Pull in your ConnectionPool and PooledItemFactory from the pageserver_client_grpc crate.
@@ -184,13 +183,13 @@ async fn main() {
// --------------------------------------
// 2. Pool parameters
// --------------------------------------
let connect_timeout = Duration::from_millis(500);
let connect_backoff = Duration::from_millis(100);
let max_consumers = 100; // test limit
let error_threshold = 2; // mock never fails
let max_idle_duration = Duration::from_secs(2);
let max_total_connections = 3;
let aggregate_metrics = None;
let connect_timeout = Duration::from_millis(500);
let connect_backoff = Duration::from_millis(100);
let max_consumers = 100; // test limit
let error_threshold = 2; // mock never fails
let max_idle_duration = Duration::from_secs(2);
let max_total_connections = 3;
let aggregate_metrics = None;
let pool: Arc<ConnectionPool<MockConnection>> = ConnectionPool::new(
factory,
@@ -211,10 +210,10 @@ async fn main() {
let start_time = Instant::now();
for worker_id in 0..num_workers {
let pool_clone = Arc::clone(&pool);
let usage_clone = Arc::clone(&usage_map);
let seen_clone = Arc::clone(&seen_set);
let mc = max_consumers;
let pool_clone = Arc::clone(&pool);
let usage_clone = Arc::clone(&usage_map);
let seen_clone = Arc::clone(&seen_set);
let mc = max_consumers;
let handle = task::spawn(async move {
client_worker(pool_clone, usage_clone, seen_clone, mc, worker_id).await;
@@ -229,10 +228,7 @@ async fn main() {
let _ = handle.await;
}
let elapsed = Instant::now().duration_since(start_time);
println!(
"All {} workers completed in {:?}",
num_workers, elapsed
);
println!("All {} workers completed in {:?}", num_workers, elapsed);
// --------------------------------------
// 5. Print the total number of unique connections seen so far
@@ -289,7 +285,10 @@ async fn main() {
// 10. Because `client_worker` asserted inside that no connection
// ever exceeded `max_consumers`, reaching this point means that check passed.
// --------------------------------------
println!("All per-connection usage stayed within max_consumers = {}.", max_consumers);
println!(
"All per-connection usage stayed within max_consumers = {}.",
max_consumers
);
println!("Load test complete; exiting cleanly.");
}

View File

@@ -1,15 +1,15 @@
// examples/request_tracker_load_test.rs
use std::{sync::Arc, time::Duration};
use tokio;
use pageserver_client_grpc::request_tracker::RequestTracker;
use pageserver_client_grpc::request_tracker::MockStreamFactory;
use pageserver_client_grpc::request_tracker::StreamReturner;
use pageserver_client_grpc::client_cache::ConnectionPool;
use pageserver_client_grpc::client_cache::PooledItemFactory;
use pageserver_client_grpc::AuthInterceptor;
use pageserver_client_grpc::ClientCacheOptions;
use pageserver_client_grpc::PageserverClientAggregateMetrics;
use pageserver_client_grpc::AuthInterceptor;
use pageserver_client_grpc::client_cache::ConnectionPool;
use pageserver_client_grpc::client_cache::PooledItemFactory;
use pageserver_client_grpc::request_tracker::MockStreamFactory;
use pageserver_client_grpc::request_tracker::RequestTracker;
use pageserver_client_grpc::request_tracker::StreamReturner;
use std::{sync::Arc, time::Duration};
use tokio;
use pageserver_client_grpc::client_cache::ChannelFactory;
@@ -22,8 +22,8 @@ use pageserver_api::key::Key;
use utils::lsn::Lsn;
use utils::shard::ShardIndex;
use futures::stream::FuturesOrdered;
use futures::StreamExt;
use futures::stream::FuturesOrdered;
use pageserver_page_api::proto;
@@ -31,22 +31,21 @@ use pageserver_page_api::proto;
async fn main() {
// 1) configure the clientpool behavior
let client_cache_options = ClientCacheOptions {
max_delay_ms: 0,
drop_rate: 0.0,
hang_rate: 0.0,
connect_timeout: Duration::from_secs(10),
connect_backoff: Duration::from_millis(200),
max_consumers: 64,
error_threshold: 10,
max_idle_duration: Duration::from_secs(60),
max_delay_ms: 0,
drop_rate: 0.0,
hang_rate: 0.0,
connect_timeout: Duration::from_secs(10),
connect_backoff: Duration::from_millis(200),
max_consumers: 64,
error_threshold: 10,
max_idle_duration: Duration::from_secs(60),
max_total_connections: 12,
};
// 2) metrics collector (we assume Default is implemented)
let metrics = Arc::new(PageserverClientAggregateMetrics::new());
let pool = ConnectionPool::<StreamReturner>::new(
Arc::new(MockStreamFactory::new(
)),
Arc::new(MockStreamFactory::new()),
client_cache_options.connect_timeout,
client_cache_options.connect_backoff,
client_cache_options.max_consumers,
@@ -60,12 +59,13 @@ async fn main() {
// There is no mock for the unary connection pool, so for now just
// don't use this pool
//
let channel_fact : Arc<dyn PooledItemFactory<Channel> + Send + Sync> = Arc::new(ChannelFactory::new(
"".to_string(),
client_cache_options.max_delay_ms,
client_cache_options.drop_rate,
client_cache_options.hang_rate,
));
let channel_fact: Arc<dyn PooledItemFactory<Channel> + Send + Sync> =
Arc::new(ChannelFactory::new(
"".to_string(),
client_cache_options.max_delay_ms,
client_cache_options.drop_rate,
client_cache_options.hang_rate,
));
let unary_pool: Arc<ConnectionPool<Channel>> = ConnectionPool::new(
Arc::clone(&channel_fact),
client_cache_options.connect_timeout,
@@ -79,42 +79,34 @@ async fn main() {
// -----------
// Dummy auth interceptor. This is not used in this test.
let auth_interceptor = AuthInterceptor::new("dummy_tenant_id",
"dummy_timeline_id",
None);
let tracker = RequestTracker::new(
pool,
unary_pool,
auth_interceptor,
ShardIndex::unsharded(),
);
let auth_interceptor = AuthInterceptor::new("dummy_tenant_id", "dummy_timeline_id", None);
let tracker = RequestTracker::new(pool, unary_pool, auth_interceptor, ShardIndex::unsharded());
// 4) fire off 10 000 requests in parallel
let mut handles = FuturesOrdered::new();
for _i in 0..500000 {
let mut rng = rand::thread_rng();
let r = 0..=1000000i128;
let key: i128 = rng.gen_range(r.clone());
let key = Key::from_i128(key);
let (rel_tag, block_no) = key
.to_rel_block()
.expect("we filter non-rel-block keys out above");
let mut rng = rand::thread_rng();
let r = 0..=1000000i128;
let key: i128 = rng.gen_range(r.clone());
let key = Key::from_i128(key);
let (rel_tag, block_no) = key
.to_rel_block()
.expect("we filter non-rel-block keys out above");
let req2 = proto::GetPageRequest {
request_id: 0,
request_class: proto::GetPageClass::Normal as i32,
read_lsn: Some(proto::ReadLsn {
request_lsn: if rng.gen_bool(0.5) {
u64::from(Lsn::MAX)
} else {
10000
},
not_modified_since_lsn: 10000,
}),
rel: Some(rel_tag.into()),
block_number: vec![block_no],
};
let req2 = proto::GetPageRequest {
request_id: 0,
request_class: proto::GetPageClass::Normal as i32,
read_lsn: Some(proto::ReadLsn {
request_lsn: if rng.gen_bool(0.5) {
u64::from(Lsn::MAX)
} else {
10000
},
not_modified_since_lsn: 10000,
}),
rel: Some(rel_tag.into()),
block_number: vec![block_no],
};
let req_model = pageserver_page_api::GetPageRequest::try_from(req2.clone());
// RequestTracker is Clone, so we can share it

View File

@@ -30,8 +30,8 @@ use http::Uri;
use hyper_util::rt::TokioIo;
use tower::service_fn;
use tokio_util::sync::CancellationToken;
use async_trait::async_trait;
use tokio_util::sync::CancellationToken;
//
// The "TokioTcp" is flakey TCP network for testing purposes, in order
@@ -168,7 +168,10 @@ impl AsyncWrite for TokioTcp {
#[async_trait]
pub trait PooledItemFactory<T>: Send + Sync + 'static {
/// Create a new pooled item.
async fn create(&self, connect_timeout: Duration) -> Result<Result<T, tonic::Status>, tokio::time::error::Elapsed>;
async fn create(
&self,
connect_timeout: Duration,
) -> Result<Result<T, tonic::Status>, tokio::time::error::Elapsed>;
}
pub struct ChannelFactory {
@@ -178,14 +181,8 @@ pub struct ChannelFactory {
hang_rate: f64,
}
impl ChannelFactory {
pub fn new(
endpoint: String,
max_delay_ms: u64,
drop_rate: f64,
hang_rate: f64,
) -> Self {
pub fn new(endpoint: String, max_delay_ms: u64, drop_rate: f64, hang_rate: f64) -> Self {
ChannelFactory {
endpoint,
max_delay_ms,
@@ -197,7 +194,10 @@ impl ChannelFactory {
#[async_trait]
impl PooledItemFactory<Channel> for ChannelFactory {
async fn create(&self, connect_timeout: Duration) -> Result<Result<Channel, tonic::Status>, tokio::time::error::Elapsed> {
async fn create(
&self,
connect_timeout: Duration,
) -> Result<Result<Channel, tonic::Status>, tokio::time::error::Elapsed> {
let max_delay_ms = self.max_delay_ms;
let drop_rate = self.drop_rate;
let hang_rate = self.hang_rate;
@@ -239,7 +239,6 @@ impl PooledItemFactory<Channel> for ChannelFactory {
}
});
let attempt = tokio::time::timeout(
connect_timeout,
Endpoint::from_shared(self.endpoint.clone())
@@ -247,26 +246,21 @@ impl PooledItemFactory<Channel> for ChannelFactory {
.timeout(connect_timeout)
.connect_with_connector(connector),
)
.await;
.await;
match attempt {
Ok(Ok(channel)) => {
// Connection succeeded
Ok(Ok(channel))
}
Ok(Err(e)) => {
Ok(Err(tonic::Status::new(
tonic::Code::Unavailable,
format!("Failed to connect: {}", e),
)))
}
Err(e) => {
Err(e)
}
Ok(Err(e)) => Ok(Err(tonic::Status::new(
tonic::Code::Unavailable,
format!("Failed to connect: {}", e),
))),
Err(e) => Err(e),
}
}
}
/// A pooled gRPC client with capacity tracking and error handling.
pub struct ConnectionPool<T> {
inner: Mutex<Inner<T>>,
@@ -511,15 +505,15 @@ impl<T: Clone + Send + 'static> ConnectionPool<T> {
let mut inner = self_clone.inner.lock().await;
inner.waiters += 1;
if inner.waiters > (inner.in_progress * self_clone.max_consumers) {
if (inner.entries.len() + inner.in_progress) < self_clone.max_total_connections {
if (inner.entries.len() + inner.in_progress)
< self_clone.max_total_connections
{
let self_clone_spawn = Arc::clone(&self_clone);
tokio::task::spawn(async move {
self_clone_spawn.create_connection().await;
});
inner.in_progress += 1;
}
}
}
// Wait for a connection to become available, either because it
@@ -548,7 +542,6 @@ impl<T: Clone + Send + 'static> ConnectionPool<T> {
}
async fn create_connection(&self) -> () {
// Generate a random backoff to add some jitter so that connections
// don't all retry at the same time.
let mut backoff_delay = Duration::from_millis(
@@ -595,9 +588,7 @@ impl<T: Clone + Send + 'static> ConnectionPool<T> {
None => {}
}
let attempt = self.fact
.create(self.connect_timeout)
.await;
let attempt = self.fact.create(self.connect_timeout).await;
match attempt {
// Connection succeeded
@@ -732,10 +723,8 @@ impl<T: Clone + Send + 'static> PooledClient<T> {
}
pub async fn finish(mut self, result: Result<(), tonic::Status>) {
self.is_ok = result.is_ok();
self.pool.return_client(
self.id,
self.is_ok,
self.permit,
).await;
self.pool
.return_client(self.id, self.is_ok, self.permit)
.await;
}
}

View File

@@ -7,30 +7,27 @@
//! account. In the future, there can be multiple pageservers per shard, and RequestTracker manages
//! load balancing between them, but that's not implemented yet.
use std::sync::Arc;
use pageserver_page_api::GetPageRequest;
use pageserver_page_api::GetPageResponse;
use pageserver_page_api::*;
use pageserver_page_api::proto;
use crate::client_cache;
use crate::client_cache::ConnectionPool;
use crate::client_cache::ChannelFactory;
use crate::AuthInterceptor;
use tonic::{transport::{Channel}, Request};
use crate::ClientCacheOptions;
use crate::PageserverClientAggregateMetrics;
use std::sync::atomic::AtomicU64;
use crate::client_cache;
use crate::client_cache::ChannelFactory;
use crate::client_cache::ConnectionPool;
use pageserver_page_api::GetPageRequest;
use pageserver_page_api::GetPageResponse;
use pageserver_page_api::proto;
use pageserver_page_api::*;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::atomic::AtomicU64;
use tonic::{Request, transport::Channel};
use utils::shard::ShardIndex;
use tokio_stream::wrappers::ReceiverStream;
use pageserver_page_api::proto::PageServiceClient;
use tokio_stream::wrappers::ReceiverStream;
use tonic::{
Status,
Code,
};
use tonic::{Code, Status};
use async_trait::async_trait;
use std::time::Duration;
@@ -45,20 +42,28 @@ use client_cache::PooledItemFactory;
#[derive(Clone)]
pub struct StreamReturner {
sender: tokio::sync::mpsc::Sender<proto::GetPageRequest>,
sender_hashmap: Arc<tokio::sync::Mutex<std::collections::HashMap<u64, tokio::sync::mpsc::Sender<Result<proto::GetPageResponse, Status>>>>>,
}
pub struct MockStreamFactory {
sender_hashmap: Arc<
tokio::sync::Mutex<
std::collections::HashMap<
u64,
tokio::sync::mpsc::Sender<Result<proto::GetPageResponse, Status>>,
>,
>,
>,
}
pub struct MockStreamFactory {}
impl MockStreamFactory {
pub fn new() -> Self {
MockStreamFactory {
}
MockStreamFactory {}
}
}
#[async_trait]
impl PooledItemFactory<StreamReturner> for MockStreamFactory {
async fn create(&self, _connect_timeout: Duration) -> Result<Result<StreamReturner, tonic::Status>, tokio::time::error::Elapsed> {
async fn create(
&self,
_connect_timeout: Duration,
) -> Result<Result<StreamReturner, tonic::Status>, tokio::time::error::Elapsed> {
let (sender, mut receiver) = tokio::sync::mpsc::channel::<proto::GetPageRequest>(1000);
// Create a StreamReturner that will send requests to the receiver channel
let stream_returner = StreamReturner {
@@ -69,7 +74,6 @@ impl PooledItemFactory<StreamReturner> for MockStreamFactory {
let map = Arc::clone(&stream_returner.sender_hashmap);
tokio::spawn(async move {
while let Some(request) = receiver.recv().await {
// Break out of the loop with 1% chance
if rand::random::<f32>() < 0.001 {
break;
@@ -111,7 +115,6 @@ impl PooledItemFactory<StreamReturner> for MockStreamFactory {
}
}
pub struct StreamFactory {
connection_pool: Arc<client_cache::ConnectionPool<Channel>>,
auth_interceptor: AuthInterceptor,
@@ -134,21 +137,22 @@ impl StreamFactory {
#[async_trait]
impl PooledItemFactory<StreamReturner> for StreamFactory {
async fn create(&self, _connect_timeout: Duration) ->
Result<Result<StreamReturner, tonic::Status>, tokio::time::error::Elapsed>
{
let pool_clone : Arc<ConnectionPool<Channel>> = Arc::clone(&self.connection_pool);
async fn create(
&self,
_connect_timeout: Duration,
) -> Result<Result<StreamReturner, tonic::Status>, tokio::time::error::Elapsed> {
let pool_clone: Arc<ConnectionPool<Channel>> = Arc::clone(&self.connection_pool);
let pooled_client = pool_clone.get_client().await;
let channel = pooled_client.unwrap().channel();
let mut client =
PageServiceClient::with_interceptor(channel, self.auth_interceptor.for_shard(self.shard));
let mut client = PageServiceClient::with_interceptor(
channel,
self.auth_interceptor.for_shard(self.shard),
);
let (sender, receiver) = tokio::sync::mpsc::channel::<proto::GetPageRequest>(1000);
let outbound = ReceiverStream::new(receiver);
let client_resp = client
.get_pages(Request::new(outbound))
.await;
let client_resp = client.get_pages(Request::new(outbound)).await;
match client_resp {
Err(status) => {
@@ -161,17 +165,23 @@ impl PooledItemFactory<StreamReturner> for StreamFactory {
Ok(resp) => {
let stream_returner = StreamReturner {
sender: sender.clone(),
sender_hashmap: Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())),
sender_hashmap: Arc::new(tokio::sync::Mutex::new(
std::collections::HashMap::new(),
)),
};
let map : Arc<tokio::sync::Mutex<std::collections::HashMap<u64, tokio::sync::mpsc::Sender<Result<proto::GetPageResponse, _>>>>>
= Arc::clone(&stream_returner.sender_hashmap);
let map: Arc<
tokio::sync::Mutex<
std::collections::HashMap<
u64,
tokio::sync::mpsc::Sender<Result<proto::GetPageResponse, _>>,
>,
>,
> = Arc::clone(&stream_returner.sender_hashmap);
tokio::spawn(async move {
let map_clone = Arc::clone(&map);
let mut inner = resp.into_inner();
loop {
let resp = inner.message().await;
if !resp.is_ok() {
break; // Exit the loop if no more messages
@@ -216,10 +226,11 @@ pub struct RequestTracker {
}
impl RequestTracker {
pub fn new(stream_pool: Arc<ConnectionPool<StreamReturner>>,
unary_pool: Arc<ConnectionPool<Channel>>,
auth_interceptor: AuthInterceptor,
shard: ShardIndex,
pub fn new(
stream_pool: Arc<ConnectionPool<StreamReturner>>,
unary_pool: Arc<ConnectionPool<Channel>>,
auth_interceptor: AuthInterceptor,
shard: ShardIndex,
) -> Self {
let cur_id = Arc::new(AtomicU64::new(0));
@@ -228,7 +239,7 @@ impl RequestTracker {
stream_pool: stream_pool,
unary_pool: unary_pool,
auth_interceptor: auth_interceptor,
shard: shard.clone()
shard: shard.clone(),
}
}
@@ -240,9 +251,14 @@ impl RequestTracker {
let unary_pool = Arc::clone(&self.unary_pool);
let pooled_client = unary_pool.get_client().await.unwrap();
let channel = pooled_client.channel();
let mut ps_client = PageServiceClient::with_interceptor(channel, self.auth_interceptor.for_shard(self.shard));
let mut ps_client = PageServiceClient::with_interceptor(
channel,
self.auth_interceptor.for_shard(self.shard),
);
let request = proto::CheckRelExistsRequest::from(req.clone());
let response = ps_client.check_rel_exists(tonic::Request::new(request)).await;
let response = ps_client
.check_rel_exists(tonic::Request::new(request))
.await;
match response {
Err(status) => {
@@ -266,7 +282,10 @@ impl RequestTracker {
let unary_pool = Arc::clone(&self.unary_pool);
let pooled_client = unary_pool.get_client().await.unwrap();
let channel = pooled_client.channel();
let mut ps_client = PageServiceClient::with_interceptor(channel, self.auth_interceptor.for_shard(self.shard));
let mut ps_client = PageServiceClient::with_interceptor(
channel,
self.auth_interceptor.for_shard(self.shard),
);
let request = proto::GetRelSizeRequest::from(req.clone());
let response = ps_client.get_rel_size(tonic::Request::new(request)).await;
@@ -281,7 +300,6 @@ impl RequestTracker {
return Ok(resp.get_ref().num_blocks);
}
}
}
}
@@ -292,8 +310,12 @@ impl RequestTracker {
loop {
// Current sharding model assumes that all metadata is present only at shard 0.
let unary_pool = Arc::clone(&self.unary_pool);
let pooled_client = unary_pool.get_client().await.unwrap();let channel = pooled_client.channel();
let mut ps_client = PageServiceClient::with_interceptor(channel, self.auth_interceptor.for_shard(self.shard));
let pooled_client = unary_pool.get_client().await.unwrap();
let channel = pooled_client.channel();
let mut ps_client = PageServiceClient::with_interceptor(
channel,
self.auth_interceptor.for_shard(self.shard),
);
let request = proto::GetDbSizeRequest::from(req.clone());
let response = ps_client.get_db_size(tonic::Request::new(request)).await;
@@ -308,7 +330,6 @@ impl RequestTracker {
return Ok(resp.get_ref().num_bytes);
}
}
}
}
@@ -322,7 +343,9 @@ impl RequestTracker {
//let request_id = self.cur_id.fetch_add(1, Ordering::SeqCst) + 1;
let request_id = request.request_id;
let response_sender: tokio::sync::mpsc::Sender<Result<proto::GetPageResponse, Status>>;
let mut response_receiver: tokio::sync::mpsc::Receiver<Result<proto::GetPageResponse, Status>>;
let mut response_receiver: tokio::sync::mpsc::Receiver<
Result<proto::GetPageResponse, Status>,
>;
(response_sender, response_receiver) = tokio::sync::mpsc::channel(1);
//request.request_id = request_id;
@@ -344,7 +367,9 @@ impl RequestTracker {
let mut map_inner = map.lock().await;
map_inner.insert(request_id, response_sender);
}
let sent = returner.sender.send(proto::GetPageRequest::from(request))
let sent = returner
.sender
.send(proto::GetPageRequest::from(request))
.await;
if let Err(_e) = sent {
@@ -354,22 +379,27 @@ impl RequestTracker {
// remove from hashmap
map_inner.remove(&request_id);
}
stream_returner.finish(Err(Status::new(Code::Unknown,
"Failed to send request"))).await;
stream_returner
.finish(Err(Status::new(Code::Unknown, "Failed to send request")))
.await;
continue;
}
let response: Option<Result<proto::GetPageResponse, Status>>;
response = response_receiver.recv().await;
match response {
Some (resp) => {
Some(resp) => {
match resp {
Err(_status) => {
// Handle the case where the response was not received
stream_returner.finish(Err(Status::new(Code::Unknown,
"Failed to receive response"))).await;
stream_returner
.finish(Err(Status::new(
Code::Unknown,
"Failed to receive response",
)))
.await;
continue;
},
}
Ok(resp) => {
stream_returner.finish(Result::Ok(())).await;
return Ok(resp.clone().into());
@@ -378,8 +408,9 @@ impl RequestTracker {
}
None => {
// Handle the case where the response channel was closed
stream_returner.finish(Err(Status::new(Code::Unknown,
"Response channel closed"))).await;
stream_returner
.finish(Err(Status::new(Code::Unknown, "Response channel closed")))
.await;
continue;
}
}
@@ -407,25 +438,25 @@ impl ShardedRequestTracker {
// Default configuration for the client. These could be added to a config file
//
let tcp_client_cache_options = ClientCacheOptions {
max_delay_ms: 0,
drop_rate: 0.0,
hang_rate: 0.0,
connect_timeout: Duration::from_secs(1),
connect_backoff: Duration::from_millis(100),
max_consumers: 8, // Streams per connection
error_threshold: 10,
max_idle_duration: Duration::from_secs(5),
max_delay_ms: 0,
drop_rate: 0.0,
hang_rate: 0.0,
connect_timeout: Duration::from_secs(1),
connect_backoff: Duration::from_millis(100),
max_consumers: 8, // Streams per connection
error_threshold: 10,
max_idle_duration: Duration::from_secs(5),
max_total_connections: 8,
};
let stream_client_cache_options = ClientCacheOptions {
max_delay_ms: 0,
drop_rate: 0.0,
hang_rate: 0.0,
connect_timeout: Duration::from_secs(1),
connect_backoff: Duration::from_millis(100),
max_consumers: 64, // Requests per stream
error_threshold: 10,
max_idle_duration: Duration::from_secs(5),
max_delay_ms: 0,
drop_rate: 0.0,
hang_rate: 0.0,
connect_timeout: Duration::from_secs(1),
connect_backoff: Duration::from_millis(100),
max_consumers: 64, // Requests per stream
error_threshold: 10,
max_idle_duration: Duration::from_secs(5),
max_total_connections: 64, // Total allowable number of streams
};
ShardedRequestTracker {
@@ -437,23 +468,26 @@ impl ShardedRequestTracker {
}
}
pub async fn update_shard_map(&self,
shard_urls: std::collections::HashMap<ShardIndex, String>,
metrics: Option<Arc<PageserverClientAggregateMetrics>>,
tenant_id: String, timeline_id: String, auth_str: Option<&str>) {
let mut trackers = std::collections::HashMap::new();
pub async fn update_shard_map(
&self,
shard_urls: std::collections::HashMap<ShardIndex, String>,
metrics: Option<Arc<PageserverClientAggregateMetrics>>,
tenant_id: String,
timeline_id: String,
auth_str: Option<&str>,
) {
let mut trackers = std::collections::HashMap::new();
for (shard, endpoint_url) in shard_urls {
//
// Create a pool of streams for streaming get_page requests
//
let channel_fact : Arc<dyn PooledItemFactory<Channel> + Send + Sync> = Arc::new(ChannelFactory::new(
endpoint_url.clone(),
self.tcp_client_cache_options.max_delay_ms,
self.tcp_client_cache_options.drop_rate,
self.tcp_client_cache_options.hang_rate,
));
let channel_fact: Arc<dyn PooledItemFactory<Channel> + Send + Sync> =
Arc::new(ChannelFactory::new(
endpoint_url.clone(),
self.tcp_client_cache_options.max_delay_ms,
self.tcp_client_cache_options.drop_rate,
self.tcp_client_cache_options.hang_rate,
));
let new_pool: Arc<ConnectionPool<Channel>>;
new_pool = ConnectionPool::new(
Arc::clone(&channel_fact),
@@ -466,13 +500,15 @@ impl ShardedRequestTracker {
metrics.clone(),
);
let auth_interceptor = AuthInterceptor::new(tenant_id.as_str(),
timeline_id.as_str(),
auth_str);
let auth_interceptor =
AuthInterceptor::new(tenant_id.as_str(), timeline_id.as_str(), auth_str);
let stream_pool = ConnectionPool::<StreamReturner>::new(
Arc::new(StreamFactory::new(new_pool.clone(),
auth_interceptor.clone(), ShardIndex::unsharded())),
Arc::new(StreamFactory::new(
new_pool.clone(),
auth_interceptor.clone(),
ShardIndex::unsharded(),
)),
self.stream_client_cache_options.connect_timeout,
self.stream_client_cache_options.connect_backoff,
self.stream_client_cache_options.max_consumers,
@@ -495,7 +531,7 @@ impl ShardedRequestTracker {
self.tcp_client_cache_options.error_threshold,
self.tcp_client_cache_options.max_idle_duration,
self.tcp_client_cache_options.max_total_connections,
metrics.clone()
metrics.clone(),
);
//
// Create a new RequestTracker for this shard
@@ -507,11 +543,7 @@ impl ShardedRequestTracker {
inner.trackers = trackers;
}
pub async fn get_page(
&self,
req: GetPageRequest,
) -> Result<GetPageResponse, tonic::Status> {
pub async fn get_page(&self, req: GetPageRequest) -> Result<GetPageResponse, tonic::Status> {
// Get shard index from the request and look up the RequestTracker instance for that shard
let shard_index = ShardIndex::unsharded(); // TODO!
let mut tracker = self.lookup_tracker_for_shard(shard_index)?;

View File

@@ -1,4 +1,4 @@
use std::collections::{HashSet, HashMap, VecDeque};
use std::collections::{HashMap, HashSet, VecDeque};
use std::future::Future;
use std::num::NonZeroUsize;
use std::pin::Pin;

View File

@@ -93,11 +93,15 @@ pub(super) async fn init(
.worker_process_init(last_lsn, file_cache);
let request_tracker = ShardedRequestTracker::new();
request_tracker.update_shard_map(shard_map,
None,
tenant_id,
timeline_id,
auth_token.as_deref()).await;
request_tracker
.update_shard_map(
shard_map,
None,
tenant_id,
timeline_id,
auth_token.as_deref(),
)
.await;
let request_counters = IntCounterVec::new(
metrics::core::Opts::new(

View File

@@ -5,7 +5,9 @@ use std::sync::Arc;
use std::time::Duration;
use anyhow::Context;
use control_plane::endpoint::{ComputeControlPlane, EndpointStatus, PageserverConnectionInfo, PageserverShardConnectionInfo};
use control_plane::endpoint::{
ComputeControlPlane, EndpointStatus, PageserverConnectionInfo, PageserverShardConnectionInfo,
};
use control_plane::local_env::LocalEnv;
use futures::StreamExt;
use hyper::StatusCode;
@@ -438,8 +440,8 @@ impl ComputeHook {
format!("postgres://no_user@{host}:{port}")
});
let grpc_url = if let Some(grpc_addr) = &ps_conf.listen_grpc_addr {
let (host, port) = parse_host_port(grpc_addr)
.expect("invalid gRPC address");
let (host, port) =
parse_host_port(grpc_addr).expect("invalid gRPC address");
let port = port.unwrap_or(DEFAULT_GRPC_LISTEN_PORT);
Some(format!("grpc://no_user@{host}:{port}"))
} else {