// examples/load_test.rs, generated by AI use std::collections::{HashMap, HashSet}; use std::sync::{ Arc, Mutex, atomic::{AtomicU64, AtomicUsize, Ordering}, }; use std::time::{Duration, Instant}; use rand::Rng; use tokio::task; use tokio::time::sleep; use tonic::Status; // Pull in your ConnectionPool and PooledItemFactory from the pageserver_client_grpc crate. // Adjust these paths if necessary. use pageserver_client_grpc::client_cache::ConnectionPool; use pageserver_client_grpc::client_cache::PooledItemFactory; // -------------------------------------- // GLOBAL COUNTERS FOR “CREATED” / “DROPPED” MockConnections // -------------------------------------- static CREATED: AtomicU64 = AtomicU64::new(0); static DROPPED: AtomicU64 = AtomicU64::new(0); // -------------------------------------- // MockConnection + Factory // -------------------------------------- #[derive(Debug)] pub struct MockConnection { pub id: u64, } impl Clone for MockConnection { fn clone(&self) -> Self { // Cloning a MockConnection does NOT count as “creating” a brand‐new connection, // so we do NOT bump CREATED here. We only bump CREATED in the factory’s `create()`. CREATED.fetch_add(1, Ordering::Relaxed); MockConnection { id: self.id } } } impl Drop for MockConnection { fn drop(&mut self) { // When a MockConnection actually gets dropped, bump the counter. DROPPED.fetch_add(1, Ordering::SeqCst); } } #[derive(Default)] pub struct MockConnectionFactory { counter: AtomicU64, } #[async_trait::async_trait] impl PooledItemFactory for MockConnectionFactory { /// The trait on ConnectionPool expects: /// async fn create(&self, timeout: Duration) /// -> Result, tokio::time::error::Elapsed>; /// /// On success: Ok(Ok(MockConnection)) /// On a simulated “gRPC” failure: Ok(Err(Status::…)) /// On a transport/factory error: Err(Box<…>) async fn create( &self, _timeout: Duration, ) -> Result, tokio::time::error::Elapsed> { // Simulate connection creation immediately succeeding. CREATED.fetch_add(1, Ordering::SeqCst); let next_id = self.counter.fetch_add(1, Ordering::Relaxed); Ok(Ok(MockConnection { id: next_id })) } } // -------------------------------------- // CLIENT WORKER // -------------------------------------- // // Each worker repeatedly calls `pool.get_client().await`. When it succeeds, we: // 1. Lock the shared Mutex>> to fetch/insert an Arc for this conn_id. // 2. Lock the shared Mutex> to record this conn_id as “seen.” // 3. Drop both locks, then atomically increment that counter and assert it ≤ max_consumers. // 4. Sleep 10–100 ms to simulate “work.” // 5. Atomically decrement the counter. // 6. Call `pooled.finish(Ok(()))` to return to the pool. async fn client_worker( pool: Arc>, usage_map: Arc>>>, seen_set: Arc>>, max_consumers: usize, worker_id: usize, ) { for iteration in 0..10 { match pool.clone().get_client().await { Ok(pooled) => { let conn: MockConnection = pooled.channel(); let conn_id = conn.id; // 1. Fetch or insert the Arc for this conn_id: let counter_arc: Arc = { let mut guard = usage_map.lock().unwrap(); guard .entry(conn_id) .or_insert_with(|| Arc::new(AtomicUsize::new(0))) .clone() // MutexGuard is dropped here }; // 2. Record this conn_id in the shared HashSet of “seen” IDs: { let mut seen_guard = seen_set.lock().unwrap(); seen_guard.insert(conn_id); // MutexGuard is dropped immediately } // 3. Atomically bump the count for this connection ID let prev = counter_arc.fetch_add(1, Ordering::SeqCst); let current = prev + 1; assert!( current <= max_consumers, "Connection {conn_id} exceeded max_consumers (got {current})", ); println!( "[worker {worker_id}][iter {iteration}] got MockConnection id={conn_id} ({current} concurrent)", ); // 4. Simulate some work (10–100 ms) let delay_ms = rand::thread_rng().gen_range(10..100); sleep(Duration::from_millis(delay_ms)).await; // 5. Decrement the usage counter let prev2 = counter_arc.fetch_sub(1, Ordering::SeqCst); let after = prev2 - 1; println!( "[worker {worker_id}][iter {iteration}] returning MockConnection id={conn_id} (now {after} remain)", ); // 6. Return to the pool (mark success) pooled.finish(Ok(())).await; } Err(status) => { eprintln!( "[worker {worker_id}][iter {iteration}] failed to get client: {status:?}", ); } } // Small random pause before next iteration to spread out load let pause = rand::thread_rng().gen_range(0..20); sleep(Duration::from_millis(pause)).await; } } #[tokio::main(flavor = "multi_thread", worker_threads = 8)] async fn main() { // -------------------------------------- // 1. Create factory and shared instrumentation // -------------------------------------- let factory = Arc::new(MockConnectionFactory::default()); // Shared map: connection ID → Arc let usage_map: Arc>>> = Arc::new(Mutex::new(HashMap::new())); // Shared set: record each unique connection ID we actually saw let seen_set: Arc>> = Arc::new(Mutex::new(HashSet::new())); // -------------------------------------- // 2. Pool parameters // -------------------------------------- let connect_timeout = Duration::from_millis(500); let connect_backoff = Duration::from_millis(100); let max_consumers = 100; // test limit let error_threshold = 2; // mock never fails let max_idle_duration = Duration::from_secs(2); let max_total_connections = 3; let aggregate_metrics = None; let pool: Arc> = ConnectionPool::new( factory, connect_timeout, connect_backoff, max_consumers, error_threshold, max_idle_duration, max_total_connections, aggregate_metrics, ); // -------------------------------------- // 3. Spawn worker tasks // -------------------------------------- let num_workers = 10000; let mut handles = Vec::with_capacity(num_workers); let start_time = Instant::now(); for worker_id in 0..num_workers { let pool_clone = Arc::clone(&pool); let usage_clone = Arc::clone(&usage_map); let seen_clone = Arc::clone(&seen_set); let mc = max_consumers; let handle = task::spawn(async move { client_worker(pool_clone, usage_clone, seen_clone, mc, worker_id).await; }); handles.push(handle); } // -------------------------------------- // 4. Wait for workers to finish // -------------------------------------- for handle in handles { let _ = handle.await; } let elapsed = Instant::now().duration_since(start_time); println!("All {num_workers} workers completed in {elapsed:?}"); // -------------------------------------- // 5. Print the total number of unique connections seen so far // -------------------------------------- let unique_count = { let seen_guard = seen_set.lock().unwrap(); seen_guard.len() }; println!("Total unique connections used by workers: {unique_count}"); // -------------------------------------- // 6. Sleep so the background sweeper can run (max_idle_duration = 2 s) // -------------------------------------- sleep(Duration::from_secs(3)).await; // -------------------------------------- // 7. Shutdown the pool // -------------------------------------- let shutdown_pool = Arc::clone(&pool); shutdown_pool.shutdown().await; println!("Pool.shutdown() returned."); // -------------------------------------- // 8. Verify that no background task still holds an Arc clone of `pool`. // If any task is still alive (sweeper/create_connection), strong_count > 1. // -------------------------------------- sleep(Duration::from_secs(1)).await; // give tasks time to exit let sc = Arc::strong_count(&pool); assert!( sc == 1, "Pool tasks did not all terminate: Arc::strong_count = {sc} (expected 1)", ); println!("Verified: all pool tasks have terminated (strong_count == 1)."); // -------------------------------------- // 9. Verify no MockConnection was leaked: // CREATED must equal DROPPED. // -------------------------------------- let created = CREATED.load(Ordering::SeqCst); let dropped = DROPPED.load(Ordering::SeqCst); assert!( created == dropped, "Leaked connections: created={created} but dropped={dropped}", ); println!("Verified: no connections leaked (created = {created}, dropped = {dropped})."); // -------------------------------------- // 10. Because `client_worker` asserted inside that no connection // ever exceeded `max_consumers`, reaching this point means that check passed. // -------------------------------------- println!("All per-connection usage stayed within max_consumers = {max_consumers}."); println!("Load test complete; exiting cleanly."); }