diff --git a/pageserver/client_grpc/Cargo.toml b/pageserver/client_grpc/Cargo.toml index f474006f29..4c619a0bf3 100644 --- a/pageserver/client_grpc/Cargo.toml +++ b/pageserver/client_grpc/Cargo.toml @@ -10,6 +10,8 @@ http.workspace = true thiserror.workspace = true tonic.workspace = true tracing.workspace = true +tokio = { version = "1.43.1", features = ["macros", "net", "io-util", "rt", "rt-multi-thread"] } +uuid = { version = "1", features = ["v4"] } pageserver_page_api.workspace = true utils.workspace = true diff --git a/pageserver/client_grpc/src/client_cache.rs b/pageserver/client_grpc/src/client_cache.rs new file mode 100644 index 0000000000..41811aae63 --- /dev/null +++ b/pageserver/client_grpc/src/client_cache.rs @@ -0,0 +1,262 @@ +use std::{collections::HashMap, sync::Arc, time::{Duration, Instant}}; +use tokio::{sync::{Mutex, Notify, mpsc, watch}, time::sleep}; +use tonic::transport::{Channel, Endpoint}; + +use tracing::info; +use uuid; + +/// A pooled gRPC client with capacity tracking and error handling. +pub struct ConnectionPool { + inner: Mutex, + + // Config options that apply to each connection + endpoint: String, + max_consumers: usize, + error_threshold: usize, + connect_timeout: Duration, + connect_backoff: Duration, + + // This notify is signaled when a connection is released or created. + notify: Notify, + + // When it is time to create a new connection for the pool, we signal + // a watch and a connection creation async wakes up and does the work. + cc_watch_tx: watch::Sender, + cc_watch_rx: watch::Receiver, + + // To acquire a connection from the pool, send a request + // to this mpsc, and wait for a response. + request_tx: mpsc::Sender>, +} + +struct Inner { + entries: HashMap, + + // This is updated when a connection is dropped, or we fail + // to create a new connection. + last_connect_failure: Option, +} + +struct ConnectionEntry { + channel: Channel, + active_consumers: usize, + consecutive_successes: usize, + consecutive_errors: usize, +} + +/// A client borrowed from the pool. +pub struct PooledClient { + pub channel: Channel, + pool: Arc, + id: uuid::Uuid, +} + +impl ConnectionPool { + /// Create a new pool and spawn the background task that handles requests. + pub fn new( + endpoint: &String, + max_consumers: usize, + error_threshold: usize, + connect_timeout: Duration, + connect_backoff: Duration, + ) -> Arc { + let (request_tx, mut request_rx) = mpsc::channel::>(100); + let (watch_tx, watch_rx) = watch::channel(false); + let pool = Arc::new(Self { + inner: Mutex::new(Inner { + entries: HashMap::new(), + last_connect_failure: None, + }), + notify: Notify::new(), + cc_watch_tx: watch_tx, + cc_watch_rx: watch_rx, + endpoint: endpoint.clone(), + max_consumers: max_consumers, + error_threshold, + connect_timeout, + connect_backoff, + request_tx, + }); + + // + // Background task to handle requests and create connections. + // + // TODO: These should be canceled when the ConnectionPool is dropped + // + + let bg_cc_pool = Arc::clone(&pool); + tokio::spawn(async move { + loop { + bg_cc_pool.create_connection().await; + } + }); + + let bg_pool = Arc::clone(&pool); + tokio::spawn(async move { + while let Some(responder) = request_rx.recv().await { + // TODO: This call should time out and return an error + let (id, channel) = bg_pool.acquire_connection().await; + let client = PooledClient { channel, pool: Arc::clone(&bg_pool), id }; + let _ = responder.send(client).await; + } + }); + + pool + } + + async fn acquire_connection(&self) -> (uuid::Uuid, Channel) { + loop { + // Reuse an existing healthy connection if available + { + let mut inner = self.inner.lock().await; + // TODO: Use a heap, although the number of connections is small + if let Some((&id, entry)) = inner.entries + .iter_mut() + .filter(|(_, e)| e.active_consumers < self.max_consumers) + .filter(|(_, e)| e.consecutive_errors < self.error_threshold) + .max_by_key(|(_, e)| e.active_consumers) + { + entry.active_consumers += 1; + return (id, entry.channel.clone()); + } + // There is no usable connection, so notify the connection creation async to make one. (It is + // possible that a consumer will release a connection while the new one is being created, in + // which case we will use it right away, but the new connection will be created anyway.) + let _ = self.cc_watch_tx.send(true); + + } + // Wait for a new connection, or for one of the consumers to release a connection + // TODO: Put this notify in a timeout + self.notify.notified().await; + } + } + + async fn create_connection(&self) -> () { + + // Wait to be signalled to create a connection. + let mut recv = self.cc_watch_tx.subscribe(); + if !*self.cc_watch_rx.borrow() { + while recv.changed().await.is_ok() { + if *self.cc_watch_rx.borrow() { + break; + } + } + } + + loop { + // + // TODO: This would be more accurate if it waited for a timer, and the timer + // was reset when a connection failed. Using timestamps, we may miss new failures + // that occur while we are sleeping. + // + // TODO: Should the backoff be exponential? + // + if let Some(delay) = { + let inner = self.inner.lock().await; + inner.last_connect_failure.and_then(|at| { + (at.elapsed() < self.connect_backoff) + .then(|| self.connect_backoff - at.elapsed()) + }) + } { + sleep(delay).await; + } + // + // Create a new connection. + // + // The connect timeout is also the timeout for an individual gRPC request + // on this connection. (Requests made later on this channel will time out + // with the same timeout.) + // + let attempt = tokio::time::timeout( + self.connect_timeout, + Endpoint::from_shared(self.endpoint.clone()) + .expect("invalid endpoint") + .timeout(self.connect_timeout) + .connect(), + ).await; + + match attempt { + Ok(Ok(channel)) => { + { + let mut inner = self.inner.lock().await; + let id = uuid::Uuid::new_v4(); + inner.entries.insert(id, ConnectionEntry { + channel: channel.clone(), + active_consumers: 0, + consecutive_successes: 0, + consecutive_errors: 0, + }); + self.notify.notify_one(); + let _ = self.cc_watch_tx.send(false); + return; + }; + } + Ok(Err(_)) | Err(_) => { + let mut inner = self.inner.lock().await; + inner.last_connect_failure = Some(Instant::now()); + } + } + } + } + + /// Get a client we can use to send gRPC messages. + pub async fn get_client(&self) -> PooledClient { + let (resp_tx, mut resp_rx) = mpsc::channel(1); + self.request_tx.send(resp_tx).await.expect("ConnectionPool task has shut down"); + resp_rx.recv().await.expect("ConnectionPool task has shut down") + } + + /// Return client to the pool, indicating success or error. + pub async fn return_client(&self, id: uuid::Uuid, success: bool) { + let mut inner = self.inner.lock().await; + let mut new_failure = false; + if let Some(entry) = inner.entries.get_mut(&id) { + // TODO: This should be a debug_assert + if entry.active_consumers <= 0 { + panic!("A consumer completed when active_consumers was zero!") + } + entry.active_consumers = entry.active_consumers - 1; + if entry.consecutive_errors < self.error_threshold { + if success { + entry.consecutive_successes += 1; + entry.consecutive_errors = 0; + } else { + entry.consecutive_errors += 1; + entry.consecutive_successes = 0; + if entry.consecutive_errors == self.error_threshold { + new_failure = true; + } + } + } + // + // Too many errors on this connection. If there are no active users, + // remove it. Otherwise just wait for active_consumers to go to zero. + // This connection will not be selected for new consumers. + // + if entry.consecutive_errors == self.error_threshold { + let remove = entry.active_consumers; + if new_failure { + inner.last_connect_failure = Some(Instant::now()); + info!("Connection {} has failed", id); + } + if remove == 0 { + info!("Removing connection {} due to too many errors", id); + inner.entries.remove(&id); + } + } else { + self.notify.notify_one(); + } + } + } +} + +impl PooledClient { + pub fn channel(&self) -> Channel { + return self.channel.clone(); + } + + pub async fn finish(self, result: Result<(), tonic::Status>) { + self.pool.return_client(self.id, result.is_ok()).await; + } +} + diff --git a/pageserver/client_grpc/src/lib.rs b/pageserver/client_grpc/src/lib.rs index dc4cd09ada..0bcfda832d 100644 --- a/pageserver/client_grpc/src/lib.rs +++ b/pageserver/client_grpc/src/lib.rs @@ -5,12 +5,13 @@ //! use std::collections::HashMap; use std::sync::RwLock; +use std::time::Duration; +use std::sync::Arc; use bytes::Bytes; use futures::Stream; use thiserror::Error; use tonic::metadata::AsciiMetadataValue; -use tonic::transport::Channel; use pageserver_page_api::model::*; use pageserver_page_api::proto; @@ -18,9 +19,10 @@ use pageserver_page_api::proto; use pageserver_page_api::proto::PageServiceClient; use utils::shard::ShardIndex; -type MyPageServiceClient = pageserver_page_api::proto::PageServiceClient< - tonic::service::interceptor::InterceptedService, ->; +use tracing::info; + + +mod client_cache; #[derive(Error, Debug)] pub enum PageserverClientError { @@ -43,7 +45,7 @@ pub struct PageserverClient { shard_map: HashMap, - channels: RwLock>, + channels: RwLock>>, auth_interceptor: AuthInterceptor, } @@ -73,11 +75,17 @@ impl PageserverClient { // Current sharding model assumes that all metadata is present only at shard 0. let shard = ShardIndex::unsharded(); - let mut client = self.get_client(shard).await?; + let pooled_client = self.get_client(shard).await; + let chan = pooled_client.channel(); + + let mut client = + PageServiceClient::with_interceptor(chan, self.auth_interceptor.for_shard(shard)); let request = proto::RelExistsRequest::from(request); let response = client.rel_exists(tonic::Request::new(request)).await?; + // TODO: check for an error and pass it to "finish" + pooled_client.finish(Ok(())).await; Ok(response.get_ref().exists) } @@ -88,11 +96,17 @@ impl PageserverClient { // Current sharding model assumes that all metadata is present only at shard 0. let shard = ShardIndex::unsharded(); - let mut client = self.get_client(shard).await?; + let pooled_client = self.get_client(shard).await; + let chan = pooled_client.channel(); + + let mut client = + PageServiceClient::with_interceptor(chan, self.auth_interceptor.for_shard(shard)); let request = proto::RelSizeRequest::from(request); let response = client.rel_size(tonic::Request::new(request)).await?; + // TODO: check for an error and pass it to "finish" + pooled_client.finish(Ok(())).await; Ok(response.get_ref().num_blocks) } @@ -100,23 +114,26 @@ impl PageserverClient { // FIXME: calculate the shard number correctly let shard = ShardIndex::unsharded(); - let mut client = self.get_client(shard).await?; + let pooled_client = self.get_client(shard).await; + let chan = pooled_client.channel(); + + let mut client = + PageServiceClient::with_interceptor(chan, self.auth_interceptor.for_shard(shard)); let request = proto::GetPageRequest::from(request); - let response = client.get_page(tonic::Request::new(request)).await?; - let response: GetPageResponse = response.into_inner().try_into()?; - if response.status != GetPageStatus::Ok { - return Err(PageserverClientError::RequestError(tonic::Status::new( - tonic::Code::Internal, - format!( - "{:?} {}", - response.status, - response.reason.unwrap_or_default() - ), - ))); + let response = client.get_page(tonic::Request::new(request)).await; + match response { + Err(status) => { + pooled_client.finish(Err(status.clone())).await; + return Err(PageserverClientError::RequestError(status)); + } + Ok(resp) => { + pooled_client.finish(Ok(())).await; + let response: GetPageResponse = resp.into_inner().try_into()?; + return Ok(response.page_image); + } } - Ok(response.page_image) } // TODO: this should use model::GetPageRequest and GetPageResponse @@ -127,12 +144,24 @@ impl PageserverClient { tonic::Response>, PageserverClientError, > { + + // Print a debug message // FIXME: calculate the shard number correctly let shard = ShardIndex::unsharded(); - let mut client = self.get_client(shard).await?; + let pooled_client = self.get_client(shard).await; + let chan = pooled_client.channel(); + + let mut client = + PageServiceClient::with_interceptor(chan, self.auth_interceptor.for_shard(shard)); + + // Check for an error return from get_pages + // Declare response + + // TODO: check for an error and pass it to "finish" + pooled_client.finish(Ok(())).await; + return Ok(client.get_pages(tonic::Request::new(requests)).await?); - Ok(client.get_pages(tonic::Request::new(requests)).await?) } /// Process a request to get the size of a database. @@ -142,12 +171,17 @@ impl PageserverClient { ) -> Result { // Current sharding model assumes that all metadata is present only at shard 0. let shard = ShardIndex::unsharded(); + let pooled_client = self.get_client(shard).await; + let chan = pooled_client.channel(); - let mut client = self.get_client(shard).await?; + let mut client = + PageServiceClient::with_interceptor(chan, self.auth_interceptor.for_shard(shard)); let request = proto::DbSizeRequest::from(request); let response = client.db_size(tonic::Request::new(request)).await?; + // TODO: check for an error and pass it to "finish" + pooled_client.finish(Ok(())).await; Ok(response.get_ref().num_bytes) } @@ -163,7 +197,12 @@ impl PageserverClient { // Current sharding model assumes that all metadata is present only at shard 0. let shard = ShardIndex::unsharded(); - let mut client = self.get_client(shard).await?; + let pooled_client = self.get_client(shard).await; + let chan = pooled_client.channel(); + + let mut client = + PageServiceClient::with_interceptor(chan, self.auth_interceptor.for_shard(shard)); + if gzip { client = client.accept_compressed(tonic::codec::CompressionEncoding::Gzip); } @@ -171,47 +210,45 @@ impl PageserverClient { let request = proto::GetBaseBackupRequest::from(request); let response = client.get_base_backup(tonic::Request::new(request)).await?; + // TODO: check for an error and pass it to "finish" + pooled_client.finish(Ok(())).await; Ok(response) } /// Get a client for given shard /// - /// This implements very basic caching. If we already have a client for the given shard, - /// reuse it. If not, create a new client and put it to the cache. + /// Get a client from the pool for this shard, also creating the pool if it doesn't exist. + /// async fn get_client( &self, shard: ShardIndex, - ) -> Result { - let reused_channel: Option = { - let channels = self.channels.read().unwrap(); + ) -> client_cache::PooledClient { + let reused_pool: Option> = { + let channels = self.channels.read().unwrap(); channels.get(&shard).cloned() }; - let channel = if let Some(reused_channel) = reused_channel { - reused_channel - } else { - let endpoint: tonic::transport::Endpoint = self - .shard_map - .get(&shard) - .expect("no url for shard {shard}") - .parse()?; - let channel = endpoint.connect().await?; - - // Insert it to the cache so that it can be reused on subsequent calls. It's possible - // that another thread did the same concurrently, in which case we will overwrite the - // client in the cache. - { - let mut channels = self.channels.write().unwrap(); - channels.insert(shard, channel.clone()); + let usable_pool : Arc; + match reused_pool { + Some(pool) => { + let pooled_client = pool.get_client().await; + return pooled_client; } - channel - }; + None => { + let new_pool = client_cache::ConnectionPool::new( + self.shard_map.get(&shard).unwrap(), + 5000, 5, Duration::from_millis(200), Duration::from_secs(1)); + let mut write_pool = self.channels.write().unwrap(); + write_pool.insert(shard, new_pool.clone()); + usable_pool = new_pool.clone(); + } + } - let client = - PageServiceClient::with_interceptor(channel, self.auth_interceptor.for_shard(shard)); - Ok(client) + let pooled_client = usable_pool.get_client().await; + return pooled_client; } + } /// Inject tenant_id, timeline_id and authentication token to all pageserver requests.