diff --git a/pageserver/client_grpc/examples/load_test.rs b/pageserver/client_grpc/examples/load_test.rs deleted file mode 100644 index 0ac6f18c6e..0000000000 --- a/pageserver/client_grpc/examples/load_test.rs +++ /dev/null @@ -1,273 +0,0 @@ -// examples/load_test.rs, generated by AI - -use std::collections::{HashMap, HashSet}; -use std::sync::{ - Arc, Mutex, - atomic::{AtomicU64, AtomicUsize, Ordering}, -}; -use std::time::{Duration, Instant}; - -use rand::Rng; -use tokio::task; -use tokio::time::sleep; -use tonic::Status; - -// Pull in your ConnectionPool and PooledItemFactory from the pageserver_client_grpc crate. -// Adjust these paths if necessary. -use pageserver_client_grpc::client_cache::ConnectionPool; -use pageserver_client_grpc::client_cache::PooledItemFactory; - -// -------------------------------------- -// GLOBAL COUNTERS FOR “CREATED” / “DROPPED” MockConnections -// -------------------------------------- -static CREATED: AtomicU64 = AtomicU64::new(0); -static DROPPED: AtomicU64 = AtomicU64::new(0); - -// -------------------------------------- -// MockConnection + Factory -// -------------------------------------- - -#[derive(Debug)] -pub struct MockConnection { - pub id: u64, -} - -impl Clone for MockConnection { - fn clone(&self) -> Self { - // Cloning a MockConnection does NOT count as “creating” a brand‐new connection, - // so we do NOT bump CREATED here. We only bump CREATED in the factory’s `create()`. - CREATED.fetch_add(1, Ordering::Relaxed); - MockConnection { id: self.id } - } -} - -impl Drop for MockConnection { - fn drop(&mut self) { - // When a MockConnection actually gets dropped, bump the counter. - DROPPED.fetch_add(1, Ordering::SeqCst); - } -} - -#[derive(Default)] -pub struct MockConnectionFactory { - counter: AtomicU64, -} - -#[async_trait::async_trait] -impl PooledItemFactory for MockConnectionFactory { - /// The trait on ConnectionPool expects: - /// async fn create(&self, timeout: Duration) - /// -> Result, tokio::time::error::Elapsed>; - /// - /// On success: Ok(Ok(MockConnection)) - /// On a simulated “gRPC” failure: Ok(Err(Status::…)) - /// On a transport/factory error: Err(Box<…>) - async fn create( - &self, - _timeout: Duration, - ) -> Result, tokio::time::error::Elapsed> { - // Simulate connection creation immediately succeeding. - CREATED.fetch_add(1, Ordering::SeqCst); - let next_id = self.counter.fetch_add(1, Ordering::Relaxed); - Ok(Ok(MockConnection { id: next_id })) - } -} - -// -------------------------------------- -// CLIENT WORKER -// -------------------------------------- -// -// Each worker repeatedly calls `pool.get_client().await`. When it succeeds, we: -// 1. Lock the shared Mutex>> to fetch/insert an Arc for this conn_id. -// 2. Lock the shared Mutex> to record this conn_id as “seen.” -// 3. Drop both locks, then atomically increment that counter and assert it ≤ max_consumers. -// 4. Sleep 10–100 ms to simulate “work.” -// 5. Atomically decrement the counter. -// 6. Call `pooled.finish(Ok(()))` to return to the pool. - -async fn client_worker( - pool: Arc>, - usage_map: Arc>>>, - seen_set: Arc>>, - max_consumers: usize, - worker_id: usize, -) { - for iteration in 0..10 { - match pool.clone().get_client().await { - Ok(pooled) => { - let conn: MockConnection = pooled.channel(); - let conn_id = conn.id; - - // 1. Fetch or insert the Arc for this conn_id: - let counter_arc: Arc = { - let mut guard = usage_map.lock().unwrap(); - guard - .entry(conn_id) - .or_insert_with(|| Arc::new(AtomicUsize::new(0))) - .clone() - // MutexGuard is dropped here - }; - - // 2. Record this conn_id in the shared HashSet of “seen” IDs: - { - let mut seen_guard = seen_set.lock().unwrap(); - seen_guard.insert(conn_id); - // MutexGuard is dropped immediately - } - - // 3. Atomically bump the count for this connection ID - let prev = counter_arc.fetch_add(1, Ordering::SeqCst); - let current = prev + 1; - assert!( - current <= max_consumers, - "Connection {conn_id} exceeded max_consumers (got {current})", - ); - - println!( - "[worker {worker_id}][iter {iteration}] got MockConnection id={conn_id} ({current} concurrent)", - ); - - // 4. Simulate some work (10–100 ms) - let delay_ms = rand::thread_rng().gen_range(10..100); - sleep(Duration::from_millis(delay_ms)).await; - - // 5. Decrement the usage counter - let prev2 = counter_arc.fetch_sub(1, Ordering::SeqCst); - let after = prev2 - 1; - println!( - "[worker {worker_id}][iter {iteration}] returning MockConnection id={conn_id} (now {after} remain)", - ); - - // 6. Return to the pool (mark success) - pooled.finish(Ok(())).await; - } - Err(status) => { - eprintln!( - "[worker {worker_id}][iter {iteration}] failed to get client: {status:?}", - ); - } - } - - // Small random pause before next iteration to spread out load - let pause = rand::thread_rng().gen_range(0..20); - sleep(Duration::from_millis(pause)).await; - } -} - -#[tokio::main(flavor = "multi_thread", worker_threads = 8)] -async fn main() { - // -------------------------------------- - // 1. Create factory and shared instrumentation - // -------------------------------------- - let factory = Arc::new(MockConnectionFactory::default()); - - // Shared map: connection ID → Arc - let usage_map: Arc>>> = - Arc::new(Mutex::new(HashMap::new())); - - // Shared set: record each unique connection ID we actually saw - let seen_set: Arc>> = Arc::new(Mutex::new(HashSet::new())); - - // -------------------------------------- - // 2. Pool parameters - // -------------------------------------- - let connect_timeout = Duration::from_millis(500); - let connect_backoff = Duration::from_millis(100); - let max_consumers = 100; // test limit - let error_threshold = 2; // mock never fails - let max_idle_duration = Duration::from_secs(2); - let max_total_connections = 3; - let aggregate_metrics = None; - - let pool: Arc> = ConnectionPool::new( - factory, - connect_timeout, - connect_backoff, - max_consumers, - error_threshold, - max_idle_duration, - max_total_connections, - aggregate_metrics, - ); - - // -------------------------------------- - // 3. Spawn worker tasks - // -------------------------------------- - let num_workers = 10000; - let mut handles = Vec::with_capacity(num_workers); - let start_time = Instant::now(); - - for worker_id in 0..num_workers { - let pool_clone = Arc::clone(&pool); - let usage_clone = Arc::clone(&usage_map); - let seen_clone = Arc::clone(&seen_set); - let mc = max_consumers; - - let handle = task::spawn(async move { - client_worker(pool_clone, usage_clone, seen_clone, mc, worker_id).await; - }); - handles.push(handle); - } - - // -------------------------------------- - // 4. Wait for workers to finish - // -------------------------------------- - for handle in handles { - let _ = handle.await; - } - let elapsed = Instant::now().duration_since(start_time); - println!("All {num_workers} workers completed in {elapsed:?}"); - - // -------------------------------------- - // 5. Print the total number of unique connections seen so far - // -------------------------------------- - let unique_count = { - let seen_guard = seen_set.lock().unwrap(); - seen_guard.len() - }; - println!("Total unique connections used by workers: {unique_count}"); - - // -------------------------------------- - // 6. Sleep so the background sweeper can run (max_idle_duration = 2 s) - // -------------------------------------- - sleep(Duration::from_secs(3)).await; - - // -------------------------------------- - // 7. Shutdown the pool - // -------------------------------------- - let shutdown_pool = Arc::clone(&pool); - shutdown_pool.shutdown().await; - println!("Pool.shutdown() returned."); - - // -------------------------------------- - // 8. Verify that no background task still holds an Arc clone of `pool`. - // If any task is still alive (sweeper/create_connection), strong_count > 1. - // -------------------------------------- - sleep(Duration::from_secs(1)).await; // give tasks time to exit - let sc = Arc::strong_count(&pool); - assert!( - sc == 1, - "Pool tasks did not all terminate: Arc::strong_count = {sc} (expected 1)", - ); - println!("Verified: all pool tasks have terminated (strong_count == 1)."); - - // -------------------------------------- - // 9. Verify no MockConnection was leaked: - // CREATED must equal DROPPED. - // -------------------------------------- - let created = CREATED.load(Ordering::SeqCst); - let dropped = DROPPED.load(Ordering::SeqCst); - assert!( - created == dropped, - "Leaked connections: created={created} but dropped={dropped}", - ); - println!("Verified: no connections leaked (created = {created}, dropped = {dropped})."); - - // -------------------------------------- - // 10. Because `client_worker` asserted inside that no connection - // ever exceeded `max_consumers`, reaching this point means that check passed. - // -------------------------------------- - println!("All per-connection usage stayed within max_consumers = {max_consumers}."); - - println!("Load test complete; exiting cleanly."); -} diff --git a/pageserver/client_grpc/src/client_cache.rs b/pageserver/client_grpc/src/client_cache.rs deleted file mode 100644 index 6da402d849..0000000000 --- a/pageserver/client_grpc/src/client_cache.rs +++ /dev/null @@ -1,705 +0,0 @@ -use std::{ - collections::HashMap, - io::{self, Error, ErrorKind}, - sync::Arc, - time::{Duration, Instant}, -}; - -use priority_queue::PriorityQueue; - -use tokio::{ - io::{AsyncRead, AsyncWrite, ReadBuf}, - net::TcpStream, - sync::{Mutex, OwnedSemaphorePermit, Semaphore}, - time::sleep, -}; -use tonic::transport::{Channel, Endpoint}; - -use uuid; - -use std::{ - pin::Pin, - task::{Context, Poll}, -}; - -use futures::future; -use rand::{Rng, SeedableRng, rngs::StdRng}; - -use bytes::BytesMut; -use http::Uri; -use hyper_util::rt::TokioIo; -use tower::service_fn; - -use async_trait::async_trait; -use tokio_util::sync::CancellationToken; - -// -// The "TokioTcp" is flakey TCP network for testing purposes, in order -// to simulate network errors and delays. -// - -/// Wraps a `TcpStream`, buffers incoming data, and injects a random delay per fresh read/write. -pub struct TokioTcp { - tcp: TcpStream, - /// Maximum randomized delay in milliseconds - delay_ms: u64, - - /// Next deadline instant for delay - deadline: Instant, - /// Internal buffer of previously-read data - buffer: BytesMut, -} - -impl TokioTcp { - /// Create a new wrapper with given max delay (ms) - pub fn new(stream: TcpStream, delay_ms: u64) -> Self { - let initial = if delay_ms > 0 { - rand::thread_rng().gen_range(0..delay_ms) - } else { - 0 - }; - let deadline = Instant::now() + Duration::from_millis(initial); - TokioTcp { - tcp: stream, - delay_ms, - deadline, - buffer: BytesMut::new(), - } - } -} - -impl AsyncRead for TokioTcp { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - // Safe because TokioTcp is Unpin - let this = self.get_mut(); - - // 1) Drain any buffered data - if !this.buffer.is_empty() { - let to_copy = this.buffer.len().min(buf.remaining()); - buf.put_slice(&this.buffer.split_to(to_copy)); - return Poll::Ready(Ok(())); - } - - // 2) If we're still before the deadline, schedule a wake and return Pending - let now = Instant::now(); - if this.delay_ms > 0 && now < this.deadline { - let waker = cx.waker().clone(); - let wait = this.deadline - now; - tokio::spawn(async move { - sleep(wait).await; - waker.wake_by_ref(); - }); - return Poll::Pending; - } - - // 3) Past deadline: compute next random deadline - if this.delay_ms > 0 { - let next_ms = rand::thread_rng().gen_range(0..=this.delay_ms); - this.deadline = Instant::now() + Duration::from_millis(next_ms); - } - - // 4) Perform actual read into a temporary buffer - let mut tmp = [0u8; 4096]; - let mut rb = ReadBuf::new(&mut tmp); - match Pin::new(&mut this.tcp).poll_read(cx, &mut rb) { - Poll::Pending => Poll::Pending, - Poll::Ready(Ok(())) => { - let filled = rb.filled(); - if filled.is_empty() { - // EOF or zero bytes - Poll::Ready(Ok(())) - } else { - this.buffer.extend_from_slice(filled); - let to_copy = this.buffer.len().min(buf.remaining()); - buf.put_slice(&this.buffer.split_to(to_copy)); - Poll::Ready(Ok(())) - } - } - Poll::Ready(Err(e)) => Poll::Ready(Err(e)), - } - } -} - -impl AsyncWrite for TokioTcp { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - data: &[u8], - ) -> Poll> { - let this = self.get_mut(); - - // 1) If before deadline, schedule wake and return Pending - let now = Instant::now(); - if this.delay_ms > 0 && now < this.deadline { - let waker = cx.waker().clone(); - let wait = this.deadline - now; - tokio::spawn(async move { - sleep(wait).await; - waker.wake_by_ref(); - }); - return Poll::Pending; - } - - // 2) Past deadline: compute next random deadline - if this.delay_ms > 0 { - let next_ms = rand::thread_rng().gen_range(0..=this.delay_ms); - this.deadline = Instant::now() + Duration::from_millis(next_ms); - } - - // 3) Actual write - Pin::new(&mut this.tcp).poll_write(cx, data) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.get_mut(); - Pin::new(&mut this.tcp).poll_flush(cx) - } - - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.get_mut(); - Pin::new(&mut this.tcp).poll_shutdown(cx) - } -} - -#[async_trait] -pub trait PooledItemFactory: Send + Sync + 'static { - /// Create a new pooled item. - async fn create( - &self, - connect_timeout: Duration, - ) -> Result, tokio::time::error::Elapsed>; -} - -pub struct ChannelFactory { - endpoint: String, - max_delay_ms: u64, - drop_rate: f64, - hang_rate: f64, -} - -impl ChannelFactory { - pub fn new(endpoint: String, max_delay_ms: u64, drop_rate: f64, hang_rate: f64) -> Self { - ChannelFactory { - endpoint, - max_delay_ms, - drop_rate, - hang_rate, - } - } -} - -#[async_trait] -impl PooledItemFactory for ChannelFactory { - async fn create( - &self, - connect_timeout: Duration, - ) -> Result, tokio::time::error::Elapsed> { - let max_delay_ms = self.max_delay_ms; - let drop_rate = self.drop_rate; - let hang_rate = self.hang_rate; - - // This is a custom connector that inserts delays and errors, for - // testing purposes. It would normally be disabled by the config. - let connector = service_fn(move |uri: Uri| { - let drop_rate = drop_rate; - let hang_rate = hang_rate; - async move { - let mut rng = StdRng::from_entropy(); - // Simulate an indefinite hang - if hang_rate > 0.0 && rng.gen_bool(hang_rate) { - // never completes, to test timeout - return future::pending::, std::io::Error>>().await; - } - - // Random drop (connect error) - if drop_rate > 0.0 && rng.gen_bool(drop_rate) { - return Err(std::io::Error::other("simulated connect drop")); - } - - // Otherwise perform real TCP connect - let addr = match (uri.host(), uri.port()) { - // host + explicit port - (Some(host), Some(port)) => format!("{}:{}", host, port.as_str()), - // host only (no port) - (Some(host), None) => host.to_string(), - // neither? error out - _ => return Err(Error::new(ErrorKind::InvalidInput, "no host or port")), - }; - - let tcp = TcpStream::connect(addr).await?; - let tcpwrapper = TokioTcp::new(tcp, max_delay_ms); - Ok(TokioIo::new(tcpwrapper)) - } - }); - - let attempt = tokio::time::timeout( - connect_timeout, - Endpoint::from_shared(self.endpoint.clone()) - .expect("invalid endpoint") - .timeout(connect_timeout) - .connect_with_connector(connector), - ) - .await; - match attempt { - Ok(Ok(channel)) => { - // Connection succeeded - Ok(Ok(channel)) - } - Ok(Err(e)) => Ok(Err(tonic::Status::new( - tonic::Code::Unavailable, - format!("Failed to connect: {e}"), - ))), - Err(e) => Err(e), - } - } -} - -/// A pooled gRPC client with capacity tracking and error handling. -pub struct ConnectionPool { - inner: Mutex>, - - fact: Arc + Send + Sync>, - - connect_timeout: Duration, - connect_backoff: Duration, - /// The maximum number of consumers that can use a single connection. - max_consumers: usize, - /// The number of consecutive errors before a connection is removed from the pool. - error_threshold: usize, - /// The maximum duration a connection can be idle before being removed. - max_idle_duration: Duration, - max_total_connections: usize, - - channel_semaphore: Arc, - - shutdown_token: CancellationToken, - aggregate_metrics: Option>, -} - -struct Inner { - entries: HashMap>, - pq: PriorityQueue, - // This is updated when a connection is dropped, or we fail - // to create a new connection. - last_connect_failure: Option, - waiters: usize, - in_progress: usize, -} -struct ConnectionEntry { - channel: T, - active_consumers: usize, - consecutive_errors: usize, - last_used: Instant, -} - -/// A client borrowed from the pool. -pub struct PooledClient { - pub channel: T, - pool: Arc>, - is_ok: bool, - id: uuid::Uuid, - permit: OwnedSemaphorePermit, -} - -impl ConnectionPool { - #[allow(clippy::too_many_arguments)] - pub fn new( - fact: Arc + Send + Sync>, - connect_timeout: Duration, - connect_backoff: Duration, - max_consumers: usize, - error_threshold: usize, - max_idle_duration: Duration, - max_total_connections: usize, - aggregate_metrics: Option>, - ) -> Arc { - let shutdown_token = CancellationToken::new(); - let pool = Arc::new(Self { - inner: Mutex::new(Inner:: { - entries: HashMap::new(), - pq: PriorityQueue::new(), - last_connect_failure: None, - waiters: 0, - in_progress: 0, - }), - fact: Arc::clone(&fact), - connect_timeout, - connect_backoff, - max_consumers, - error_threshold, - max_idle_duration, - max_total_connections, - channel_semaphore: Arc::new(Semaphore::new(0)), - shutdown_token: shutdown_token.clone(), - aggregate_metrics: aggregate_metrics.clone(), - }); - - // Cancelable background task to sweep idle connections - let sweeper_token = shutdown_token.clone(); - let sweeper_pool = Arc::clone(&pool); - tokio::spawn(async move { - loop { - tokio::select! { - _ = sweeper_token.cancelled() => break, - _ = async { - sweeper_pool.sweep_idle_connections().await; - sleep(Duration::from_secs(5)).await; - } => {} - } - } - }); - - pool - } - - pub async fn shutdown(self: Arc) { - self.shutdown_token.cancel(); - - loop { - let all_idle = { - let inner = self.inner.lock().await; - inner.entries.values().all(|e| e.active_consumers == 0) - }; - if all_idle { - break; - } - sleep(Duration::from_millis(100)).await; - } - - // 4. Remove all entries - let mut inner = self.inner.lock().await; - inner.entries.clear(); - } - - /// Sweep and remove idle connections safely, burning their permits. - async fn sweep_idle_connections(self: &Arc) { - let mut ids_to_remove = Vec::new(); - let now = Instant::now(); - - // Remove idle entries. First collect permits for those connections so that - // no consumer will reserve them, then remove them from the pool. - { - let mut inner = self.inner.lock().await; - inner.entries.retain(|id, entry| { - if entry.active_consumers == 0 - && now.duration_since(entry.last_used) > self.max_idle_duration - { - // metric - if let Some(ref metrics) = self.aggregate_metrics { - metrics - .retry_counters - .with_label_values(&["connection_swept"]) - .inc(); - } - ids_to_remove.push(*id); - return false; // remove this entry - } - true - }); - // Remove the entries from the priority queue - for id in ids_to_remove { - inner.pq.remove(&id); - } - } - } - - // If we have a permit already, get a connection out of the heap - async fn get_conn_with_permit( - self: Arc, - permit: OwnedSemaphorePermit, - ) -> Option> { - let mut inner = self.inner.lock().await; - - // Pop the highest-active-consumers connection. There are no connections - // in the heap that have more than max_consumers active consumers. - if let Some((id, _cons)) = inner.pq.pop() { - let entry = inner - .entries - .get_mut(&id) - .expect("pq and entries got out of sync"); - - let mut active_consumers = entry.active_consumers; - entry.active_consumers += 1; - entry.last_used = Instant::now(); - - let client = PooledClient:: { - channel: entry.channel.clone(), - pool: Arc::clone(&self), - is_ok: true, - id, - permit, - }; - - // re‐insert with updated priority - active_consumers += 1; - if active_consumers < self.max_consumers { - inner.pq.push(id, active_consumers as usize); - } - Some(client) - } else { - // If there is no connection to take, it is because permits for a connection - // need to drain. This can happen if a connection is removed because it has - // too many errors. It is taken out of the heap/hash table in this case, but - // we can't remove it's permits until now. - // - // Just forget the permit and retry. - permit.forget(); - None - } - } - - pub async fn get_client(self: Arc) -> Result, tonic::Status> { - // The pool is shutting down. Don't accept new connections. - if self.shutdown_token.is_cancelled() { - return Err(tonic::Status::unavailable("Pool is shutting down")); - } - - // A loop is necessary because when a connection is draining, we have to return - // a permit and retry. - loop { - let self_clone = Arc::clone(&self); - let mut semaphore = Arc::clone(&self_clone.channel_semaphore); - - match semaphore.try_acquire_owned() { - Ok(permit_) => { - // We got a permit, so check the heap for a connection - // we can use. - let pool_conn = self_clone.get_conn_with_permit(permit_).await; - match pool_conn { - Some(pool_conn_) => { - return Ok(pool_conn_); - } - None => { - // No connection available. Forget the permit and retry. - continue; - } - } - } - Err(_) => { - if let Some(ref metrics) = self_clone.aggregate_metrics { - metrics - .retry_counters - .with_label_values(&["sema_acquire_success"]) - .inc(); - } - - { - // - // This is going to generate enough connections to handle a burst, - // but it may generate up to twice the number of connections needed - // in the worst case. Extra connections will go idle and be cleaned - // up. - // - let mut inner = self_clone.inner.lock().await; - inner.waiters += 1; - if inner.waiters > (inner.in_progress * self_clone.max_consumers) - && (inner.entries.len() + inner.in_progress) - < self_clone.max_total_connections - { - let self_clone_spawn = Arc::clone(&self_clone); - tokio::task::spawn(async move { - self_clone_spawn.create_connection().await; - }); - inner.in_progress += 1; - } - } - // Wait for a connection to become available, either because it - // was created or because a connection was returned to the pool - // by another consumer. - semaphore = Arc::clone(&self_clone.channel_semaphore); - let conn_permit = semaphore.acquire_owned().await.unwrap(); - { - let mut inner = self_clone.inner.lock().await; - inner.waiters -= 1; - } - // We got a permit, check the heap for a connection. - let pool_conn = self_clone.get_conn_with_permit(conn_permit).await; - match pool_conn { - Some(pool_conn_) => { - return Ok(pool_conn_); - } - None => { - // No connection was found, forget the permit and retry. - continue; - } - } - } - } - } - } - - async fn create_connection(&self) { - // Generate a random backoff to add some jitter so that connections - // don't all retry at the same time. - let mut backoff_delay = Duration::from_millis( - rand::thread_rng().gen_range(0..=self.connect_backoff.as_millis() as u64), - ); - - loop { - if self.shutdown_token.is_cancelled() { - return; - } - - // Back off. - // Loop because failure can occur while we are sleeping, so wait - // until the failure stopped for at least one backoff period. Backoff - // period includes some jitter, so that if multiple connections are - // failing, they don't all retry at the same time. - while let Some(delay) = { - let inner = self.inner.lock().await; - inner.last_connect_failure.and_then(|at| { - (at.elapsed() < backoff_delay).then(|| backoff_delay - at.elapsed()) - }) - } { - sleep(delay).await; - } - - // - // Create a new connection. - // - // The connect timeout is also the timeout for an individual gRPC request - // on this connection. (Requests made later on this channel will time out - // with the same timeout.) - // - if let Some(ref metrics) = self.aggregate_metrics { - metrics - .retry_counters - .with_label_values(&["connection_attempt"]) - .inc(); - } - - let attempt = self.fact.create(self.connect_timeout).await; - - match attempt { - // Connection succeeded - Ok(Ok(channel)) => { - { - if let Some(ref metrics) = self.aggregate_metrics { - metrics - .retry_counters - .with_label_values(&["connection_success"]) - .inc(); - } - let mut inner = self.inner.lock().await; - let id = uuid::Uuid::new_v4(); - inner.entries.insert( - id, - ConnectionEntry:: { - channel: channel.clone(), - active_consumers: 0, - consecutive_errors: 0, - last_used: Instant::now(), - }, - ); - inner.pq.push(id, 0); - inner.in_progress -= 1; - self.channel_semaphore.add_permits(self.max_consumers); - return; - }; - } - // Connection failed, back off and retry - Ok(Err(_)) | Err(_) => { - if let Some(ref metrics) = self.aggregate_metrics { - metrics - .retry_counters - .with_label_values(&["connect_failed"]) - .inc(); - } - let mut inner = self.inner.lock().await; - inner.last_connect_failure = Some(Instant::now()); - // Add some jitter so that every connection doesn't retry at once - let jitter = rand::thread_rng().gen_range(0..=backoff_delay.as_millis() as u64); - backoff_delay = - Duration::from_millis(backoff_delay.as_millis() as u64 + jitter); - - // Do not backoff longer than one minute - if backoff_delay > Duration::from_secs(60) { - backoff_delay = Duration::from_secs(60); - } - // continue the loop to retry - } - } - } - } - - /// Return client to the pool, indicating success or error. - pub async fn return_client(&self, id: uuid::Uuid, success: bool, permit: OwnedSemaphorePermit) { - let mut inner = self.inner.lock().await; - if let Some(entry) = inner.entries.get_mut(&id) { - entry.last_used = Instant::now(); - if entry.active_consumers == 0 { - panic!("A consumer completed when active_consumers was zero!") - } - entry.active_consumers -= 1; - if success { - if entry.consecutive_errors < self.error_threshold { - entry.consecutive_errors = 0; - } - } else { - entry.consecutive_errors += 1; - if entry.consecutive_errors == self.error_threshold { - if let Some(ref metrics) = self.aggregate_metrics { - metrics - .retry_counters - .with_label_values(&["connection_dropped"]) - .inc(); - } - } - } - - // - // Too many errors on this connection. If there are no active users, - // remove it. Otherwise just wait for active_consumers to go to zero. - // This connection will not be selected for new consumers. - // - let active_consumers = entry.active_consumers; - if entry.consecutive_errors >= self.error_threshold { - // too many errors, remove the connection permanently. Once it drains, - // it will be dropped. - if inner.pq.get_priority(&id).is_some() { - inner.pq.remove(&id); - } - - // remove from entries - // check if entry is in inner - if inner.entries.contains_key(&id) { - inner.entries.remove(&id); - } - inner.last_connect_failure = Some(Instant::now()); - - // The connection has been removed, it's permits will be - // drained because if we look for a connection and it's not there - // we just forget the permit. However, this process can be a little - // bit faster if we just forget permits as the connections are returned. - permit.forget(); - } else { - // update its priority in the queue - if inner.pq.get_priority(&id).is_some() { - inner.pq.change_priority(&id, active_consumers); - } else { - // This connection is not in the heap, but it has space - // for more consumers. Put it back in the heap. - if active_consumers < self.max_consumers { - inner.pq.push(id, active_consumers); - } - } - } - } - } -} - -impl PooledClient { - pub fn channel(&self) -> T { - self.channel.clone() - } - pub async fn finish(mut self, result: Result<(), tonic::Status>) { - self.is_ok = result.is_ok(); - self.pool - .return_client(self.id, self.is_ok, self.permit) - .await; - } -} diff --git a/pageserver/client_grpc/src/lib.rs b/pageserver/client_grpc/src/lib.rs index ee773ec378..d3353b9aad 100644 --- a/pageserver/client_grpc/src/lib.rs +++ b/pageserver/client_grpc/src/lib.rs @@ -1,451 +1,4 @@ -//! Pageserver Data API client -//! -//! - Manage connections to pageserver -//! - Send requests to correct shards -//! -use std::collections::HashMap; -use std::fmt::Debug; -use std::sync::Arc; -use std::sync::RwLock; -use std::time::Duration; +mod client; +mod pool; -use bytes::Bytes; -use futures::{Stream, StreamExt}; -use thiserror::Error; -use tonic::metadata::AsciiMetadataValue; -use tonic::transport::Channel; - -use pageserver_page_api::proto; -use pageserver_page_api::proto::PageServiceClient; -use pageserver_page_api::*; -use utils::shard::ShardIndex; - -pub mod client; -pub mod client_cache; -pub mod pool; -pub mod request_tracker; - -use metrics::{IntCounterVec, core::Collector}; - -#[derive(Error, Debug)] -pub enum PageserverClientError { - #[error("could not connect to service: {0}")] - ConnectError(#[from] tonic::transport::Error), - #[error("could not perform request: {0}`")] - RequestError(#[from] tonic::Status), - #[error("protocol error: {0}")] - ProtocolError(#[from] ProtocolError), - - #[error("could not perform request: {0}`")] - InvalidUri(#[from] http::uri::InvalidUri), - - #[error("could not perform request: {0}`")] - Other(String), -} - -#[derive(Clone, Debug)] -pub struct PageserverClientAggregateMetrics { - pub request_counters: IntCounterVec, - pub retry_counters: IntCounterVec, -} - -impl Default for PageserverClientAggregateMetrics { - fn default() -> Self { - Self::new() - } -} - -impl PageserverClientAggregateMetrics { - pub fn new() -> Self { - let request_counters = IntCounterVec::new( - metrics::core::Opts::new( - "backend_requests_total", - "Number of requests from backends.", - ), - &["request_kind"], - ) - .unwrap(); - - let retry_counters = IntCounterVec::new( - metrics::core::Opts::new( - "backend_requests_retries_total", - "Number of retried requests from backends.", - ), - &["request_kind"], - ) - .unwrap(); - Self { - request_counters, - retry_counters, - } - } - - pub fn collect(&self) -> Vec { - let mut metrics = Vec::new(); - metrics.append(&mut self.request_counters.collect()); - metrics.append(&mut self.retry_counters.collect()); - metrics - } -} - -pub struct PageserverClient { - _tenant_id: String, - _timeline_id: String, - - _auth_token: Option, - - shard_map: HashMap, - - channels: RwLock>>>, - - auth_interceptor: AuthInterceptor, - - client_cache_options: ClientCacheOptions, - - aggregate_metrics: Option>, -} -#[derive(Clone)] -pub struct ClientCacheOptions { - pub max_consumers: usize, - pub error_threshold: usize, - pub connect_timeout: Duration, - pub connect_backoff: Duration, - pub max_idle_duration: Duration, - pub max_total_connections: usize, - pub max_delay_ms: u64, - pub drop_rate: f64, - pub hang_rate: f64, -} - -impl PageserverClient { - /// TODO: this doesn't currently react to changes in the shard map. - pub fn new( - tenant_id: &str, - timeline_id: &str, - auth_token: &Option, - shard_map: HashMap, - ) -> Self { - let options = ClientCacheOptions { - max_consumers: 5000, - error_threshold: 5, - connect_timeout: Duration::from_secs(5), - connect_backoff: Duration::from_secs(1), - max_idle_duration: Duration::from_secs(60), - max_total_connections: 100000, - max_delay_ms: 0, - drop_rate: 0.0, - hang_rate: 0.0, - }; - Self::new_with_config(tenant_id, timeline_id, auth_token, shard_map, options, None) - } - pub fn new_with_config( - tenant_id: &str, - timeline_id: &str, - auth_token: &Option, - shard_map: HashMap, - options: ClientCacheOptions, - metrics: Option>, - ) -> Self { - Self { - _tenant_id: tenant_id.to_string(), - _timeline_id: timeline_id.to_string(), - _auth_token: auth_token.clone(), - shard_map, - channels: RwLock::new(HashMap::new()), - auth_interceptor: AuthInterceptor::new(tenant_id, timeline_id, auth_token.as_deref()), - client_cache_options: options, - aggregate_metrics: metrics, - } - } - pub async fn process_check_rel_exists_request( - &self, - request: CheckRelExistsRequest, - ) -> Result { - // Current sharding model assumes that all metadata is present only at shard 0. - let shard = ShardIndex::unsharded(); - let pooled_client = self.get_client(shard).await; - let chan = pooled_client.channel(); - - let mut client = - PageServiceClient::with_interceptor(chan, self.auth_interceptor.for_shard(shard)); - - let request = proto::CheckRelExistsRequest::from(request); - let response = client.check_rel_exists(tonic::Request::new(request)).await; - - match response { - Err(status) => { - pooled_client.finish(Err(status.clone())).await; // Pass error to finish - Err(PageserverClientError::RequestError(status)) - } - Ok(resp) => { - pooled_client.finish(Ok(())).await; // Pass success to finish - Ok(resp.get_ref().exists) - } - } - } - - pub async fn process_get_rel_size_request( - &self, - request: GetRelSizeRequest, - ) -> Result { - // Current sharding model assumes that all metadata is present only at shard 0. - let shard = ShardIndex::unsharded(); - let pooled_client = self.get_client(shard).await; - let chan = pooled_client.channel(); - - let mut client = - PageServiceClient::with_interceptor(chan, self.auth_interceptor.for_shard(shard)); - - let request = proto::GetRelSizeRequest::from(request); - let response = client.get_rel_size(tonic::Request::new(request)).await; - - match response { - Err(status) => { - pooled_client.finish(Err(status.clone())).await; // Pass error to finish - Err(PageserverClientError::RequestError(status)) - } - Ok(resp) => { - pooled_client.finish(Ok(())).await; // Pass success to finish - Ok(resp.get_ref().num_blocks) - } - } - } - - // Request a single batch of pages - // - // TODO: This opens a new gRPC stream for every request, which is extremely inefficient - pub async fn get_page( - &self, - request: GetPageRequest, - ) -> Result, PageserverClientError> { - // FIXME: calculate the shard number correctly - let shard = ShardIndex::unsharded(); - let pooled_client = self.get_client(shard).await; - let chan = pooled_client.channel(); - - let mut client = - PageServiceClient::with_interceptor(chan, self.auth_interceptor.for_shard(shard)); - - let request = proto::GetPageRequest::from(request); - - let request_stream = futures::stream::once(std::future::ready(request)); - - let mut response_stream = client - .get_pages(tonic::Request::new(request_stream)) - .await? - .into_inner(); - - let Some(response) = response_stream.next().await else { - return Err(PageserverClientError::Other( - "no response received for getpage request".to_string(), - )); - }; - - if let Some(ref metrics) = self.aggregate_metrics { - metrics - .request_counters - .with_label_values(&["get_page"]) - .inc(); - } - - match response { - Err(status) => { - pooled_client.finish(Err(status.clone())).await; // Pass error to finish - Err(PageserverClientError::RequestError(status)) - } - Ok(resp) => { - pooled_client.finish(Ok(())).await; // Pass success to finish - let response: GetPageResponse = resp.into(); - Ok(response.page_images.to_vec()) - } - } - } - - // Open a stream for requesting pages - // - // TODO: This is a pretty low level interface, the caller should not need to be concerned - // with streams. But 'get_page' is currently very naive and inefficient. - pub async fn get_pages( - &self, - requests: impl Stream + Send + 'static, - ) -> std::result::Result< - tonic::Response>, - PageserverClientError, - > { - // FIXME: calculate the shard number correctly - let shard = ShardIndex::unsharded(); - let pooled_client = self.get_client(shard).await; - let chan = pooled_client.channel(); - - let mut client = - PageServiceClient::with_interceptor(chan, self.auth_interceptor.for_shard(shard)); - - let response = client.get_pages(tonic::Request::new(requests)).await; - - match response { - Err(status) => { - pooled_client.finish(Err(status.clone())).await; // Pass error to finish - Err(PageserverClientError::RequestError(status)) - } - Ok(resp) => Ok(resp), - } - } - - /// Process a request to get the size of a database. - pub async fn process_get_dbsize_request( - &self, - request: GetDbSizeRequest, - ) -> Result { - // Current sharding model assumes that all metadata is present only at shard 0. - let shard = ShardIndex::unsharded(); - let pooled_client = self.get_client(shard).await; - let chan = pooled_client.channel(); - - let mut client = - PageServiceClient::with_interceptor(chan, self.auth_interceptor.for_shard(shard)); - - let request = proto::GetDbSizeRequest::from(request); - let response = client.get_db_size(tonic::Request::new(request)).await; - - match response { - Err(status) => { - pooled_client.finish(Err(status.clone())).await; // Pass error to finish - Err(PageserverClientError::RequestError(status)) - } - Ok(resp) => { - pooled_client.finish(Ok(())).await; // Pass success to finish - Ok(resp.get_ref().num_bytes) - } - } - } - /// Process a request to get the size of a database. - pub async fn get_base_backup( - &self, - request: GetBaseBackupRequest, - gzip: bool, - ) -> std::result::Result< - tonic::Response>, - PageserverClientError, - > { - // Current sharding model assumes that all metadata is present only at shard 0. - let shard = ShardIndex::unsharded(); - let pooled_client = self.get_client(shard).await; - let chan = pooled_client.channel(); - - let mut client = - PageServiceClient::with_interceptor(chan, self.auth_interceptor.for_shard(shard)); - - if gzip { - client = client.accept_compressed(tonic::codec::CompressionEncoding::Gzip); - } - - let request = proto::GetBaseBackupRequest::from(request); - let response = client.get_base_backup(tonic::Request::new(request)).await; - - match response { - Err(status) => { - pooled_client.finish(Err(status.clone())).await; // Pass error to finish - Err(PageserverClientError::RequestError(status)) - } - Ok(resp) => { - pooled_client.finish(Ok(())).await; // Pass success to finish - Ok(resp) - } - } - } - /// Get a client for given shard - /// - /// Get a client from the pool for this shard, also creating the pool if it doesn't exist. - /// - async fn get_client(&self, shard: ShardIndex) -> client_cache::PooledClient { - let reused_pool: Option>> = { - let channels = self.channels.read().unwrap(); - channels.get(&shard).cloned() - }; - - let usable_pool = match reused_pool { - Some(pool) => { - let pooled_client = pool.get_client().await.unwrap(); - return pooled_client; - } - None => { - // Create a new pool using client_cache_options - // declare new_pool - - let channel_fact = Arc::new(client_cache::ChannelFactory::new( - self.shard_map.get(&shard).unwrap().clone(), - self.client_cache_options.max_delay_ms, - self.client_cache_options.drop_rate, - self.client_cache_options.hang_rate, - )); - let new_pool = client_cache::ConnectionPool::new( - channel_fact, - self.client_cache_options.connect_timeout, - self.client_cache_options.connect_backoff, - self.client_cache_options.max_consumers, - self.client_cache_options.error_threshold, - self.client_cache_options.max_idle_duration, - self.client_cache_options.max_total_connections, - self.aggregate_metrics.clone(), - ); - let mut write_pool = self.channels.write().unwrap(); - write_pool.insert(shard, new_pool.clone()); - new_pool.clone() - } - }; - - usable_pool.get_client().await.unwrap() - } -} - -/// Inject tenant_id, timeline_id and authentication token to all pageserver requests. -#[derive(Clone)] -pub struct AuthInterceptor { - tenant_id: AsciiMetadataValue, - shard_id: Option, - timeline_id: AsciiMetadataValue, - - auth_header: Option, // including "Bearer " prefix -} - -impl AuthInterceptor { - pub fn new(tenant_id: &str, timeline_id: &str, auth_token: Option<&str>) -> Self { - Self { - tenant_id: tenant_id.parse().expect("could not parse tenant id"), - shard_id: None, - timeline_id: timeline_id.parse().expect("could not parse timeline id"), - auth_header: auth_token - .map(|t| format!("Bearer {t}")) - .map(|t| t.parse().expect("could not parse auth token")), - } - } - - fn for_shard(&self, shard_id: ShardIndex) -> Self { - let mut with_shard = self.clone(); - with_shard.shard_id = Some( - shard_id - .to_string() - .parse() - .expect("could not parse shard id"), - ); - with_shard - } -} - -impl tonic::service::Interceptor for AuthInterceptor { - fn call(&mut self, mut req: tonic::Request<()>) -> Result, tonic::Status> { - req.metadata_mut() - .insert("neon-tenant-id", self.tenant_id.clone()); - if let Some(shard_id) = &self.shard_id { - req.metadata_mut().insert("neon-shard-id", shard_id.clone()); - } - req.metadata_mut() - .insert("neon-timeline-id", self.timeline_id.clone()); - if let Some(auth_header) = &self.auth_header { - req.metadata_mut() - .insert("authorization", auth_header.clone()); - } - - Ok(req) - } -} +pub use client::PageserverClient; diff --git a/pageserver/client_grpc/src/pool.rs b/pageserver/client_grpc/src/pool.rs index 1fe5c6958a..cbcf26656e 100644 --- a/pageserver/client_grpc/src/pool.rs +++ b/pageserver/client_grpc/src/pool.rs @@ -39,7 +39,7 @@ use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc, oneshot}; use tonic::transport::{Channel, Endpoint}; use tracing::warn; -use pageserver_page_api::{self as page_api, GetPageRequest, GetPageResponse}; +use pageserver_page_api as page_api; use utils::id::{TenantId, TimelineId}; use utils::shard::ShardIndex; @@ -358,9 +358,9 @@ pub struct StreamPool { } type StreamID = usize; -type RequestSender = Sender<(GetPageRequest, ResponseSender)>; -type RequestReceiver = Receiver<(GetPageRequest, ResponseSender)>; -type ResponseSender = oneshot::Sender>; +type RequestSender = Sender<(page_api::GetPageRequest, ResponseSender)>; +type RequestReceiver = Receiver<(page_api::GetPageRequest, ResponseSender)>; +type ResponseSender = oneshot::Sender>; struct StreamEntry { /// Sends caller requests to the stream task. The stream task exits when this is dropped. @@ -400,7 +400,10 @@ impl StreamPool { /// * Allow spinning up multiple streams concurrently, but don't overshoot limits. /// /// For now, we just do something simple and functional, but very inefficient (linear scan). - pub async fn send(&self, req: GetPageRequest) -> tonic::Result { + pub async fn send( + &self, + req: page_api::GetPageRequest, + ) -> tonic::Result { // Acquire a permit. For simplicity, we drop it when this method returns. This may exceed // the queue depth if a caller goes away while a request is in flight, but that's okay. We // do the same for queue depth tracking. diff --git a/pageserver/client_grpc/src/request_tracker.rs b/pageserver/client_grpc/src/request_tracker.rs deleted file mode 100644 index 52e17843ff..0000000000 --- a/pageserver/client_grpc/src/request_tracker.rs +++ /dev/null @@ -1,578 +0,0 @@ -//! The request tracker dispatches GetPage- and other requests to pageservers, managing a pool of -//! connections and gRPC streams. -//! -//! There is usually one global instance of ShardedRequestTracker in an application, in particular -//! in the neon extension's communicator process. The application calls the async functions in -//! ShardedRequestTracker, which routes them to the correct pageservers, taking sharding into -//! account. In the future, there can be multiple pageservers per shard, and RequestTracker manages -//! load balancing between them, but that's not implemented yet. - -use crate::AuthInterceptor; -use crate::ClientCacheOptions; -use crate::PageserverClientAggregateMetrics; -use crate::client_cache; -use crate::client_cache::ChannelFactory; -use crate::client_cache::ConnectionPool; -use pageserver_page_api::GetPageRequest; -use pageserver_page_api::GetPageResponse; -use pageserver_page_api::proto; -use pageserver_page_api::*; -use std::sync::Arc; -use std::sync::atomic::AtomicU64; -use tonic::{Request, transport::Channel}; - -use utils::shard::ShardIndex; - -use pageserver_page_api::proto::PageServiceClient; -use tokio_stream::wrappers::ReceiverStream; - -use tonic::{Code, Status}; - -use async_trait::async_trait; -use std::time::Duration; - -use client_cache::PooledItemFactory; - -/// StreamReturner represents a gRPC stream to a pageserver. -/// -/// To send a request: -/// 1. insert the request's ID, along with a channel to receive the response -/// 2. send the request to 'sender' -#[derive(Clone)] -pub struct StreamReturner { - sender: tokio::sync::mpsc::Sender, - #[allow(clippy::type_complexity)] - sender_hashmap: Arc< - tokio::sync::Mutex< - Option< - std::collections::HashMap< - u64, - tokio::sync::mpsc::Sender>, - >, - >, - >, - >, -} - -pub struct StreamFactory { - connection_pool: Arc>, - auth_interceptor: AuthInterceptor, - shard: ShardIndex, -} - -impl StreamFactory { - pub fn new( - connection_pool: Arc>, - auth_interceptor: AuthInterceptor, - shard: ShardIndex, - ) -> Self { - StreamFactory { - connection_pool, - auth_interceptor, - shard, - } - } -} - -#[async_trait] -impl PooledItemFactory for StreamFactory { - async fn create( - &self, - _connect_timeout: Duration, - ) -> Result, tokio::time::error::Elapsed> { - let pool_clone: Arc> = Arc::clone(&self.connection_pool); - let pooled_client = pool_clone.get_client().await; - let channel = pooled_client.unwrap().channel(); - let mut client = PageServiceClient::with_interceptor( - channel, - self.auth_interceptor.for_shard(self.shard), - ); - - let (sender, receiver) = tokio::sync::mpsc::channel::(1000); - let outbound = ReceiverStream::new(receiver); - - let client_resp = client.get_pages(Request::new(outbound)).await; - - match client_resp { - Err(status) => { - // TODO: Convert this error correctly - Ok(Err(tonic::Status::new( - status.code(), - format!("Failed to connect to pageserver: {}", status.message()), - ))) - } - Ok(resp) => { - let stream_returner = StreamReturner { - sender: sender.clone(), - sender_hashmap: Arc::new(tokio::sync::Mutex::new(Some( - std::collections::HashMap::new(), - ))), - }; - let map = Arc::clone(&stream_returner.sender_hashmap); - - tokio::spawn(async move { - let map_clone = Arc::clone(&map); - let mut inner = resp.into_inner(); - loop { - match inner.message().await { - Err(e) => { - tracing::info!("error received on getpage stream: {e}"); - break; // Exit the loop if no more messages - } - Ok(None) => { - break; // Sender closed the stream - } - Ok(Some(response)) => { - // look up stream in hash map - let mut hashmap = map_clone.lock().await; - let hashmap = - hashmap.as_mut().expect("no other task clears the hashmap"); - if let Some(sender) = hashmap.get(&response.request_id) { - // Send the response to the original request sender - if let Err(e) = sender.send(Ok(response.clone())).await { - eprintln!("Failed to send response: {e}"); - } - hashmap.remove(&response.request_id); - } else { - eprintln!( - "No sender found for request ID: {}", - response.request_id - ); - } - } - } - } - // Don't accept any more requests - - // Close every sender stream in the hashmap - let mut hashmap_opt = map_clone.lock().await; - let hashmap = hashmap_opt - .as_mut() - .expect("no other task clears the hashmap"); - for sender in hashmap.values() { - let error = Status::new(Code::Unknown, "Stream closed"); - if let Err(e) = sender.send(Err(error)).await { - eprintln!("Failed to send close response: {e}"); - } - } - *hashmap_opt = None; - }); - - Ok(Ok(stream_returner)) - } - } - } -} - -#[derive(Clone)] -pub struct RequestTracker { - _cur_id: Arc, - stream_pool: Arc>, - unary_pool: Arc>, - auth_interceptor: AuthInterceptor, - shard: ShardIndex, -} - -impl RequestTracker { - pub fn new( - stream_pool: Arc>, - unary_pool: Arc>, - auth_interceptor: AuthInterceptor, - shard: ShardIndex, - ) -> Self { - let cur_id = Arc::new(AtomicU64::new(0)); - - RequestTracker { - _cur_id: cur_id.clone(), - stream_pool, - unary_pool, - auth_interceptor, - shard, - } - } - - pub async fn send_process_check_rel_exists_request( - &self, - req: CheckRelExistsRequest, - ) -> Result { - loop { - let unary_pool = Arc::clone(&self.unary_pool); - let pooled_client = unary_pool.get_client().await.unwrap(); - let channel = pooled_client.channel(); - let mut ps_client = PageServiceClient::with_interceptor( - channel, - self.auth_interceptor.for_shard(self.shard), - ); - let request = proto::CheckRelExistsRequest::from(req); - let response = ps_client - .check_rel_exists(tonic::Request::new(request)) - .await; - - match response { - Err(status) => { - pooled_client.finish(Err(status.clone())).await; // Pass error to finish - continue; - } - Ok(resp) => { - pooled_client.finish(Ok(())).await; // Pass success to finish - return Ok(resp.get_ref().exists); - } - } - } - } - - pub async fn send_process_get_rel_size_request( - &self, - req: GetRelSizeRequest, - ) -> Result { - loop { - // Current sharding model assumes that all metadata is present only at shard 0. - let unary_pool = Arc::clone(&self.unary_pool); - let pooled_client = unary_pool.get_client().await.unwrap(); - let channel = pooled_client.channel(); - let mut ps_client = PageServiceClient::with_interceptor( - channel, - self.auth_interceptor.for_shard(self.shard), - ); - - let request = proto::GetRelSizeRequest::from(req); - let response = ps_client.get_rel_size(tonic::Request::new(request)).await; - - match response { - Err(status) => { - tracing::info!("send_process_get_rel_size_request: got error {status}, retrying"); - pooled_client.finish(Err(status.clone())).await; // Pass error to finish - continue; - } - Ok(resp) => { - pooled_client.finish(Ok(())).await; // Pass success to finish - return Ok(resp.get_ref().num_blocks); - } - } - } - } - - pub async fn send_process_get_dbsize_request( - &self, - req: GetDbSizeRequest, - ) -> Result { - loop { - // Current sharding model assumes that all metadata is present only at shard 0. - let unary_pool = Arc::clone(&self.unary_pool); - let pooled_client = unary_pool.get_client().await.unwrap(); - let channel = pooled_client.channel(); - let mut ps_client = PageServiceClient::with_interceptor( - channel, - self.auth_interceptor.for_shard(self.shard), - ); - - let request = proto::GetDbSizeRequest::from(req); - let response = ps_client.get_db_size(tonic::Request::new(request)).await; - - match response { - Err(status) => { - pooled_client.finish(Err(status.clone())).await; // Pass error to finish - continue; - } - Ok(resp) => { - pooled_client.finish(Ok(())).await; // Pass success to finish - return Ok(resp.get_ref().num_bytes); - } - } - } - } - - pub async fn send_getpage_request( - &mut self, - req: GetPageRequest, - ) -> Result { - loop { - let request = req.clone(); - // Increment cur_id - //let request_id = self.cur_id.fetch_add(1, Ordering::SeqCst) + 1; - let request_id = request.request_id; - let response_sender: tokio::sync::mpsc::Sender>; - let mut response_receiver: tokio::sync::mpsc::Receiver< - Result, - >; - - (response_sender, response_receiver) = tokio::sync::mpsc::channel(1); - //request.request_id = request_id; - - // Get a stream from the stream pool - let pool_clone = Arc::clone(&self.stream_pool); - let sender_stream_pool = pool_clone.get_client().await; - let stream_returner = match sender_stream_pool { - Ok(stream_ret) => stream_ret, - Err(_e) => { - // retry - continue; - } - }; - let returner = stream_returner.channel(); - let map = returner.sender_hashmap.clone(); - // Insert the response sender into the hashmap - { - if let Some(map_inner) = map.lock().await.as_mut() { - let old = map_inner.insert(request_id, response_sender); - - // request IDs must be unique - if old.is_some() { - panic!("request with ID {request_id} is already in-flight"); - } - } else { - // The stream was closed. Try a different one. - tracing::info!("stream was concurrently closed"); - continue; - } - } - let sent = returner - .sender - .send(proto::GetPageRequest::from(request)) - .await; - - if let Err(_e) = sent { - // Remove the request from the map if sending failed - { - if let Some(map_inner) = map.lock().await.as_mut() { - // remove from hashmap - map_inner.remove(&request_id); - } - } - stream_returner - .finish(Err(Status::new(Code::Unknown, "Failed to send request"))) - .await; - continue; - } - - let response = response_receiver.recv().await; - match response { - Some(resp) => { - match resp { - Err(_status) => { - // Handle the case where the response was not received - stream_returner - .finish(Err(Status::new( - Code::Unknown, - "Failed to receive response", - ))) - .await; - continue; - } - Ok(resp) => { - stream_returner.finish(Result::Ok(())).await; - return Ok(resp.clone().into()); - } - } - } - None => { - // Handle the case where the response channel was closed - stream_returner - .finish(Err(Status::new(Code::Unknown, "Response channel closed"))) - .await; - continue; - } - } - } - } -} - -struct ShardedRequestTrackerInner { - // Hashmap of shard index to RequestTracker - trackers: std::collections::HashMap, -} -pub struct ShardedRequestTracker { - inner: Arc>, - tcp_client_cache_options: ClientCacheOptions, - stream_client_cache_options: ClientCacheOptions, -} - -// -// TODO: Functions in the ShardedRequestTracker should be able to timeout and -// cancel a reqeust. The request should return an error if it is cancelled. -// - -impl Default for ShardedRequestTracker { - fn default() -> Self { - ShardedRequestTracker::new() - } -} - -impl ShardedRequestTracker { - pub fn new() -> Self { - // - // Default configuration for the client. These could be added to a config file - // - let tcp_client_cache_options = ClientCacheOptions { - max_delay_ms: 0, - drop_rate: 0.0, - hang_rate: 0.0, - connect_timeout: Duration::from_secs(1), - connect_backoff: Duration::from_millis(100), - max_consumers: 8, // Streams per connection - error_threshold: 10, - max_idle_duration: Duration::from_secs(5), - max_total_connections: 8, - }; - let stream_client_cache_options = ClientCacheOptions { - max_delay_ms: 0, - drop_rate: 0.0, - hang_rate: 0.0, - connect_timeout: Duration::from_secs(1), - connect_backoff: Duration::from_millis(100), - max_consumers: 64, // Requests per stream - error_threshold: 10, - max_idle_duration: Duration::from_secs(5), - max_total_connections: 64, // Total allowable number of streams - }; - ShardedRequestTracker { - inner: Arc::new(std::sync::Mutex::new(ShardedRequestTrackerInner { - trackers: std::collections::HashMap::new(), - })), - tcp_client_cache_options, - stream_client_cache_options, - } - } - - pub async fn update_shard_map( - &self, - shard_urls: std::collections::HashMap, - metrics: Option>, - tenant_id: String, - timeline_id: String, - auth_str: Option<&str>, - ) { - let mut trackers = std::collections::HashMap::new(); - for (shard, endpoint_url) in shard_urls { - // - // Create a pool of streams for streaming get_page requests - // - let channel_fact: Arc + Send + Sync> = - Arc::new(ChannelFactory::new( - endpoint_url.clone(), - self.tcp_client_cache_options.max_delay_ms, - self.tcp_client_cache_options.drop_rate, - self.tcp_client_cache_options.hang_rate, - )); - let new_pool = ConnectionPool::new( - Arc::clone(&channel_fact), - self.tcp_client_cache_options.connect_timeout, - self.tcp_client_cache_options.connect_backoff, - self.tcp_client_cache_options.max_consumers, - self.tcp_client_cache_options.error_threshold, - self.tcp_client_cache_options.max_idle_duration, - self.tcp_client_cache_options.max_total_connections, - metrics.clone(), - ); - - let auth_interceptor = - AuthInterceptor::new(tenant_id.as_str(), timeline_id.as_str(), auth_str); - - let stream_pool = ConnectionPool::::new( - Arc::new(StreamFactory::new( - new_pool.clone(), - auth_interceptor.clone(), - ShardIndex::unsharded(), - )), - self.stream_client_cache_options.connect_timeout, - self.stream_client_cache_options.connect_backoff, - self.stream_client_cache_options.max_consumers, - self.stream_client_cache_options.error_threshold, - self.stream_client_cache_options.max_idle_duration, - self.stream_client_cache_options.max_total_connections, - metrics.clone(), - ); - - // - // Create a client pool for unary requests - // - - let unary_pool = ConnectionPool::new( - Arc::clone(&channel_fact), - self.tcp_client_cache_options.connect_timeout, - self.tcp_client_cache_options.connect_backoff, - self.tcp_client_cache_options.max_consumers, - self.tcp_client_cache_options.error_threshold, - self.tcp_client_cache_options.max_idle_duration, - self.tcp_client_cache_options.max_total_connections, - metrics.clone(), - ); - // - // Create a new RequestTracker for this shard - // - let new_tracker = RequestTracker::new(stream_pool, unary_pool, auth_interceptor, shard); - trackers.insert(shard, new_tracker); - } - let mut inner = self.inner.lock().unwrap(); - inner.trackers = trackers; - } - - pub async fn get_page(&self, req: GetPageRequest) -> Result { - // Get shard index from the request and look up the RequestTracker instance for that shard - let shard_index = ShardIndex::unsharded(); // TODO! - let mut tracker = self.lookup_tracker_for_shard(shard_index)?; - - let response = tracker.send_getpage_request(req).await; - match response { - Ok(resp) => Ok(resp), - Err(e) => Err(tonic::Status::unknown(format!("Failed to get page: {e}"))), - } - } - - pub async fn process_get_dbsize_request( - &self, - request: GetDbSizeRequest, - ) -> Result { - // Current sharding model assumes that all metadata is present only at shard 0. - let tracker = self.lookup_tracker_for_shard(ShardIndex::unsharded())?; - - let response = tracker.send_process_get_dbsize_request(request).await; - match response { - Ok(resp) => Ok(resp), - Err(e) => Err(e), - } - } - - pub async fn process_get_rel_size_request( - &self, - request: GetRelSizeRequest, - ) -> Result { - // Current sharding model assumes that all metadata is present only at shard 0. - let tracker = self.lookup_tracker_for_shard(ShardIndex::unsharded())?; - - let response = tracker.send_process_get_rel_size_request(request).await; - match response { - Ok(resp) => Ok(resp), - Err(e) => Err(e), - } - } - - pub async fn process_check_rel_exists_request( - &self, - request: CheckRelExistsRequest, - ) -> Result { - // Current sharding model assumes that all metadata is present only at shard 0. - let tracker = self.lookup_tracker_for_shard(ShardIndex::unsharded())?; - - let response = tracker.send_process_check_rel_exists_request(request).await; - match response { - Ok(resp) => Ok(resp), - Err(e) => Err(e), - } - } - - #[allow(clippy::result_large_err)] - fn lookup_tracker_for_shard( - &self, - shard_index: ShardIndex, - ) -> Result { - let inner = self.inner.lock().unwrap(); - if let Some(t) = inner.trackers.get(&shard_index) { - Ok(t.clone()) - } else { - Err(tonic::Status::not_found(format!( - "Shard {shard_index} not found", - ))) - } - } -} diff --git a/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs b/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs index e5d6e28f46..01c6bea2e5 100644 --- a/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs +++ b/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs @@ -26,17 +26,6 @@ use utils::id::TenantTimelineId; use utils::lsn::Lsn; use utils::shard::ShardIndex; -use axum::Router; -use axum::body::Body; -use axum::extract::State; -use axum::response::Response; - -use http::StatusCode; -use http::header::CONTENT_TYPE; - -use metrics::proto::MetricFamily; -use metrics::{Encoder, TextEncoder}; - use crate::util::tokio_thread_local_stats::AllThreadLocalStats; use crate::util::{request_stats, tokio_thread_local_stats}; @@ -185,62 +174,12 @@ pub(crate) fn main(args: Args) -> anyhow::Result<()> { main_impl(args, thread_local_stats) }) } -async fn get_metrics( - State(state): State>, -) -> Response { - let metrics = state.collect(); - - info!("metrics: {metrics:?}"); - // When we call TextEncoder::encode() below, it will immediately return an - // error if a metric family has no metrics, so we need to preemptively - // filter out metric families with no metrics. - let metrics = metrics - .into_iter() - .filter(|m| !m.get_metric().is_empty()) - .collect::>(); - - let encoder = TextEncoder::new(); - let mut buffer = vec![]; - - if let Err(e) = encoder.encode(&metrics, &mut buffer) { - Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .header(CONTENT_TYPE, "application/text") - .body(Body::from(e.to_string())) - .unwrap() - } else { - Response::builder() - .status(StatusCode::OK) - .header(CONTENT_TYPE, encoder.format_type()) - .body(Body::from(buffer)) - .unwrap() - } -} - async fn main_impl( args: Args, all_thread_local_stats: AllThreadLocalStats, ) -> anyhow::Result<()> { let args: &'static Args = Box::leak(Box::new(args)); - // Vector of pageserver clients - let client_metrics = Arc::new(pageserver_client_grpc::PageserverClientAggregateMetrics::new()); - - use axum::routing::get; - let app = Router::new() - .route("/metrics", get(get_metrics)) - .with_state(client_metrics.clone()); - - // TODO: make configurable. Or listen on unix domain socket? - let listener = tokio::net::TcpListener::bind("127.0.0.1:9090") - .await - .unwrap(); - - tokio::spawn(async { - tracing::info!("metrics listener spawned"); - axum::serve(listener, app).await.unwrap() - }); - let mgmt_api_client = Arc::new(pageserver_client::mgmt_api::Client::new( reqwest::Client::new(), // TODO: support ssl_ca_file for https APIs in pagebench. args.mgmt_api_endpoint.clone(), diff --git a/pgxn/neon/communicator/src/worker_process/main_loop.rs b/pgxn/neon/communicator/src/worker_process/main_loop.rs index 17dad6a560..24be5f4987 100644 --- a/pgxn/neon/communicator/src/worker_process/main_loop.rs +++ b/pgxn/neon/communicator/src/worker_process/main_loop.rs @@ -13,7 +13,7 @@ use crate::integrated_cache::{CacheResult, IntegratedCacheWriteAccess}; use crate::neon_request::{CGetPageVRequest, CPrefetchVRequest}; use crate::neon_request::{NeonIORequest, NeonIOResult}; use crate::worker_process::in_progress_ios::{RequestInProgressKey, RequestInProgressTable}; -use pageserver_client_grpc::client::PageserverClient; +use pageserver_client_grpc::PageserverClient; use pageserver_page_api as page_api; use metrics::{IntCounter, IntCounterVec};