mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-08 14:02:55 +00:00
295 lines
10 KiB
Rust
295 lines
10 KiB
Rust
// examples/load_test.rs, generated by AI
|
||
|
||
use std::collections::{HashMap, HashSet};
|
||
use std::sync::{
|
||
Arc, Mutex,
|
||
atomic::{AtomicU64, AtomicUsize, Ordering},
|
||
};
|
||
use std::time::{Duration, Instant};
|
||
|
||
use rand::Rng;
|
||
use tokio::task;
|
||
use tokio::time::sleep;
|
||
use tonic::Status;
|
||
|
||
// Pull in your ConnectionPool and PooledItemFactory from the pageserver_client_grpc crate.
|
||
// Adjust these paths if necessary.
|
||
use pageserver_client_grpc::client_cache::ConnectionPool;
|
||
use pageserver_client_grpc::client_cache::PooledItemFactory;
|
||
|
||
// --------------------------------------
|
||
// GLOBAL COUNTERS FOR “CREATED” / “DROPPED” MockConnections
|
||
// --------------------------------------
|
||
static CREATED: AtomicU64 = AtomicU64::new(0);
|
||
static DROPPED: AtomicU64 = AtomicU64::new(0);
|
||
|
||
// --------------------------------------
|
||
// MockConnection + Factory
|
||
// --------------------------------------
|
||
|
||
#[derive(Debug)]
|
||
pub struct MockConnection {
|
||
pub id: u64,
|
||
}
|
||
|
||
impl Clone for MockConnection {
|
||
fn clone(&self) -> Self {
|
||
// Cloning a MockConnection does NOT count as “creating” a brand‐new connection,
|
||
// so we do NOT bump CREATED here. We only bump CREATED in the factory’s `create()`.
|
||
CREATED.fetch_add(1, Ordering::Relaxed);
|
||
MockConnection { id: self.id }
|
||
}
|
||
}
|
||
|
||
impl Drop for MockConnection {
|
||
fn drop(&mut self) {
|
||
// When a MockConnection actually gets dropped, bump the counter.
|
||
DROPPED.fetch_add(1, Ordering::SeqCst);
|
||
}
|
||
}
|
||
|
||
pub struct MockConnectionFactory {
|
||
counter: AtomicU64,
|
||
}
|
||
|
||
impl MockConnectionFactory {
|
||
pub fn new() -> Self {
|
||
MockConnectionFactory {
|
||
counter: AtomicU64::new(1),
|
||
}
|
||
}
|
||
}
|
||
|
||
#[async_trait::async_trait]
|
||
impl PooledItemFactory<MockConnection> for MockConnectionFactory {
|
||
/// The trait on ConnectionPool expects:
|
||
/// async fn create(&self, timeout: Duration)
|
||
/// -> Result<Result<MockConnection, Status>, tokio::time::error::Elapsed>;
|
||
///
|
||
/// On success: Ok(Ok(MockConnection))
|
||
/// On a simulated “gRPC” failure: Ok(Err(Status::…))
|
||
/// On a transport/factory error: Err(Box<…>)
|
||
async fn create(
|
||
&self,
|
||
_timeout: Duration,
|
||
) -> Result<Result<MockConnection, Status>, tokio::time::error::Elapsed> {
|
||
// Simulate connection creation immediately succeeding.
|
||
CREATED.fetch_add(1, Ordering::SeqCst);
|
||
let next_id = self.counter.fetch_add(1, Ordering::Relaxed);
|
||
Ok(Ok(MockConnection { id: next_id }))
|
||
}
|
||
}
|
||
|
||
// --------------------------------------
|
||
// CLIENT WORKER
|
||
// --------------------------------------
|
||
//
|
||
// Each worker repeatedly calls `pool.get_client().await`. When it succeeds, we:
|
||
// 1. Lock the shared Mutex<HashMap<u64, Arc<AtomicUsize>>> to fetch/insert an Arc<AtomicUsize> for this conn_id.
|
||
// 2. Lock the shared Mutex<HashSet<u64>> to record this conn_id as “seen.”
|
||
// 3. Drop both locks, then atomically increment that counter and assert it ≤ max_consumers.
|
||
// 4. Sleep 10–100 ms to simulate “work.”
|
||
// 5. Atomically decrement the counter.
|
||
// 6. Call `pooled.finish(Ok(()))` to return to the pool.
|
||
|
||
async fn client_worker(
|
||
pool: Arc<ConnectionPool<MockConnection>>,
|
||
usage_map: Arc<Mutex<HashMap<u64, Arc<AtomicUsize>>>>,
|
||
seen_set: Arc<Mutex<HashSet<u64>>>,
|
||
max_consumers: usize,
|
||
worker_id: usize,
|
||
) {
|
||
for iteration in 0..10 {
|
||
match pool.clone().get_client().await {
|
||
Ok(pooled) => {
|
||
let conn: MockConnection = pooled.channel();
|
||
let conn_id = conn.id;
|
||
|
||
// 1. Fetch or insert the Arc<AtomicUsize> for this conn_id:
|
||
let counter_arc: Arc<AtomicUsize> = {
|
||
let mut guard = usage_map.lock().unwrap();
|
||
guard
|
||
.entry(conn_id)
|
||
.or_insert_with(|| Arc::new(AtomicUsize::new(0)))
|
||
.clone()
|
||
// MutexGuard is dropped here
|
||
};
|
||
|
||
// 2. Record this conn_id in the shared HashSet of “seen” IDs:
|
||
{
|
||
let mut seen_guard = seen_set.lock().unwrap();
|
||
seen_guard.insert(conn_id);
|
||
// MutexGuard is dropped immediately
|
||
}
|
||
|
||
// 3. Atomically bump the count for this connection ID
|
||
let prev = counter_arc.fetch_add(1, Ordering::SeqCst);
|
||
let current = prev + 1;
|
||
assert!(
|
||
current <= max_consumers,
|
||
"Connection {} exceeded max_consumers (got {})",
|
||
conn_id,
|
||
current
|
||
);
|
||
|
||
println!(
|
||
"[worker {}][iter {}] got MockConnection id={} ({} concurrent)",
|
||
worker_id, iteration, conn_id, current
|
||
);
|
||
|
||
// 4. Simulate some work (10–100 ms)
|
||
let delay_ms = rand::thread_rng().gen_range(10..100);
|
||
sleep(Duration::from_millis(delay_ms)).await;
|
||
|
||
// 5. Decrement the usage counter
|
||
let prev2 = counter_arc.fetch_sub(1, Ordering::SeqCst);
|
||
let after = prev2 - 1;
|
||
println!(
|
||
"[worker {}][iter {}] returning MockConnection id={} (now {} remain)",
|
||
worker_id, iteration, conn_id, after
|
||
);
|
||
|
||
// 6. Return to the pool (mark success)
|
||
pooled.finish(Ok(())).await;
|
||
}
|
||
Err(status) => {
|
||
eprintln!(
|
||
"[worker {}][iter {}] failed to get client: {:?}",
|
||
worker_id, iteration, status
|
||
);
|
||
}
|
||
}
|
||
|
||
// Small random pause before next iteration to spread out load
|
||
let pause = rand::thread_rng().gen_range(0..20);
|
||
sleep(Duration::from_millis(pause)).await;
|
||
}
|
||
}
|
||
|
||
#[tokio::main(flavor = "multi_thread", worker_threads = 8)]
|
||
async fn main() {
|
||
// --------------------------------------
|
||
// 1. Create factory and shared instrumentation
|
||
// --------------------------------------
|
||
let factory = Arc::new(MockConnectionFactory::new());
|
||
|
||
// Shared map: connection ID → Arc<AtomicUsize>
|
||
let usage_map: Arc<Mutex<HashMap<u64, Arc<AtomicUsize>>>> =
|
||
Arc::new(Mutex::new(HashMap::new()));
|
||
|
||
// Shared set: record each unique connection ID we actually saw
|
||
let seen_set: Arc<Mutex<HashSet<u64>>> = Arc::new(Mutex::new(HashSet::new()));
|
||
|
||
// --------------------------------------
|
||
// 2. Pool parameters
|
||
// --------------------------------------
|
||
let connect_timeout = Duration::from_millis(500);
|
||
let connect_backoff = Duration::from_millis(100);
|
||
let max_consumers = 100; // test limit
|
||
let error_threshold = 2; // mock never fails
|
||
let max_idle_duration = Duration::from_secs(2);
|
||
let max_total_connections = 3;
|
||
let aggregate_metrics = None;
|
||
|
||
let pool: Arc<ConnectionPool<MockConnection>> = ConnectionPool::new(
|
||
factory,
|
||
connect_timeout,
|
||
connect_backoff,
|
||
max_consumers,
|
||
error_threshold,
|
||
max_idle_duration,
|
||
max_total_connections,
|
||
aggregate_metrics,
|
||
);
|
||
|
||
// --------------------------------------
|
||
// 3. Spawn worker tasks
|
||
// --------------------------------------
|
||
let num_workers = 10000;
|
||
let mut handles = Vec::with_capacity(num_workers);
|
||
let start_time = Instant::now();
|
||
|
||
for worker_id in 0..num_workers {
|
||
let pool_clone = Arc::clone(&pool);
|
||
let usage_clone = Arc::clone(&usage_map);
|
||
let seen_clone = Arc::clone(&seen_set);
|
||
let mc = max_consumers;
|
||
|
||
let handle = task::spawn(async move {
|
||
client_worker(pool_clone, usage_clone, seen_clone, mc, worker_id).await;
|
||
});
|
||
handles.push(handle);
|
||
}
|
||
|
||
// --------------------------------------
|
||
// 4. Wait for workers to finish
|
||
// --------------------------------------
|
||
for handle in handles {
|
||
let _ = handle.await;
|
||
}
|
||
let elapsed = Instant::now().duration_since(start_time);
|
||
println!("All {} workers completed in {:?}", num_workers, elapsed);
|
||
|
||
// --------------------------------------
|
||
// 5. Print the total number of unique connections seen so far
|
||
// --------------------------------------
|
||
let unique_count = {
|
||
let seen_guard = seen_set.lock().unwrap();
|
||
seen_guard.len()
|
||
};
|
||
println!("Total unique connections used by workers: {}", unique_count);
|
||
|
||
// --------------------------------------
|
||
// 6. Sleep so the background sweeper can run (max_idle_duration = 2 s)
|
||
// --------------------------------------
|
||
sleep(Duration::from_secs(3)).await;
|
||
|
||
// --------------------------------------
|
||
// 7. Shutdown the pool
|
||
// --------------------------------------
|
||
let shutdown_pool = Arc::clone(&pool);
|
||
shutdown_pool.shutdown().await;
|
||
println!("Pool.shutdown() returned.");
|
||
|
||
// --------------------------------------
|
||
// 8. Verify that no background task still holds an Arc clone of `pool`.
|
||
// If any task is still alive (sweeper/create_connection), strong_count > 1.
|
||
// --------------------------------------
|
||
sleep(Duration::from_secs(1)).await; // give tasks time to exit
|
||
let sc = Arc::strong_count(&pool);
|
||
assert!(
|
||
sc == 1,
|
||
"Pool tasks did not all terminate: Arc::strong_count = {} (expected 1)",
|
||
sc
|
||
);
|
||
println!("Verified: all pool tasks have terminated (strong_count == 1).");
|
||
|
||
// --------------------------------------
|
||
// 9. Verify no MockConnection was leaked:
|
||
// CREATED must equal DROPPED.
|
||
// --------------------------------------
|
||
let created = CREATED.load(Ordering::SeqCst);
|
||
let dropped = DROPPED.load(Ordering::SeqCst);
|
||
assert!(
|
||
created == dropped,
|
||
"Leaked connections: created={} but dropped={}",
|
||
created,
|
||
dropped
|
||
);
|
||
println!(
|
||
"Verified: no connections leaked (created = {}, dropped = {}).",
|
||
created, dropped
|
||
);
|
||
|
||
// --------------------------------------
|
||
// 10. Because `client_worker` asserted inside that no connection
|
||
// ever exceeded `max_consumers`, reaching this point means that check passed.
|
||
// --------------------------------------
|
||
println!(
|
||
"All per-connection usage stayed within max_consumers = {}.",
|
||
max_consumers
|
||
);
|
||
|
||
println!("Load test complete; exiting cleanly.");
|
||
}
|