diff --git a/Cargo.lock b/Cargo.lock index 9a6f7fe2ca..78cab72c10 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4592,17 +4592,22 @@ dependencies = [ name = "pageserver_client_grpc" version = "0.1.0" dependencies = [ + "async-trait", "bytes", + "chrono", + "dashmap 5.5.0", "futures", "http 1.1.0", "hyper 1.6.0", "hyper-util", "metrics", + "pageserver_api", "pageserver_page_api", "priority-queue", "rand 0.8.5", "thiserror 1.0.69", "tokio", + "tokio-stream", "tokio-util", "tonic 0.13.1", "tower 0.4.13", diff --git a/pageserver/client_grpc/Cargo.toml b/pageserver/client_grpc/Cargo.toml index d0e162fbbe..fdd838c098 100644 --- a/pageserver/client_grpc/Cargo.toml +++ b/pageserver/client_grpc/Cargo.toml @@ -19,7 +19,12 @@ hyper-util = "0.1.9" hyper = "1.6.0" metrics.workspace = true priority-queue = "2.3.1" +async-trait = { version = "0.1" } +tokio-stream = "0.1" +dashmap = "5" +chrono = { version = "0.4", features = ["serde"] } pageserver_page_api.workspace = true +pageserver_api.workspace = true utils.workspace = true diff --git a/pageserver/client_grpc/examples/load_test.rs b/pageserver/client_grpc/examples/load_test.rs new file mode 100644 index 0000000000..75165a65b7 --- /dev/null +++ b/pageserver/client_grpc/examples/load_test.rs @@ -0,0 +1,296 @@ +// examples/load_test.rs, generated by AI + +use std::collections::{HashMap, HashSet}; +use std::sync::{ + Arc, + Mutex, + atomic::{AtomicU64, AtomicUsize, Ordering}, +}; +use std::time::{Duration, Instant}; + +use tokio::task; +use tokio::time::sleep; +use rand::Rng; +use tonic::Status; +use uuid::Uuid; + +// Pull in your ConnectionPool and PooledItemFactory from the pageserver_client_grpc crate. +// Adjust these paths if necessary. +use pageserver_client_grpc::client_cache::ConnectionPool; +use pageserver_client_grpc::client_cache::PooledItemFactory; + +// -------------------------------------- +// GLOBAL COUNTERS FOR “CREATED” / “DROPPED” MockConnections +// -------------------------------------- +static CREATED: AtomicU64 = AtomicU64::new(0); +static DROPPED: AtomicU64 = AtomicU64::new(0); + +// -------------------------------------- +// MockConnection + Factory +// -------------------------------------- + +#[derive(Debug)] +pub struct MockConnection { + pub id: u64, +} + +impl Clone for MockConnection { + fn clone(&self) -> Self { + // Cloning a MockConnection does NOT count as “creating” a brand‐new connection, + // so we do NOT bump CREATED here. We only bump CREATED in the factory’s `create()`. + CREATED.fetch_add(1, Ordering::Relaxed); + MockConnection { id: self.id } + } +} + +impl Drop for MockConnection { + fn drop(&mut self) { + // When a MockConnection actually gets dropped, bump the counter. + DROPPED.fetch_add(1, Ordering::SeqCst); + } +} + +pub struct MockConnectionFactory { + counter: AtomicU64, +} + +impl MockConnectionFactory { + pub fn new() -> Self { + MockConnectionFactory { + counter: AtomicU64::new(1), + } + } +} + +#[async_trait::async_trait] +impl PooledItemFactory for MockConnectionFactory { + /// The trait on ConnectionPool expects: + /// async fn create(&self, timeout: Duration) + /// -> Result, tokio::time::error::Elapsed>; + /// + /// On success: Ok(Ok(MockConnection)) + /// On a simulated “gRPC” failure: Ok(Err(Status::…)) + /// On a transport/factory error: Err(Box<…>) + async fn create( + &self, + _timeout: Duration, + ) -> Result, tokio::time::error::Elapsed> { + // Simulate connection creation immediately succeeding. + CREATED.fetch_add(1, Ordering::SeqCst); + let next_id = self.counter.fetch_add(1, Ordering::Relaxed); + Ok(Ok(MockConnection { id: next_id })) + } +} + +// -------------------------------------- +// CLIENT WORKER +// -------------------------------------- +// +// Each worker repeatedly calls `pool.get_client().await`. When it succeeds, we: +// 1. Lock the shared Mutex>> to fetch/insert an Arc for this conn_id. +// 2. Lock the shared Mutex> to record this conn_id as “seen.” +// 3. Drop both locks, then atomically increment that counter and assert it ≤ max_consumers. +// 4. Sleep 10–100 ms to simulate “work.” +// 5. Atomically decrement the counter. +// 6. Call `pooled.finish(Ok(()))` to return to the pool. + +async fn client_worker( + pool: Arc>, + usage_map: Arc>>>, + seen_set: Arc>>, + max_consumers: usize, + worker_id: usize, +) { + for iteration in 0..10 { + match pool.clone().get_client().await { + Ok(pooled) => { + let conn: MockConnection = pooled.channel(); + let conn_id = conn.id; + + // 1. Fetch or insert the Arc for this conn_id: + let counter_arc: Arc = { + let mut guard = usage_map.lock().unwrap(); + guard + .entry(conn_id) + .or_insert_with(|| Arc::new(AtomicUsize::new(0))) + .clone() + // MutexGuard is dropped here + }; + + // 2. Record this conn_id in the shared HashSet of “seen” IDs: + { + let mut seen_guard = seen_set.lock().unwrap(); + seen_guard.insert(conn_id); + // MutexGuard is dropped immediately + } + + // 3. Atomically bump the count for this connection ID + let prev = counter_arc.fetch_add(1, Ordering::SeqCst); + let current = prev + 1; + assert!( + current <= max_consumers, + "Connection {} exceeded max_consumers (got {})", + conn_id, + current + ); + + println!( + "[worker {}][iter {}] got MockConnection id={} ({} concurrent)", + worker_id, iteration, conn_id, current + ); + + // 4. Simulate some work (10–100 ms) + let delay_ms = rand::thread_rng().gen_range(10..100); + sleep(Duration::from_millis(delay_ms)).await; + + // 5. Decrement the usage counter + let prev2 = counter_arc.fetch_sub(1, Ordering::SeqCst); + let after = prev2 - 1; + println!( + "[worker {}][iter {}] returning MockConnection id={} (now {} remain)", + worker_id, iteration, conn_id, after + ); + + // 6. Return to the pool (mark success) + pooled.finish(Ok(())).await; + } + Err(status) => { + eprintln!( + "[worker {}][iter {}] failed to get client: {:?}", + worker_id, iteration, status + ); + } + } + + // Small random pause before next iteration to spread out load + let pause = rand::thread_rng().gen_range(0..20); + sleep(Duration::from_millis(pause)).await; + } +} + +#[tokio::main(flavor = "multi_thread", worker_threads = 8)] +async fn main() { + // -------------------------------------- + // 1. Create factory and shared instrumentation + // -------------------------------------- + let factory = Arc::new(MockConnectionFactory::new()); + + // Shared map: connection ID → Arc + let usage_map: Arc>>> = + Arc::new(Mutex::new(HashMap::new())); + + // Shared set: record each unique connection ID we actually saw + let seen_set: Arc>> = Arc::new(Mutex::new(HashSet::new())); + + // -------------------------------------- + // 2. Pool parameters + // -------------------------------------- + let connect_timeout = Duration::from_millis(500); + let connect_backoff = Duration::from_millis(100); + let max_consumers = 100; // test limit + let error_threshold = 2; // mock never fails + let max_idle_duration = Duration::from_secs(2); + let max_total_connections = 3; + let aggregate_metrics = None; + + let pool: Arc> = ConnectionPool::new( + factory, + connect_timeout, + connect_backoff, + max_consumers, + error_threshold, + max_idle_duration, + max_total_connections, + aggregate_metrics, + ); + + // -------------------------------------- + // 3. Spawn worker tasks + // -------------------------------------- + let num_workers = 10000; + let mut handles = Vec::with_capacity(num_workers); + let start_time = Instant::now(); + + for worker_id in 0..num_workers { + let pool_clone = Arc::clone(&pool); + let usage_clone = Arc::clone(&usage_map); + let seen_clone = Arc::clone(&seen_set); + let mc = max_consumers; + + let handle = task::spawn(async move { + client_worker(pool_clone, usage_clone, seen_clone, mc, worker_id).await; + }); + handles.push(handle); + } + + // -------------------------------------- + // 4. Wait for workers to finish + // -------------------------------------- + for handle in handles { + let _ = handle.await; + } + let elapsed = Instant::now().duration_since(start_time); + println!( + "All {} workers completed in {:?}", + num_workers, elapsed + ); + + // -------------------------------------- + // 5. Print the total number of unique connections seen so far + // -------------------------------------- + let unique_count = { + let seen_guard = seen_set.lock().unwrap(); + seen_guard.len() + }; + println!("Total unique connections used by workers: {}", unique_count); + + // -------------------------------------- + // 6. Sleep so the background sweeper can run (max_idle_duration = 2 s) + // -------------------------------------- + sleep(Duration::from_secs(3)).await; + + // -------------------------------------- + // 7. Shutdown the pool + // -------------------------------------- + let shutdown_pool = Arc::clone(&pool); + shutdown_pool.shutdown().await; + println!("Pool.shutdown() returned."); + + // -------------------------------------- + // 8. Verify that no background task still holds an Arc clone of `pool`. + // If any task is still alive (sweeper/create_connection), strong_count > 1. + // -------------------------------------- + sleep(Duration::from_secs(1)).await; // give tasks time to exit + let sc = Arc::strong_count(&pool); + assert!( + sc == 1, + "Pool tasks did not all terminate: Arc::strong_count = {} (expected 1)", + sc + ); + println!("Verified: all pool tasks have terminated (strong_count == 1)."); + + // -------------------------------------- + // 9. Verify no MockConnection was leaked: + // CREATED must equal DROPPED. + // -------------------------------------- + let created = CREATED.load(Ordering::SeqCst); + let dropped = DROPPED.load(Ordering::SeqCst); + assert!( + created == dropped, + "Leaked connections: created={} but dropped={}", + created, + dropped + ); + println!( + "Verified: no connections leaked (created = {}, dropped = {}).", + created, dropped + ); + + // -------------------------------------- + // 10. Because `client_worker` asserted inside that no connection + // ever exceeded `max_consumers`, reaching this point means that check passed. + // -------------------------------------- + println!("All per-connection usage stayed within max_consumers = {}.", max_consumers); + + println!("Load test complete; exiting cleanly."); +} diff --git a/pageserver/client_grpc/examples/request_tracker_load_test.rs b/pageserver/client_grpc/examples/request_tracker_load_test.rs new file mode 100644 index 0000000000..0e20e2acdd --- /dev/null +++ b/pageserver/client_grpc/examples/request_tracker_load_test.rs @@ -0,0 +1,160 @@ +// examples/request_tracker_load_test.rs + +use std::{sync::Arc, time::Duration}; +use tokio; +use pageserver_client_grpc::request_tracker::RequestTracker; +use pageserver_client_grpc::request_tracker::MockStreamFactory; +use pageserver_client_grpc::request_tracker::StreamReturner; +use pageserver_client_grpc::client_cache::ConnectionPool; +use pageserver_client_grpc::client_cache::PooledItemFactory; +use pageserver_client_grpc::ClientCacheOptions; +use pageserver_client_grpc::PageserverClientAggregateMetrics; +use pageserver_client_grpc::AuthInterceptor; + +use pageserver_client_grpc::client_cache::ChannelFactory; + +use tonic::{transport::{Channel}, Request}; + +use rand::prelude::*; + +use pageserver_api::key::Key; + +use utils::lsn::Lsn; +use utils::id::TenantTimelineId; + +use futures::stream::FuturesOrdered; +use futures::StreamExt; +// use chrono +use chrono::Utc; + +use pageserver_page_api::{GetPageClass, GetPageResponse}; +use pageserver_page_api::proto; +#[derive(Clone)] +struct KeyRange { + timeline: TenantTimelineId, + timeline_lsn: Lsn, + start: i128, + end: i128, +} + +impl KeyRange { + fn len(&self) -> i128 { + self.end - self.start + } +} + +#[tokio::main] +async fn main() { + // 1) configure the client‐pool behavior + let client_cache_options = ClientCacheOptions { + max_delay_ms: 0, + drop_rate: 0.0, + hang_rate: 0.0, + connect_timeout: Duration::from_secs(10), + connect_backoff: Duration::from_millis(200), + max_consumers: 64, + error_threshold: 10, + max_idle_duration: Duration::from_secs(60), + max_total_connections: 12, + }; + + // 2) metrics collector (we assume Default is implemented) + let metrics = Arc::new(PageserverClientAggregateMetrics::new()); + let pool = ConnectionPool::::new( + Arc::new(MockStreamFactory::new( + )), + client_cache_options.connect_timeout, + client_cache_options.connect_backoff, + client_cache_options.max_consumers, + client_cache_options.error_threshold, + client_cache_options.max_idle_duration, + client_cache_options.max_total_connections, + Some(Arc::clone(&metrics)), + ); + + // ----------- + // There is no mock for the unary connection pool, so for now just + // don't use this pool + // + let channel_fact : Arc + Send + Sync> = Arc::new(ChannelFactory::new( + "".to_string(), + client_cache_options.max_delay_ms, + client_cache_options.drop_rate, + client_cache_options.hang_rate, + )); + let unary_pool: Arc> = ConnectionPool::new( + Arc::clone(&channel_fact), + client_cache_options.connect_timeout, + client_cache_options.connect_backoff, + client_cache_options.max_consumers, + client_cache_options.error_threshold, + client_cache_options.max_idle_duration, + client_cache_options.max_total_connections, + Some(Arc::clone(&metrics)), + ); + + // ----------- + // Dummy auth interceptor. This is not used in this test. + let auth_interceptor = AuthInterceptor::new("dummy_tenant_id", + "dummy_timeline_id", + None); + let mut tracker = RequestTracker::new( + pool, + unary_pool, + auth_interceptor, + ); + + // 4) fire off 10 000 requests in parallel + let mut handles = FuturesOrdered::new(); + for i in 0..500000 { + + let mut rng = rand::thread_rng(); + let r = 0..=1000000i128; + let key: i128 = rng.gen_range(r.clone()); + let key = Key::from_i128(key); + let (rel_tag, block_no) = key + .to_rel_block() + .expect("we filter non-rel-block keys out above"); + + let req2 = proto::GetPageRequest { + request_id: 0, + request_class: proto::GetPageClass::Normal as i32, + read_lsn: Some(proto::ReadLsn { + request_lsn: if rng.gen_bool(0.5) { + u64::from(Lsn::MAX) + } else { + 10000 + }, + not_modified_since_lsn: 10000, + }), + rel: Some(rel_tag.into()), + block_number: vec![block_no], + }; + let req_model = pageserver_page_api::GetPageRequest::try_from(req2.clone()); + + // RequestTracker is Clone, so we can share it + let mut tr = tracker.clone(); + let fut = async move { + let resp = tr.send_getpage_request(req_model.unwrap()).await.unwrap(); + // sanity‐check: the mock echo returns the same request_id + assert!(resp.request_id > 0); + }; + handles.push_back(fut); + + // empty future + let fut = async move {}; + fut.await; + } + + // print timestamp + println!("Starting 5000000 requests at: {}", chrono::Utc::now()); + // 5) wait for them all + for i in 0..500000 { + handles.next().await.expect("Failed to get next handle"); + } + + // print timestamp + println!("Finished 5000000 requests at: {}", chrono::Utc::now()); + + println!("✅ All 100000 requests completed successfully"); +} diff --git a/pageserver/client_grpc/src/client_cache.rs b/pageserver/client_grpc/src/client_cache.rs index b58a7119a4..89c2d2b44e 100644 --- a/pageserver/client_grpc/src/client_cache.rs +++ b/pageserver/client_grpc/src/client_cache.rs @@ -31,6 +31,7 @@ use hyper_util::rt::TokioIo; use tower::service_fn; use tokio_util::sync::CancellationToken; +use async_trait::async_trait; // // The "TokioTcp" is flakey TCP network for testing purposes, in order @@ -164,32 +165,132 @@ impl AsyncWrite for TokioTcp { } } -/// A pooled gRPC client with capacity tracking and error handling. -pub struct ConnectionPool { - inner: Mutex, +#[async_trait] +pub trait PooledItemFactory: Send + Sync + 'static { + /// Create a new pooled item. + async fn create(&self, connect_timeout: Duration) -> Result, tokio::time::error::Elapsed>; +} - // Config options that apply to each connection +pub struct ChannelFactory { endpoint: String, - max_consumers: usize, - error_threshold: usize, - connect_timeout: Duration, - connect_backoff: Duration, - - // Parameters for testing max_delay_ms: u64, drop_rate: f64, hang_rate: f64, +} - // The maximum duration a connection can be idle before being removed + +impl ChannelFactory { + pub fn new( + endpoint: String, + max_delay_ms: u64, + drop_rate: f64, + hang_rate: f64, + ) -> Self { + ChannelFactory { + endpoint, + max_delay_ms, + drop_rate, + hang_rate, + } + } +} + +#[async_trait] +impl PooledItemFactory for ChannelFactory { + async fn create(&self, connect_timeout: Duration) -> Result, tokio::time::error::Elapsed> { + let max_delay_ms = self.max_delay_ms; + let drop_rate = self.drop_rate; + let hang_rate = self.hang_rate; + + // This is a custom connector that inserts delays and errors, for + // testing purposes. It would normally be disabled by the config. + let connector = service_fn(move |uri: Uri| { + let drop_rate = drop_rate; + let hang_rate = hang_rate; + async move { + let mut rng = StdRng::from_entropy(); + // Simulate an indefinite hang + if hang_rate > 0.0 && rng.gen_bool(hang_rate) { + // never completes, to test timeout + return future::pending::, std::io::Error>>().await; + } + + // Random drop (connect error) + if drop_rate > 0.0 && rng.gen_bool(drop_rate) { + return Err(std::io::Error::new( + std::io::ErrorKind::Other, + "simulated connect drop", + )); + } + + // Otherwise perform real TCP connect + let addr = match (uri.host(), uri.port()) { + // host + explicit port + (Some(host), Some(port)) => format!("{}:{}", host, port.as_str()), + // host only (no port) + (Some(host), None) => host.to_string(), + // neither? error out + _ => return Err(Error::new(ErrorKind::InvalidInput, "no host or port")), + }; + + let tcp = TcpStream::connect(addr).await?; + let tcpwrapper = TokioTcp::new(tcp, max_delay_ms); + Ok(TokioIo::new(tcpwrapper)) + } + }); + + + let attempt = tokio::time::timeout( + connect_timeout, + Endpoint::from_shared(self.endpoint.clone()) + .expect("invalid endpoint") + .timeout(connect_timeout) + .connect_with_connector(connector), + ) + .await; + match attempt { + Ok(Ok(channel)) => { + // Connection succeeded + Ok(Ok(channel)) + } + Ok(Err(e)) => { + Ok(Err(tonic::Status::new( + tonic::Code::Unavailable, + format!("Failed to connect: {}", e), + ))) + } + Err(e) => { + Err(e) + } + } + } +} + + +/// A pooled gRPC client with capacity tracking and error handling. +pub struct ConnectionPool { + inner: Mutex>, + + fact: Arc + Send + Sync>, + + connect_timeout: Duration, + connect_backoff: Duration, + /// The maximum number of consumers that can use a single connection. + max_consumers: usize, + /// The number of consecutive errors before a connection is removed from the pool. + error_threshold: usize, + /// The maximum duration a connection can be idle before being removed. max_idle_duration: Duration, + max_total_connections: usize, + channel_semaphore: Arc, shutdown_token: CancellationToken, aggregate_metrics: Option>, } -struct Inner { - entries: HashMap, +struct Inner { + entries: HashMap>, pq: PriorityQueue, // This is updated when a connection is dropped, or we fail // to create a new connection. @@ -197,54 +298,50 @@ struct Inner { waiters: usize, in_progress: usize, } - -struct ConnectionEntry { - channel: Channel, +struct ConnectionEntry { + channel: T, active_consumers: usize, consecutive_errors: usize, last_used: Instant, } /// A client borrowed from the pool. -pub struct PooledClient { - pub channel: Channel, - pool: Arc, +pub struct PooledClient { + pub channel: T, + pool: Arc>, + is_ok: bool, id: uuid::Uuid, permit: OwnedSemaphorePermit, } -impl ConnectionPool { +impl ConnectionPool { pub fn new( - endpoint: &String, - max_consumers: usize, - error_threshold: usize, + fact: Arc + Send + Sync>, connect_timeout: Duration, connect_backoff: Duration, + max_consumers: usize, + error_threshold: usize, max_idle_duration: Duration, - max_delay_ms: u64, - drop_rate: f64, - hang_rate: f64, + max_total_connections: usize, aggregate_metrics: Option>, ) -> Arc { let shutdown_token = CancellationToken::new(); let pool = Arc::new(Self { - inner: Mutex::new(Inner { + inner: Mutex::new(Inner:: { entries: HashMap::new(), pq: PriorityQueue::new(), last_connect_failure: None, waiters: 0, in_progress: 0, }), + fact: Arc::clone(&fact), + connect_timeout, + connect_backoff, + max_consumers, + error_threshold, + max_idle_duration, + max_total_connections, channel_semaphore: Arc::new(Semaphore::new(0)), - endpoint: endpoint.clone(), - max_consumers: max_consumers, - error_threshold: error_threshold, - connect_timeout: connect_timeout, - connect_backoff: connect_backoff, - max_idle_duration: max_idle_duration, - max_delay_ms: max_delay_ms, - drop_rate: drop_rate, - hang_rate: hang_rate, shutdown_token: shutdown_token.clone(), aggregate_metrics: aggregate_metrics.clone(), }); @@ -325,7 +422,7 @@ impl ConnectionPool { async fn get_conn_with_permit( self: Arc, permit: OwnedSemaphorePermit, - ) -> Option { + ) -> Option> { let mut inner = self.inner.lock().await; // Pop the highest-active-consumers connection. There are no connections @@ -340,9 +437,10 @@ impl ConnectionPool { entry.active_consumers += 1; entry.last_used = Instant::now(); - let client = PooledClient { + let client = PooledClient:: { channel: entry.channel.clone(), pool: Arc::clone(&self), + is_ok: true, id, permit: permit, }; @@ -365,7 +463,7 @@ impl ConnectionPool { } } - pub async fn get_client(self: Arc) -> Result { + pub async fn get_client(self: Arc) -> Result, tonic::Status> { // The pool is shutting down. Don't accept new connections. if self.shutdown_token.is_cancelled() { return Err(tonic::Status::unavailable("Pool is shutting down")); @@ -412,12 +510,16 @@ impl ConnectionPool { // let mut inner = self_clone.inner.lock().await; inner.waiters += 1; - if inner.waiters >= (inner.in_progress * self_clone.max_consumers) { - let self_clone_spawn = Arc::clone(&self_clone); - tokio::task::spawn(async move { - self_clone_spawn.create_connection().await; - }); - inner.in_progress += 1; + if inner.waiters > (inner.in_progress * self_clone.max_consumers) { + if (inner.entries.len() + inner.in_progress) < self_clone.max_total_connections { + + let self_clone_spawn = Arc::clone(&self_clone); + tokio::task::spawn(async move { + self_clone_spawn.create_connection().await; + }); + inner.in_progress += 1; + } + } } // Wait for a connection to become available, either because it @@ -446,46 +548,6 @@ impl ConnectionPool { } async fn create_connection(&self) -> () { - let max_delay_ms = self.max_delay_ms; - let drop_rate = self.drop_rate; - let hang_rate = self.hang_rate; - - // This is a custom connector that inserts delays and errors, for - // testing purposes. It would normally be disabled by the config. - let connector = service_fn(move |uri: Uri| { - let drop_rate = drop_rate; - let hang_rate = hang_rate; - async move { - let mut rng = StdRng::from_entropy(); - // Simulate an indefinite hang - if hang_rate > 0.0 && rng.gen_bool(hang_rate) { - // never completes, to test timeout - return future::pending::, std::io::Error>>().await; - } - - // Random drop (connect error) - if drop_rate > 0.0 && rng.gen_bool(drop_rate) { - return Err(std::io::Error::new( - std::io::ErrorKind::Other, - "simulated connect drop", - )); - } - - // Otherwise perform real TCP connect - let addr = match (uri.host(), uri.port()) { - // host + explicit port - (Some(host), Some(port)) => format!("{}:{}", host, port.as_str()), - // host only (no port) - (Some(host), None) => host.to_string(), - // neither? error out - _ => return Err(Error::new(ErrorKind::InvalidInput, "no host or port")), - }; - - let tcp = TcpStream::connect(addr).await?; - let tcpwrapper = TokioTcp::new(tcp, max_delay_ms); - Ok(TokioIo::new(tcpwrapper)) - } - }); // Generate a random backoff to add some jitter so that connections // don't all retry at the same time. @@ -533,14 +595,9 @@ impl ConnectionPool { None => {} } - let attempt = tokio::time::timeout( - self.connect_timeout, - Endpoint::from_shared(self.endpoint.clone()) - .expect("invalid endpoint") - .timeout(self.connect_timeout) - .connect_with_connector(connector), - ) - .await; + let attempt = self.fact + .create(self.connect_timeout) + .await; match attempt { // Connection succeeded @@ -559,7 +616,7 @@ impl ConnectionPool { let id = uuid::Uuid::new_v4(); inner.entries.insert( id, - ConnectionEntry { + ConnectionEntry:: { channel: channel.clone(), active_consumers: 0, consecutive_errors: 0, @@ -641,6 +698,11 @@ impl ConnectionPool { inner.pq.remove(&id); } + // remove from entries + // check if entry is in inner + if inner.entries.contains_key(&id) { + inner.entries.remove(&id); + } inner.last_connect_failure = Some(Instant::now()); // The connection has been removed, it's permits will be @@ -661,18 +723,19 @@ impl ConnectionPool { } } } - // The semaphore permit is released when the pooled client is dropped. } } -impl PooledClient { - pub fn channel(&self) -> Channel { +impl PooledClient { + pub fn channel(&self) -> T { return self.channel.clone(); } - - pub async fn finish(self, result: Result<(), tonic::Status>) { - self.pool - .return_client(self.id, result.is_ok(), self.permit) - .await; + pub async fn finish(mut self, result: Result<(), tonic::Status>) { + self.is_ok = result.is_ok(); + self.pool.return_client( + self.id, + self.is_ok, + self.permit, + ).await; } } diff --git a/pageserver/client_grpc/src/lib.rs b/pageserver/client_grpc/src/lib.rs index 2947754817..e709068beb 100644 --- a/pageserver/client_grpc/src/lib.rs +++ b/pageserver/client_grpc/src/lib.rs @@ -20,9 +20,16 @@ use pageserver_page_api::proto::PageServiceClient; use utils::shard::ShardIndex; use std::fmt::Debug; -mod client_cache; +pub mod client_cache; +pub mod request_tracker; +use tonic::transport::Channel; use metrics::{IntCounterVec, core::Collector}; +use crate::client_cache::{PooledItemFactory}; + +use tokio::sync::mpsc; +use async_trait::async_trait; + #[derive(Error, Debug)] pub enum PageserverClientError { @@ -77,6 +84,7 @@ impl PageserverClientAggregateMetrics { metrics } } + pub struct PageserverClient { _tenant_id: String, _timeline_id: String, @@ -85,7 +93,7 @@ pub struct PageserverClient { shard_map: HashMap, - channels: RwLock>>, + channels: RwLock>>>, auth_interceptor: AuthInterceptor, @@ -93,13 +101,14 @@ pub struct PageserverClient { aggregate_metrics: Option>, } - +#[derive(Clone)] pub struct ClientCacheOptions { pub max_consumers: usize, pub error_threshold: usize, pub connect_timeout: Duration, pub connect_backoff: Duration, pub max_idle_duration: Duration, + pub max_total_connections: usize, pub max_delay_ms: u64, pub drop_rate: f64, pub hang_rate: f64, @@ -119,6 +128,7 @@ impl PageserverClient { connect_timeout: Duration::from_secs(5), connect_backoff: Duration::from_secs(1), max_idle_duration: Duration::from_secs(60), + max_total_connections: 100000, max_delay_ms: 0, drop_rate: 0.0, hang_rate: 0.0, @@ -349,13 +359,13 @@ impl PageserverClient { /// /// Get a client from the pool for this shard, also creating the pool if it doesn't exist. /// - async fn get_client(&self, shard: ShardIndex) -> client_cache::PooledClient { - let reused_pool: Option> = { + async fn get_client(&self, shard: ShardIndex) -> client_cache::PooledClient { + let reused_pool: Option>> = { let channels = self.channels.read().unwrap(); channels.get(&shard).cloned() }; - let usable_pool: Arc; + let usable_pool: Arc>; match reused_pool { Some(pool) => { let pooled_client = pool.get_client().await.unwrap(); @@ -365,17 +375,21 @@ impl PageserverClient { // Create a new pool using client_cache_options // declare new_pool - let new_pool: Arc; - new_pool = client_cache::ConnectionPool::new( - self.shard_map.get(&shard).unwrap(), - self.client_cache_options.max_consumers, - self.client_cache_options.error_threshold, - self.client_cache_options.connect_timeout, - self.client_cache_options.connect_backoff, - self.client_cache_options.max_idle_duration, + let new_pool: Arc>; + let channel_fact = Arc::new(client_cache::ChannelFactory::new( + self.shard_map.get(&shard).unwrap().clone(), self.client_cache_options.max_delay_ms, self.client_cache_options.drop_rate, self.client_cache_options.hang_rate, + )); + new_pool = client_cache::ConnectionPool::new( + channel_fact, + self.client_cache_options.connect_timeout, + self.client_cache_options.connect_backoff, + self.client_cache_options.max_consumers, + self.client_cache_options.error_threshold, + self.client_cache_options.max_idle_duration, + self.client_cache_options.max_total_connections, self.aggregate_metrics.clone(), ); let mut write_pool = self.channels.write().unwrap(); @@ -391,7 +405,7 @@ impl PageserverClient { /// Inject tenant_id, timeline_id and authentication token to all pageserver requests. #[derive(Clone)] -struct AuthInterceptor { +pub struct AuthInterceptor { tenant_id: AsciiMetadataValue, shard_id: Option, timeline_id: AsciiMetadataValue, @@ -400,7 +414,7 @@ struct AuthInterceptor { } impl AuthInterceptor { - fn new(tenant_id: &str, timeline_id: &str, auth_token: Option<&str>) -> Self { + pub fn new(tenant_id: &str, timeline_id: &str, auth_token: Option<&str>) -> Self { Self { tenant_id: tenant_id.parse().expect("could not parse tenant id"), shard_id: None, diff --git a/pageserver/client_grpc/src/request_tracker.rs b/pageserver/client_grpc/src/request_tracker.rs new file mode 100644 index 0000000000..118c455537 --- /dev/null +++ b/pageserver/client_grpc/src/request_tracker.rs @@ -0,0 +1,590 @@ + +// +// API Visible to the spawner, just a function call that is async +// +use std::sync::Arc; +use crate::client_cache; +use pageserver_page_api::GetPageRequest; +use pageserver_page_api::GetPageResponse; +use pageserver_page_api::*; +use pageserver_page_api::proto; +use crate::client_cache::ConnectionPool; +use crate::client_cache::ChannelFactory; +use crate::AuthInterceptor; +use tonic::{transport::{Channel}, Request}; +use crate::ClientCacheOptions; +use crate::PageserverClientAggregateMetrics; +use tokio::sync::Mutex; +use std::sync::atomic::{AtomicU64, Ordering}; + +use utils::shard::ShardIndex; + +use tokio_stream::wrappers::ReceiverStream; +use pageserver_page_api::proto::PageServiceClient; + +use tonic::{ + Status, + Code, +}; + +use async_trait::async_trait; +use std::time::Duration; + +use client_cache::PooledItemFactory; +//use tracing::info; +// +// A mock stream pool that just returns a sending channel, and whenever a GetPageRequest +// comes in on that channel, it randomly sleeps before sending a GetPageResponse +// + +#[derive(Clone)] +pub struct StreamReturner { + sender: tokio::sync::mpsc::Sender, + sender_hashmap: Arc>>>>, +} +pub struct MockStreamFactory { +} + +impl MockStreamFactory { + pub fn new() -> Self { + MockStreamFactory { + } + } +} +#[async_trait] +impl PooledItemFactory for MockStreamFactory { + async fn create(&self, _connect_timeout: Duration) -> Result, tokio::time::error::Elapsed> { + let (sender, mut receiver) = tokio::sync::mpsc::channel::(1000); + // Create a StreamReturner that will send requests to the receiver channel + let stream_returner = StreamReturner { + sender: sender.clone(), + sender_hashmap: Arc::new(Mutex::new(std::collections::HashMap::new())), + }; + + let map : Arc>>>> + = Arc::clone(&stream_returner.sender_hashmap); + tokio::spawn(async move { + while let Some(request) = receiver.recv().await { + + // Break out of the loop with 1% chance + if rand::random::() < 0.001 { + break; + } + // Generate a random number between 0 and 100 + // Simulate some processing time + let mapclone = Arc::clone(&map); + tokio::spawn(async move { + let sleep_ms = rand::random::() % 100; + tokio::time::sleep(tokio::time::Duration::from_millis(sleep_ms)).await; + let response = proto::GetPageResponse { + request_id: request.request_id, + ..Default::default() + }; + // look up stream in hash map + let mut hashmap = mapclone.lock().await; + if let Some(sender) = hashmap.get(&request.request_id) { + // Send the response to the original request sender + if let Err(e) = sender.send(Ok(response.clone())).await { + eprintln!("Failed to send response: {}", e); + } + hashmap.remove(&request.request_id); + } else { + eprintln!("No sender found for request ID: {}", request.request_id); + } + }); + } + // Close every sender stream in the hashmap + let hashmap = map.lock().await; + for sender in hashmap.values() { + let error = Status::new(Code::Unknown, "Stream closed"); + if let Err(e) = sender.send(Err(error)).await { + eprintln!("Failed to send close response: {}", e); + } + } + }); + + Ok(Ok(stream_returner)) + } +} + + +pub struct StreamFactory { + connection_pool: Arc>, + auth_interceptor: AuthInterceptor, + shard: ShardIndex, +} + +impl StreamFactory { + pub fn new( + connection_pool: Arc>, + auth_interceptor: AuthInterceptor, + shard: ShardIndex, + ) -> Self { + StreamFactory { + connection_pool, + auth_interceptor, + shard, + } + } +} + +#[async_trait] +impl PooledItemFactory for StreamFactory { + async fn create(&self, _connect_timeout: Duration) -> + Result, tokio::time::error::Elapsed> + { + let pool_clone : Arc> = Arc::clone(&self.connection_pool); + let pooled_client = pool_clone.get_client().await; + let channel = pooled_client.unwrap().channel(); + let mut client = + PageServiceClient::with_interceptor(channel, self.auth_interceptor.for_shard(self.shard)); + + let (sender, receiver) = tokio::sync::mpsc::channel::(1000); + let outbound = ReceiverStream::new(receiver); + + let client_resp = client + .get_pages(Request::new(outbound)) + .await; + + match client_resp { + Err(status) => { + // TODO: Convert this error correctly + Ok(Err(tonic::Status::new( + status.code(), + format!("Failed to connect to pageserver: {}", status.message()), + ))) + } + Ok(resp) => { + let stream_returner = StreamReturner { + sender: sender.clone(), + sender_hashmap: Arc::new(Mutex::new(std::collections::HashMap::new())), + }; + let map : Arc>>>> + = Arc::clone(&stream_returner.sender_hashmap); + + tokio::spawn(async move { + + let map_clone = Arc::clone(&map); + let mut inner = resp.into_inner(); + loop { + + let resp = inner.message().await; + if !resp.is_ok() { + break; // Exit the loop if no more messages + } + let response = resp.unwrap().unwrap(); + + // look up stream in hash map + let mut hashmap = map_clone.lock().await; + if let Some(sender) = hashmap.get(&response.request_id) { + // Send the response to the original request sender + if let Err(e) = sender.send(Ok(response.clone())).await { + eprintln!("Failed to send response: {}", e); + } + hashmap.remove(&response.request_id); + } else { + eprintln!("No sender found for request ID: {}", response.request_id); + } + } + // Close every sender stream in the hashmap + let hashmap = map_clone.lock().await; + for sender in hashmap.values() { + let error = Status::new(Code::Unknown, "Stream closed"); + if let Err(e) = sender.send(Err(error)).await { + eprintln!("Failed to send close response: {}", e); + } + } + }); + + Ok(Ok(stream_returner)) + } + } + } +} + +#[derive(Clone)] +pub struct RequestTracker { + cur_id: Arc, + stream_pool: Arc>, + unary_pool: Arc>, + auth_interceptor: AuthInterceptor, + shard: ShardIndex, +} + +impl RequestTracker { + pub fn new(stream_pool: Arc>, + unary_pool: Arc>, + auth_interceptor: AuthInterceptor, + shard: ShardIndex, + ) -> Self { + let cur_id = Arc::new(AtomicU64::new(0)); + + RequestTracker { + cur_id: cur_id.clone(), + stream_pool: stream_pool, + unary_pool: unary_pool, + auth_interceptor: auth_interceptor, + shard: shard.clone() + } + } + + pub async fn send_process_check_rel_exists_request( + &self, + req: CheckRelExistsRequest, + ) -> Result { + loop { + let unary_pool = Arc::clone(&self.unary_pool); + let pooled_client = unary_pool.get_client().await.unwrap(); + let channel = pooled_client.channel(); + let mut ps_client = PageServiceClient::with_interceptor(channel, self.auth_interceptor.for_shard(self.shard)); + let request = proto::CheckRelExistsRequest::from(req.clone()); + let response = ps_client.check_rel_exists(tonic::Request::new(request)).await; + + match response { + Err(status) => { + pooled_client.finish(Err(status.clone())).await; // Pass error to finish + continue; + } + Ok(resp) => { + pooled_client.finish(Ok(())).await; // Pass success to finish + return Ok(resp.get_ref().exists); + } + } + } + } + + pub async fn send_process_get_rel_size_request( + &self, + req: GetRelSizeRequest, + ) -> Result { + loop { + // Current sharding model assumes that all metadata is present only at shard 0. + let unary_pool = Arc::clone(&self.unary_pool); + let pooled_client = unary_pool.get_client().await.unwrap(); + let channel = pooled_client.channel(); + let mut ps_client = PageServiceClient::with_interceptor(channel, self.auth_interceptor.for_shard(self.shard)); + + let request = proto::GetRelSizeRequest::from(req.clone()); + let response = ps_client.get_rel_size(tonic::Request::new(request)).await; + + match response { + Err(status) => { + pooled_client.finish(Err(status.clone())).await; // Pass error to finish + continue; + } + Ok(resp) => { + pooled_client.finish(Ok(())).await; // Pass success to finish + return Ok(resp.get_ref().num_blocks); + } + } + + } + } + + pub async fn send_process_get_dbsize_request( + &self, + req: GetDbSizeRequest, + ) -> Result { + loop { + // Current sharding model assumes that all metadata is present only at shard 0. + let unary_pool = Arc::clone(&self.unary_pool); + let pooled_client = unary_pool.get_client().await.unwrap();let channel = pooled_client.channel(); + let mut ps_client = PageServiceClient::with_interceptor(channel, self.auth_interceptor.for_shard(self.shard)); + + let request = proto::GetDbSizeRequest::from(req.clone()); + let response = ps_client.get_db_size(tonic::Request::new(request)).await; + + match response { + Err(status) => { + pooled_client.finish(Err(status.clone())).await; // Pass error to finish + continue; + } + Ok(resp) => { + pooled_client.finish(Ok(())).await; // Pass success to finish + return Ok(resp.get_ref().num_bytes); + } + } + + } + } + + pub async fn send_getpage_request( + &mut self, + req: GetPageRequest, + ) -> Result { + loop { + let mut request = req.clone(); + // Increment cur_id + //let request_id = self.cur_id.fetch_add(1, Ordering::SeqCst) + 1; + let request_id = request.request_id; + let response_sender: tokio::sync::mpsc::Sender>; + let mut response_receiver: tokio::sync::mpsc::Receiver>; + + (response_sender, response_receiver) = tokio::sync::mpsc::channel(1); + //request.request_id = request_id; + + // Get a stream from the stream pool + let pool_clone = Arc::clone(&self.stream_pool); + let sender_stream_pool = pool_clone.get_client().await; + let stream_returner = match sender_stream_pool { + Ok(stream_ret) => stream_ret, + Err(_e) => { + // retry + continue; + } + }; + let returner = stream_returner.channel(); + let map = returner.sender_hashmap.clone(); + // Insert the response sender into the hashmap + { + let mut map_inner = map.lock().await; + map_inner.insert(request_id, response_sender); + } + let sent = returner.sender.send(proto::GetPageRequest::from(request)) + .await; + + if let Err(_e) = sent { + // Remove the request from the map if sending failed + { + let mut map_inner = map.lock().await; + // remove from hashmap + map_inner.remove(&request_id); + } + stream_returner.finish(Err(Status::new(Code::Unknown, + "Failed to send request"))).await; + continue; + } + + let response: Option>; + response = response_receiver.recv().await; + match response { + Some (resp) => { + match resp { + Err(_status) => { + // Handle the case where the response was not received + stream_returner.finish(Err(Status::new(Code::Unknown, + "Failed to receive response"))).await; + continue; + }, + Ok(resp) => { + stream_returner.finish(Result::Ok(())).await; + return Ok(resp.clone().into()); + } + } + } + None => { + // Handle the case where the response channel was closed + stream_returner.finish(Err(Status::new(Code::Unknown, + "Response channel closed"))).await; + continue; + } + } + } + } +} + +struct ShardedRequestTrackerInner { + // Hashmap of shard index to RequestTracker + trackers: std::collections::HashMap, +} +pub struct ShardedRequestTracker { + inner: Arc>, + tcp_client_cache_options: ClientCacheOptions, + stream_client_cache_options: ClientCacheOptions, +} + +// +// TODO: Functions in the ShardedRequestTracker should be able to timeout and +// cancel a reqeust. The request should return an error if it is cancelled. +// +impl ShardedRequestTracker { + pub fn new() -> Self { + // + // Default configuration for the client. These could be added to a config file + // + let tcp_client_cache_options = ClientCacheOptions { + max_delay_ms: 0, + drop_rate: 0.0, + hang_rate: 0.0, + connect_timeout: Duration::from_secs(1), + connect_backoff: Duration::from_millis(100), + max_consumers: 8, // Streams per connection + error_threshold: 10, + max_idle_duration: Duration::from_secs(5), + max_total_connections: 8, + }; + let stream_client_cache_options = ClientCacheOptions { + max_delay_ms: 0, + drop_rate: 0.0, + hang_rate: 0.0, + connect_timeout: Duration::from_secs(1), + connect_backoff: Duration::from_millis(100), + max_consumers: 64, // Requests per stream + error_threshold: 10, + max_idle_duration: Duration::from_secs(5), + max_total_connections: 64, // Total allowable number of streams + }; + ShardedRequestTracker { + inner: Arc::new(Mutex::new(ShardedRequestTrackerInner { + trackers: std::collections::HashMap::new(), + })), + tcp_client_cache_options, + stream_client_cache_options, + } + } + + pub async fn update_shard_map(&self, + shard_urls: std::collections::HashMap, + metrics: Option>, + tenant_id: String, timeline_id: String, auth_str: Option<&str>) { + + + let mut trackers = std::collections::HashMap::new(); + for (shard, endpoint_url) in shard_urls { + // + // Create a pool of streams for streaming get_page requests + // + let channel_fact : Arc + Send + Sync> = Arc::new(ChannelFactory::new( + endpoint_url.clone(), + self.tcp_client_cache_options.max_delay_ms, + self.tcp_client_cache_options.drop_rate, + self.tcp_client_cache_options.hang_rate, + )); + let new_pool: Arc>; + new_pool = ConnectionPool::new( + Arc::clone(&channel_fact), + self.tcp_client_cache_options.connect_timeout, + self.tcp_client_cache_options.connect_backoff, + self.tcp_client_cache_options.max_consumers, + self.tcp_client_cache_options.error_threshold, + self.tcp_client_cache_options.max_idle_duration, + self.tcp_client_cache_options.max_total_connections, + metrics.clone(), + ); + + let auth_interceptor = AuthInterceptor::new(tenant_id.as_str(), + timeline_id.as_str(), + auth_str); + + let stream_pool = ConnectionPool::::new( + Arc::new(StreamFactory::new(new_pool.clone(), + auth_interceptor.clone(), ShardIndex::unsharded())), + self.stream_client_cache_options.connect_timeout, + self.stream_client_cache_options.connect_backoff, + self.stream_client_cache_options.max_consumers, + self.stream_client_cache_options.error_threshold, + self.stream_client_cache_options.max_idle_duration, + self.stream_client_cache_options.max_total_connections, + metrics.clone(), + ); + + // + // Create a client pool for unary requests + // + + let unary_pool: Arc>; + unary_pool = ConnectionPool::new( + Arc::clone(&channel_fact), + self.tcp_client_cache_options.connect_timeout, + self.tcp_client_cache_options.connect_backoff, + self.tcp_client_cache_options.max_consumers, + self.tcp_client_cache_options.error_threshold, + self.tcp_client_cache_options.max_idle_duration, + self.tcp_client_cache_options.max_total_connections, + metrics.clone() + ); + // + // Create a new RequestTracker for this shard + // + let new_tracker = RequestTracker::new(stream_pool, unary_pool, auth_interceptor, shard); + trackers.insert(shard, new_tracker); + } + let mut inner = self.inner.lock().await; + inner.trackers = trackers; + } + + pub async fn get_page( + &self, + req: GetPageRequest, + ) -> Result { + + // Get shard index from the request + let shard_index = ShardIndex::unsharded(); + let inner = self.inner.lock().await; + let mut tracker : RequestTracker; + if let Some(t) = inner.trackers.get(&shard_index) { + tracker = t.clone(); + } else { + return Err(tonic::Status::not_found(format!("Shard {} not found", shard_index))); + } + drop(inner); + // Call the send_getpage_request method on the tracker + let response = tracker.send_getpage_request(req).await; + match response { + Ok(resp) => Ok(resp), + Err(e) => Err(tonic::Status::unknown(format!("Failed to get page: {}", e))), + } + } + pub async fn process_get_dbsize_request( + &self, + request: GetDbSizeRequest, + ) -> Result { + let shard_index = ShardIndex::unsharded(); + let inner = self.inner.lock().await; + let mut tracker: RequestTracker; + if let Some(t) = inner.trackers.get(&shard_index) { + tracker = t.clone(); + } else { + return Err(tonic::Status::not_found(format!("Shard {} not found", shard_index))); + } + drop(inner); // Release the lock before calling send_process_get_dbsize_request + // Call the send_process_get_dbsize_request method on the tracker + let response = tracker.send_process_get_dbsize_request(request).await; + match response { + Ok(resp) => Ok(resp), + Err(e) => Err(e), + } + } + + pub async fn process_get_rel_size_request( + &self, + request: GetRelSizeRequest, + ) -> Result { + let shard_index = ShardIndex::unsharded(); + let inner = self.inner.lock().await; + let mut tracker: RequestTracker; + if let Some(t) = inner.trackers.get(&shard_index) { + tracker = t.clone(); + } else { + return Err(tonic::Status::not_found(format!("Shard {} not found", shard_index))); + } + drop(inner); // Release the lock before calling send_process_get_rel_size_request + // Call the send_process_get_rel_size_request method on the tracker + let response = tracker.send_process_get_rel_size_request(request).await; + match response { + Ok(resp) => Ok(resp), + Err(e) => Err(e), + } + } + + pub async fn process_check_rel_exists_request( + &self, + request: CheckRelExistsRequest, + ) -> Result { + let shard_index = ShardIndex::unsharded(); + let inner = self.inner.lock().await; + let mut tracker: RequestTracker; + if let Some(t) = inner.trackers.get(&shard_index) { + tracker = t.clone(); + } else { + return Err(tonic::Status::not_found(format!("Shard {} not found", shard_index))); + } + drop(inner); // Release the lock before calling send_process_check_rel_exists_request + // Call the send_process_check_rel_exists_request method on the tracker + let response = tracker.send_process_check_rel_exists_request(request).await; + match response { + Ok(resp) => Ok(resp), + Err(e) => Err(e), + } + } +} diff --git a/pageserver/page_api/src/model.rs b/pageserver/page_api/src/model.rs index 1a08d04cc1..b3eeaece22 100644 --- a/pageserver/page_api/src/model.rs +++ b/pageserver/page_api/src/model.rs @@ -487,6 +487,7 @@ impl From for i32 { // Fetches the size of a relation at a given LSN, as # of blocks. Only valid on shard 0, other // shards will error. +#[derive(Clone)] pub struct GetRelSizeRequest { pub read_lsn: ReadLsn, pub rel: RelTag, diff --git a/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs b/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs index eebf618fce..6a25d18809 100644 --- a/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs +++ b/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs @@ -1,10 +1,11 @@ -use std::collections::{HashMap, HashSet, VecDeque}; +use std::collections::{HashSet, HashMap, VecDeque}; use std::future::Future; use std::num::NonZeroUsize; use std::pin::Pin; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; +use std::io::Error; use anyhow::Context; use async_trait::async_trait; @@ -23,6 +24,8 @@ use tracing::info; use utils::id::TenantTimelineId; use utils::lsn::Lsn; +use tonic::transport::Channel; + use axum::Router; use axum::body::Body; use axum::extract::State; @@ -427,6 +430,7 @@ async fn main_impl( .await .unwrap(), ), + }; run_worker(args, client, ss, cancel, rps_period, ranges, weights).await }) @@ -694,6 +698,7 @@ impl Client for LibpqClient { struct GrpcClient { req_tx: tokio::sync::mpsc::Sender, resp_rx: tonic::Streaming, + start_times: Vec, } impl GrpcClient { @@ -717,6 +722,7 @@ impl GrpcClient { Ok(Self { req_tx, resp_rx: resp_stream, + start_times: Vec::new(), }) } } @@ -741,6 +747,7 @@ impl Client for GrpcClient { rel: Some(rel.into()), block_number: blks, }; + self.start_times.push(Instant::now()); self.req_tx.send(req).await?; Ok(()) } @@ -755,3 +762,4 @@ impl Client for GrpcClient { Ok((resp.request_id, resp.page_image)) } } + diff --git a/pgxn/neon/communicator/src/worker_process/main_loop.rs b/pgxn/neon/communicator/src/worker_process/main_loop.rs index c6ce6c4197..e190193ae5 100644 --- a/pgxn/neon/communicator/src/worker_process/main_loop.rs +++ b/pgxn/neon/communicator/src/worker_process/main_loop.rs @@ -12,7 +12,7 @@ use crate::integrated_cache::{CacheResult, IntegratedCacheWriteAccess}; use crate::neon_request::{CGetPageVRequest, CPrefetchVRequest}; use crate::neon_request::{NeonIORequest, NeonIOResult}; use crate::worker_process::in_progress_ios::{RequestInProgressKey, RequestInProgressTable}; -use pageserver_client_grpc::PageserverClient; +use pageserver_client_grpc::request_tracker::ShardedRequestTracker; use pageserver_page_api as page_api; use metrics::{IntCounter, IntCounterVec}; @@ -30,7 +30,7 @@ use utils::lsn::Lsn; pub struct CommunicatorWorkerProcessStruct<'a> { neon_request_slots: &'a [NeonIOHandle], - pageserver_client: PageserverClient, + request_tracker: ShardedRequestTracker, pub(crate) cache: IntegratedCacheWriteAccess<'a>, @@ -74,6 +74,7 @@ pub(super) async fn init( initial_file_cache_size: u64, file_cache_path: Option, ) -> CommunicatorWorkerProcessStruct<'static> { + info!("Test log message"); let last_lsn = get_request_lsn(); let file_cache = if let Some(path) = file_cache_path { @@ -97,7 +98,12 @@ pub(super) async fn init( .integrated_cache_init_struct .worker_process_init(last_lsn, file_cache); - let pageserver_client = PageserverClient::new(&tenant_id, &timeline_id, &auth_token, shard_map); + let mut request_tracker = ShardedRequestTracker::new(); + request_tracker.update_shard_map(shard_map, + None, + tenant_id, + timeline_id, + auth_token.as_deref()).await; let request_counters = IntCounterVec::new( metrics::core::Opts::new( @@ -148,7 +154,7 @@ pub(super) async fn init( CommunicatorWorkerProcessStruct { neon_request_slots: cis.neon_request_slots, - pageserver_client, + request_tracker, cache, submission_pipe_read_fd: cis.submission_pipe_read_fd, next_request_id: AtomicU64::new(1), @@ -257,7 +263,7 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> { }; match self - .pageserver_client + .request_tracker .process_check_rel_exists_request(page_api::CheckRelExistsRequest { read_lsn: self.request_lsns(not_modified_since), rel, @@ -291,7 +297,7 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> { let read_lsn = self.request_lsns(not_modified_since); match self - .pageserver_client + .request_tracker .process_get_rel_size_request(page_api::GetRelSizeRequest { read_lsn, rel: rel.clone(), @@ -344,7 +350,7 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> { }; match self - .pageserver_client + .request_tracker .process_get_dbsize_request(page_api::GetDbSizeRequest { read_lsn: self.request_lsns(not_modified_since), db_oid: req.db_oid, @@ -467,7 +473,7 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> { // TODO: Use batched protocol for (blkno, _lsn, dest, _guard) in cache_misses.iter() { match self - .pageserver_client + .request_tracker .get_page(page_api::GetPageRequest { request_id: self.next_request_id.fetch_add(1, Ordering::Relaxed), request_class: page_api::GetPageClass::Normal, @@ -477,11 +483,11 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> { }) .await { - Ok(page_images) => { + Ok(resp) => { // Write the received page image directly to the shared memory location // that the backend requested. - assert!(page_images.len() == 1); - let page_image = page_images[0].clone(); + assert!(resp.page_images.len() == 1); + let page_image = resp.page_images[0].clone(); let src: &[u8] = page_image.as_ref(); let len = std::cmp::min(src.len(), dest.bytes_total() as usize); unsafe { @@ -545,7 +551,7 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> { // TODO: Use batched protocol for (blkno, _lsn, _guard) in cache_misses.iter() { match self - .pageserver_client + .request_tracker .get_page(page_api::GetPageRequest { request_id: self.next_request_id.fetch_add(1, Ordering::Relaxed), request_class: page_api::GetPageClass::Prefetch, @@ -555,13 +561,13 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> { }) .await { - Ok(page_images) => { + Ok(resp) => { trace!( "prefetch completed, remembering blk {} in rel {:?} in LFC", *blkno, rel ); - assert!(page_images.len() == 1); - let page_image = page_images[0].clone(); + assert!(resp.page_images.len() == 1); + let page_image = resp.page_images[0].clone(); self.cache .remember_page(&rel, *blkno, page_image, not_modified_since, false) .await;