mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-13 16:32:56 +00:00
Add a connection pool module to the grpc client.
This commit is contained in:
@@ -10,6 +10,8 @@ http.workspace = true
|
||||
thiserror.workspace = true
|
||||
tonic.workspace = true
|
||||
tracing.workspace = true
|
||||
tokio = { version = "1.43.1", features = ["macros", "net", "io-util", "rt", "rt-multi-thread"] }
|
||||
uuid = { version = "1", features = ["v4"] }
|
||||
|
||||
pageserver_page_api.workspace = true
|
||||
utils.workspace = true
|
||||
|
||||
262
pageserver/client_grpc/src/client_cache.rs
Normal file
262
pageserver/client_grpc/src/client_cache.rs
Normal file
@@ -0,0 +1,262 @@
|
||||
use std::{collections::HashMap, sync::Arc, time::{Duration, Instant}};
|
||||
use tokio::{sync::{Mutex, Notify, mpsc, watch}, time::sleep};
|
||||
use tonic::transport::{Channel, Endpoint};
|
||||
|
||||
use tracing::info;
|
||||
use uuid;
|
||||
|
||||
/// A pooled gRPC client with capacity tracking and error handling.
|
||||
pub struct ConnectionPool {
|
||||
inner: Mutex<Inner>,
|
||||
|
||||
// Config options that apply to each connection
|
||||
endpoint: String,
|
||||
max_consumers: usize,
|
||||
error_threshold: usize,
|
||||
connect_timeout: Duration,
|
||||
connect_backoff: Duration,
|
||||
|
||||
// This notify is signaled when a connection is released or created.
|
||||
notify: Notify,
|
||||
|
||||
// When it is time to create a new connection for the pool, we signal
|
||||
// a watch and a connection creation async wakes up and does the work.
|
||||
cc_watch_tx: watch::Sender<bool>,
|
||||
cc_watch_rx: watch::Receiver<bool>,
|
||||
|
||||
// To acquire a connection from the pool, send a request
|
||||
// to this mpsc, and wait for a response.
|
||||
request_tx: mpsc::Sender<mpsc::Sender<PooledClient>>,
|
||||
}
|
||||
|
||||
struct Inner {
|
||||
entries: HashMap<uuid::Uuid, ConnectionEntry>,
|
||||
|
||||
// This is updated when a connection is dropped, or we fail
|
||||
// to create a new connection.
|
||||
last_connect_failure: Option<Instant>,
|
||||
}
|
||||
|
||||
struct ConnectionEntry {
|
||||
channel: Channel,
|
||||
active_consumers: usize,
|
||||
consecutive_successes: usize,
|
||||
consecutive_errors: usize,
|
||||
}
|
||||
|
||||
/// A client borrowed from the pool.
|
||||
pub struct PooledClient {
|
||||
pub channel: Channel,
|
||||
pool: Arc<ConnectionPool>,
|
||||
id: uuid::Uuid,
|
||||
}
|
||||
|
||||
impl ConnectionPool {
|
||||
/// Create a new pool and spawn the background task that handles requests.
|
||||
pub fn new(
|
||||
endpoint: &String,
|
||||
max_consumers: usize,
|
||||
error_threshold: usize,
|
||||
connect_timeout: Duration,
|
||||
connect_backoff: Duration,
|
||||
) -> Arc<Self> {
|
||||
let (request_tx, mut request_rx) = mpsc::channel::<mpsc::Sender<PooledClient>>(100);
|
||||
let (watch_tx, watch_rx) = watch::channel(false);
|
||||
let pool = Arc::new(Self {
|
||||
inner: Mutex::new(Inner {
|
||||
entries: HashMap::new(),
|
||||
last_connect_failure: None,
|
||||
}),
|
||||
notify: Notify::new(),
|
||||
cc_watch_tx: watch_tx,
|
||||
cc_watch_rx: watch_rx,
|
||||
endpoint: endpoint.clone(),
|
||||
max_consumers: max_consumers,
|
||||
error_threshold,
|
||||
connect_timeout,
|
||||
connect_backoff,
|
||||
request_tx,
|
||||
});
|
||||
|
||||
//
|
||||
// Background task to handle requests and create connections.
|
||||
//
|
||||
// TODO: These should be canceled when the ConnectionPool is dropped
|
||||
//
|
||||
|
||||
let bg_cc_pool = Arc::clone(&pool);
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
bg_cc_pool.create_connection().await;
|
||||
}
|
||||
});
|
||||
|
||||
let bg_pool = Arc::clone(&pool);
|
||||
tokio::spawn(async move {
|
||||
while let Some(responder) = request_rx.recv().await {
|
||||
// TODO: This call should time out and return an error
|
||||
let (id, channel) = bg_pool.acquire_connection().await;
|
||||
let client = PooledClient { channel, pool: Arc::clone(&bg_pool), id };
|
||||
let _ = responder.send(client).await;
|
||||
}
|
||||
});
|
||||
|
||||
pool
|
||||
}
|
||||
|
||||
async fn acquire_connection(&self) -> (uuid::Uuid, Channel) {
|
||||
loop {
|
||||
// Reuse an existing healthy connection if available
|
||||
{
|
||||
let mut inner = self.inner.lock().await;
|
||||
// TODO: Use a heap, although the number of connections is small
|
||||
if let Some((&id, entry)) = inner.entries
|
||||
.iter_mut()
|
||||
.filter(|(_, e)| e.active_consumers < self.max_consumers)
|
||||
.filter(|(_, e)| e.consecutive_errors < self.error_threshold)
|
||||
.max_by_key(|(_, e)| e.active_consumers)
|
||||
{
|
||||
entry.active_consumers += 1;
|
||||
return (id, entry.channel.clone());
|
||||
}
|
||||
// There is no usable connection, so notify the connection creation async to make one. (It is
|
||||
// possible that a consumer will release a connection while the new one is being created, in
|
||||
// which case we will use it right away, but the new connection will be created anyway.)
|
||||
let _ = self.cc_watch_tx.send(true);
|
||||
|
||||
}
|
||||
// Wait for a new connection, or for one of the consumers to release a connection
|
||||
// TODO: Put this notify in a timeout
|
||||
self.notify.notified().await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_connection(&self) -> () {
|
||||
|
||||
// Wait to be signalled to create a connection.
|
||||
let mut recv = self.cc_watch_tx.subscribe();
|
||||
if !*self.cc_watch_rx.borrow() {
|
||||
while recv.changed().await.is_ok() {
|
||||
if *self.cc_watch_rx.borrow() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
loop {
|
||||
//
|
||||
// TODO: This would be more accurate if it waited for a timer, and the timer
|
||||
// was reset when a connection failed. Using timestamps, we may miss new failures
|
||||
// that occur while we are sleeping.
|
||||
//
|
||||
// TODO: Should the backoff be exponential?
|
||||
//
|
||||
if let Some(delay) = {
|
||||
let inner = self.inner.lock().await;
|
||||
inner.last_connect_failure.and_then(|at| {
|
||||
(at.elapsed() < self.connect_backoff)
|
||||
.then(|| self.connect_backoff - at.elapsed())
|
||||
})
|
||||
} {
|
||||
sleep(delay).await;
|
||||
}
|
||||
//
|
||||
// Create a new connection.
|
||||
//
|
||||
// The connect timeout is also the timeout for an individual gRPC request
|
||||
// on this connection. (Requests made later on this channel will time out
|
||||
// with the same timeout.)
|
||||
//
|
||||
let attempt = tokio::time::timeout(
|
||||
self.connect_timeout,
|
||||
Endpoint::from_shared(self.endpoint.clone())
|
||||
.expect("invalid endpoint")
|
||||
.timeout(self.connect_timeout)
|
||||
.connect(),
|
||||
).await;
|
||||
|
||||
match attempt {
|
||||
Ok(Ok(channel)) => {
|
||||
{
|
||||
let mut inner = self.inner.lock().await;
|
||||
let id = uuid::Uuid::new_v4();
|
||||
inner.entries.insert(id, ConnectionEntry {
|
||||
channel: channel.clone(),
|
||||
active_consumers: 0,
|
||||
consecutive_successes: 0,
|
||||
consecutive_errors: 0,
|
||||
});
|
||||
self.notify.notify_one();
|
||||
let _ = self.cc_watch_tx.send(false);
|
||||
return;
|
||||
};
|
||||
}
|
||||
Ok(Err(_)) | Err(_) => {
|
||||
let mut inner = self.inner.lock().await;
|
||||
inner.last_connect_failure = Some(Instant::now());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a client we can use to send gRPC messages.
|
||||
pub async fn get_client(&self) -> PooledClient {
|
||||
let (resp_tx, mut resp_rx) = mpsc::channel(1);
|
||||
self.request_tx.send(resp_tx).await.expect("ConnectionPool task has shut down");
|
||||
resp_rx.recv().await.expect("ConnectionPool task has shut down")
|
||||
}
|
||||
|
||||
/// Return client to the pool, indicating success or error.
|
||||
pub async fn return_client(&self, id: uuid::Uuid, success: bool) {
|
||||
let mut inner = self.inner.lock().await;
|
||||
let mut new_failure = false;
|
||||
if let Some(entry) = inner.entries.get_mut(&id) {
|
||||
// TODO: This should be a debug_assert
|
||||
if entry.active_consumers <= 0 {
|
||||
panic!("A consumer completed when active_consumers was zero!")
|
||||
}
|
||||
entry.active_consumers = entry.active_consumers - 1;
|
||||
if entry.consecutive_errors < self.error_threshold {
|
||||
if success {
|
||||
entry.consecutive_successes += 1;
|
||||
entry.consecutive_errors = 0;
|
||||
} else {
|
||||
entry.consecutive_errors += 1;
|
||||
entry.consecutive_successes = 0;
|
||||
if entry.consecutive_errors == self.error_threshold {
|
||||
new_failure = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
//
|
||||
// Too many errors on this connection. If there are no active users,
|
||||
// remove it. Otherwise just wait for active_consumers to go to zero.
|
||||
// This connection will not be selected for new consumers.
|
||||
//
|
||||
if entry.consecutive_errors == self.error_threshold {
|
||||
let remove = entry.active_consumers;
|
||||
if new_failure {
|
||||
inner.last_connect_failure = Some(Instant::now());
|
||||
info!("Connection {} has failed", id);
|
||||
}
|
||||
if remove == 0 {
|
||||
info!("Removing connection {} due to too many errors", id);
|
||||
inner.entries.remove(&id);
|
||||
}
|
||||
} else {
|
||||
self.notify.notify_one();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PooledClient {
|
||||
pub fn channel(&self) -> Channel {
|
||||
return self.channel.clone();
|
||||
}
|
||||
|
||||
pub async fn finish(self, result: Result<(), tonic::Status>) {
|
||||
self.pool.return_client(self.id, result.is_ok()).await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,12 +5,13 @@
|
||||
//!
|
||||
use std::collections::HashMap;
|
||||
use std::sync::RwLock;
|
||||
use std::time::Duration;
|
||||
use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
use futures::Stream;
|
||||
use thiserror::Error;
|
||||
use tonic::metadata::AsciiMetadataValue;
|
||||
use tonic::transport::Channel;
|
||||
|
||||
use pageserver_page_api::model::*;
|
||||
use pageserver_page_api::proto;
|
||||
@@ -18,9 +19,10 @@ use pageserver_page_api::proto;
|
||||
use pageserver_page_api::proto::PageServiceClient;
|
||||
use utils::shard::ShardIndex;
|
||||
|
||||
type MyPageServiceClient = pageserver_page_api::proto::PageServiceClient<
|
||||
tonic::service::interceptor::InterceptedService<tonic::transport::Channel, AuthInterceptor>,
|
||||
>;
|
||||
use tracing::info;
|
||||
|
||||
|
||||
mod client_cache;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum PageserverClientError {
|
||||
@@ -43,7 +45,7 @@ pub struct PageserverClient {
|
||||
|
||||
shard_map: HashMap<ShardIndex, String>,
|
||||
|
||||
channels: RwLock<HashMap<ShardIndex, Channel>>,
|
||||
channels: RwLock<HashMap<ShardIndex, Arc<client_cache::ConnectionPool>>>,
|
||||
|
||||
auth_interceptor: AuthInterceptor,
|
||||
}
|
||||
@@ -73,11 +75,17 @@ impl PageserverClient {
|
||||
// Current sharding model assumes that all metadata is present only at shard 0.
|
||||
let shard = ShardIndex::unsharded();
|
||||
|
||||
let mut client = self.get_client(shard).await?;
|
||||
let pooled_client = self.get_client(shard).await;
|
||||
let chan = pooled_client.channel();
|
||||
|
||||
let mut client =
|
||||
PageServiceClient::with_interceptor(chan, self.auth_interceptor.for_shard(shard));
|
||||
|
||||
let request = proto::RelExistsRequest::from(request);
|
||||
let response = client.rel_exists(tonic::Request::new(request)).await?;
|
||||
|
||||
// TODO: check for an error and pass it to "finish"
|
||||
pooled_client.finish(Ok(())).await;
|
||||
Ok(response.get_ref().exists)
|
||||
}
|
||||
|
||||
@@ -88,11 +96,17 @@ impl PageserverClient {
|
||||
// Current sharding model assumes that all metadata is present only at shard 0.
|
||||
let shard = ShardIndex::unsharded();
|
||||
|
||||
let mut client = self.get_client(shard).await?;
|
||||
let pooled_client = self.get_client(shard).await;
|
||||
let chan = pooled_client.channel();
|
||||
|
||||
let mut client =
|
||||
PageServiceClient::with_interceptor(chan, self.auth_interceptor.for_shard(shard));
|
||||
|
||||
let request = proto::RelSizeRequest::from(request);
|
||||
let response = client.rel_size(tonic::Request::new(request)).await?;
|
||||
|
||||
// TODO: check for an error and pass it to "finish"
|
||||
pooled_client.finish(Ok(())).await;
|
||||
Ok(response.get_ref().num_blocks)
|
||||
}
|
||||
|
||||
@@ -100,23 +114,26 @@ impl PageserverClient {
|
||||
// FIXME: calculate the shard number correctly
|
||||
let shard = ShardIndex::unsharded();
|
||||
|
||||
let mut client = self.get_client(shard).await?;
|
||||
let pooled_client = self.get_client(shard).await;
|
||||
let chan = pooled_client.channel();
|
||||
|
||||
let mut client =
|
||||
PageServiceClient::with_interceptor(chan, self.auth_interceptor.for_shard(shard));
|
||||
|
||||
let request = proto::GetPageRequest::from(request);
|
||||
let response = client.get_page(tonic::Request::new(request)).await?;
|
||||
let response: GetPageResponse = response.into_inner().try_into()?;
|
||||
if response.status != GetPageStatus::Ok {
|
||||
return Err(PageserverClientError::RequestError(tonic::Status::new(
|
||||
tonic::Code::Internal,
|
||||
format!(
|
||||
"{:?} {}",
|
||||
response.status,
|
||||
response.reason.unwrap_or_default()
|
||||
),
|
||||
)));
|
||||
let response = client.get_page(tonic::Request::new(request)).await;
|
||||
match response {
|
||||
Err(status) => {
|
||||
pooled_client.finish(Err(status.clone())).await;
|
||||
return Err(PageserverClientError::RequestError(status));
|
||||
}
|
||||
Ok(resp) => {
|
||||
pooled_client.finish(Ok(())).await;
|
||||
let response: GetPageResponse = resp.into_inner().try_into()?;
|
||||
return Ok(response.page_image);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(response.page_image)
|
||||
}
|
||||
|
||||
// TODO: this should use model::GetPageRequest and GetPageResponse
|
||||
@@ -127,12 +144,24 @@ impl PageserverClient {
|
||||
tonic::Response<tonic::codec::Streaming<proto::GetPageResponse>>,
|
||||
PageserverClientError,
|
||||
> {
|
||||
|
||||
// Print a debug message
|
||||
// FIXME: calculate the shard number correctly
|
||||
let shard = ShardIndex::unsharded();
|
||||
|
||||
let mut client = self.get_client(shard).await?;
|
||||
let pooled_client = self.get_client(shard).await;
|
||||
let chan = pooled_client.channel();
|
||||
|
||||
let mut client =
|
||||
PageServiceClient::with_interceptor(chan, self.auth_interceptor.for_shard(shard));
|
||||
|
||||
// Check for an error return from get_pages
|
||||
// Declare response
|
||||
|
||||
// TODO: check for an error and pass it to "finish"
|
||||
pooled_client.finish(Ok(())).await;
|
||||
return Ok(client.get_pages(tonic::Request::new(requests)).await?);
|
||||
|
||||
Ok(client.get_pages(tonic::Request::new(requests)).await?)
|
||||
}
|
||||
|
||||
/// Process a request to get the size of a database.
|
||||
@@ -142,12 +171,17 @@ impl PageserverClient {
|
||||
) -> Result<u64, PageserverClientError> {
|
||||
// Current sharding model assumes that all metadata is present only at shard 0.
|
||||
let shard = ShardIndex::unsharded();
|
||||
let pooled_client = self.get_client(shard).await;
|
||||
let chan = pooled_client.channel();
|
||||
|
||||
let mut client = self.get_client(shard).await?;
|
||||
let mut client =
|
||||
PageServiceClient::with_interceptor(chan, self.auth_interceptor.for_shard(shard));
|
||||
|
||||
let request = proto::DbSizeRequest::from(request);
|
||||
let response = client.db_size(tonic::Request::new(request)).await?;
|
||||
|
||||
// TODO: check for an error and pass it to "finish"
|
||||
pooled_client.finish(Ok(())).await;
|
||||
Ok(response.get_ref().num_bytes)
|
||||
}
|
||||
|
||||
@@ -163,7 +197,12 @@ impl PageserverClient {
|
||||
// Current sharding model assumes that all metadata is present only at shard 0.
|
||||
let shard = ShardIndex::unsharded();
|
||||
|
||||
let mut client = self.get_client(shard).await?;
|
||||
let pooled_client = self.get_client(shard).await;
|
||||
let chan = pooled_client.channel();
|
||||
|
||||
let mut client =
|
||||
PageServiceClient::with_interceptor(chan, self.auth_interceptor.for_shard(shard));
|
||||
|
||||
if gzip {
|
||||
client = client.accept_compressed(tonic::codec::CompressionEncoding::Gzip);
|
||||
}
|
||||
@@ -171,47 +210,45 @@ impl PageserverClient {
|
||||
let request = proto::GetBaseBackupRequest::from(request);
|
||||
let response = client.get_base_backup(tonic::Request::new(request)).await?;
|
||||
|
||||
// TODO: check for an error and pass it to "finish"
|
||||
pooled_client.finish(Ok(())).await;
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// Get a client for given shard
|
||||
///
|
||||
/// This implements very basic caching. If we already have a client for the given shard,
|
||||
/// reuse it. If not, create a new client and put it to the cache.
|
||||
/// Get a client from the pool for this shard, also creating the pool if it doesn't exist.
|
||||
///
|
||||
async fn get_client(
|
||||
&self,
|
||||
shard: ShardIndex,
|
||||
) -> Result<MyPageServiceClient, PageserverClientError> {
|
||||
let reused_channel: Option<Channel> = {
|
||||
let channels = self.channels.read().unwrap();
|
||||
) -> client_cache::PooledClient {
|
||||
|
||||
let reused_pool: Option<Arc<client_cache::ConnectionPool>> = {
|
||||
let channels = self.channels.read().unwrap();
|
||||
channels.get(&shard).cloned()
|
||||
};
|
||||
|
||||
let channel = if let Some(reused_channel) = reused_channel {
|
||||
reused_channel
|
||||
} else {
|
||||
let endpoint: tonic::transport::Endpoint = self
|
||||
.shard_map
|
||||
.get(&shard)
|
||||
.expect("no url for shard {shard}")
|
||||
.parse()?;
|
||||
let channel = endpoint.connect().await?;
|
||||
|
||||
// Insert it to the cache so that it can be reused on subsequent calls. It's possible
|
||||
// that another thread did the same concurrently, in which case we will overwrite the
|
||||
// client in the cache.
|
||||
{
|
||||
let mut channels = self.channels.write().unwrap();
|
||||
channels.insert(shard, channel.clone());
|
||||
let usable_pool : Arc<client_cache::ConnectionPool>;
|
||||
match reused_pool {
|
||||
Some(pool) => {
|
||||
let pooled_client = pool.get_client().await;
|
||||
return pooled_client;
|
||||
}
|
||||
channel
|
||||
};
|
||||
None => {
|
||||
let new_pool = client_cache::ConnectionPool::new(
|
||||
self.shard_map.get(&shard).unwrap(),
|
||||
5000, 5, Duration::from_millis(200), Duration::from_secs(1));
|
||||
let mut write_pool = self.channels.write().unwrap();
|
||||
write_pool.insert(shard, new_pool.clone());
|
||||
usable_pool = new_pool.clone();
|
||||
}
|
||||
}
|
||||
|
||||
let client =
|
||||
PageServiceClient::with_interceptor(channel, self.auth_interceptor.for_shard(shard));
|
||||
Ok(client)
|
||||
let pooled_client = usable_pool.get_client().await;
|
||||
return pooled_client;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/// Inject tenant_id, timeline_id and authentication token to all pageserver requests.
|
||||
|
||||
Reference in New Issue
Block a user