Use a sempahore to gate access to connections. Add metrics for testing.

This commit is contained in:
Elizabeth Murray
2025-05-23 11:46:06 -07:00
parent bb28109ffa
commit af9379ccf6
5 changed files with 333 additions and 123 deletions

View File

@@ -17,6 +17,7 @@ rand = "0.8"
tokio-util = { version = "0.7", features = ["compat"] }
hyper-util = "0.1.9"
hyper = "1.6.0"
metrics.workspace = true
pageserver_page_api.workspace = true

View File

@@ -1,14 +1,37 @@
use std::{
collections::HashMap,
sync::Arc,
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
},
time::{Duration, Instant},
};
use tokio::{
sync::{Mutex, Notify, mpsc, watch},
sync::{Mutex, mpsc, watch, Semaphore, OwnedSemaphorePermit},
time::sleep,
net::TcpStream,
io::{AsyncRead, AsyncWrite, ReadBuf},
};
use tonic::transport::{Channel, Endpoint};
use uuid;
use std::io::{self, Error, ErrorKind};
use std::{
pin::Pin,
task::{Context, Poll}
};
use futures::future;
use rand::{
Rng,
rngs::StdRng,
SeedableRng
};
use tower::service_fn;
use http::Uri;
use hyper_util::rt::TokioIo;
use bytes::BytesMut;
use futures::future;
use http::Uri;
@@ -24,6 +47,14 @@ use tokio::net::TcpStream;
use tower::service_fn;
use uuid;
use metrics;
use metrics::proto::MetricFamily;
use metrics::{Encoder, TextEncoder};
// use info
use tracing::info;
use tokio_util::sync::CancellationToken;
/// A pooled gRPC client with capacity tracking and error handling.
pub struct ConnectionPool {
inner: Mutex<Inner>,
@@ -42,18 +73,10 @@ pub struct ConnectionPool {
// The maximum duration a connection can be idle before being removed
max_idle_duration: Duration,
channel_semaphore: Arc<Semaphore>,
// This notify is signaled when a connection is released or created.
notify: Notify,
// When it is time to create a new connection for the pool, we signal
// a watch and a connection creation async wakes up and does the work.
cc_watch_tx: watch::Sender<bool>,
cc_watch_rx: watch::Receiver<bool>,
// To acquire a connection from the pool, send a request
// to this mpsc, and wait for a response.
request_tx: mpsc::Sender<mpsc::Sender<PooledClient>>,
shutdown_token: CancellationToken,
aggregate_metrics: Option<Arc<crate::PageserverClientAggregateMetrics>>,
}
struct Inner {
@@ -62,6 +85,8 @@ struct Inner {
// This is updated when a connection is dropped, or we fail
// to create a new connection.
last_connect_failure: Option<Instant>,
waiters: usize,
in_progress: usize,
}
struct ConnectionEntry {
@@ -77,6 +102,7 @@ pub struct PooledClient {
pub channel: Channel,
pool: Arc<ConnectionPool>,
id: uuid::Uuid,
permit: OwnedSemaphorePermit,
}
/// Wraps a `TcpStream`, buffers incoming data, and injects a random delay per fresh read/write.
pub struct TokioTcp {
@@ -206,7 +232,6 @@ impl AsyncWrite for TokioTcp {
}
impl ConnectionPool {
/// Create a new pool and spawn the background task that handles requests.
pub fn new(
endpoint: &String,
max_consumers: usize,
@@ -217,107 +242,165 @@ impl ConnectionPool {
max_delay_ms: u64,
drop_rate: f64,
hang_rate: f64,
aggregate_metrics: Option<Arc<crate::PageserverClientAggregateMetrics>>,
) -> Arc<Self> {
let (request_tx, mut request_rx) = mpsc::channel::<mpsc::Sender<PooledClient>>(100);
let (watch_tx, watch_rx) = watch::channel(false);
let shutdown_token = CancellationToken::new();
let pool = Arc::new(Self {
inner: Mutex::new(Inner {
entries: HashMap::new(),
last_connect_failure: None,
waiters: 0,
in_progress: 0,
}),
notify: Notify::new(),
cc_watch_tx: watch_tx,
cc_watch_rx: watch_rx,
channel_semaphore: Arc::new(Semaphore::new(0)),
endpoint: endpoint.clone(),
max_consumers: max_consumers,
error_threshold: error_threshold,
connect_timeout: connect_timeout,
connect_backoff: connect_backoff,
max_idle_duration: max_idle_duration,
request_tx: request_tx,
max_delay_ms: max_delay_ms,
drop_rate: drop_rate,
hang_rate: hang_rate,
shutdown_token: shutdown_token.clone(),
aggregate_metrics: aggregate_metrics.clone(),
});
//
// Background task to handle requests and create connections.
//
// TODO: These should be canceled when the ConnectionPool is dropped
//
let bg_cc_pool = Arc::clone(&pool);
tokio::spawn(async move {
loop {
bg_cc_pool.create_connection().await;
}
});
let bg_pool = Arc::clone(&pool);
tokio::spawn(async move {
while let Some(responder) = request_rx.recv().await {
// TODO: This call should time out and return an error
let (id, channel) = bg_pool.acquire_connection().await;
let client = PooledClient {
channel,
pool: Arc::clone(&bg_pool),
id,
};
let _ = responder.send(client).await;
}
});
// Background task to sweep idle connections
// Cancelable background task to sweep idle connections
let sweeper_token = shutdown_token.clone();
let sweeper_pool = Arc::clone(&pool);
tokio::spawn(async move {
loop {
sweeper_pool.sweep_idle_connections().await;
sleep(Duration::from_secs(5)).await; // Run every 5 seconds
tokio::select! {
_ = sweeper_token.cancelled() => break,
_ = async {
sweeper_pool.sweep_idle_connections().await;
sleep(Duration::from_secs(5)).await;
} => {}
}
}
});
pool
}
// Sweep and remove idle connections
async fn sweep_idle_connections(&self) {
let mut inner = self.inner.lock().await;
let now = Instant::now();
inner.entries.retain(|_id, entry| {
if entry.active_consumers == 0
&& now.duration_since(entry.last_used) > self.max_idle_duration
{
// Remove idle connection
return false;
pub async fn shutdown(self: Arc<Self>) {
self.shutdown_token.cancel();
loop {
let all_idle = {
let inner = self.inner.lock().await;
inner.entries.values().all(|e| e.active_consumers == 0)
};
if all_idle {
break;
}
true
});
sleep(Duration::from_millis(100)).await;
}
// 4. Remove all entries
let mut inner = self.inner.lock().await;
inner.entries.clear();
}
async fn acquire_connection(&self) -> (uuid::Uuid, Channel) {
loop {
// Reuse an existing healthy connection if available
{
let mut inner = self.inner.lock().await;
// TODO: Use a heap, although the number of connections is small
if let Some((&id, entry)) = inner
.entries
.iter_mut()
.filter(|(_, e)| e.active_consumers < self.max_consumers)
.filter(|(_, e)| e.consecutive_errors < self.error_threshold)
.max_by_key(|(_, e)| e.active_consumers)
/// Sweep and remove idle connections safely, burning their permits.
async fn sweep_idle_connections(self: &Arc<Self>) {
let mut to_forget = Vec::new();
let now = Instant::now();
// Remove idle entries and collect permits to forget
{
let mut inner = self.inner.lock().await;
inner.entries.retain(|_, entry| {
if entry.active_consumers == 0
&& now.duration_since(entry.last_used) > self.max_idle_duration
{
entry.active_consumers += 1;
return (id, entry.channel.clone());
let semaphore = Arc::clone(&self.channel_semaphore);
if let Ok(permits) = semaphore.try_acquire_many_owned(self.max_consumers as u32) {
to_forget.push(permits);
return false; // remove this entry
}
}
// There is no usable connection, so notify the connection creation async to make one. (It is
// possible that a consumer will release a connection while the new one is being created, in
// which case we will use it right away, but the new connection will be created anyway.)
let _ = self.cc_watch_tx.send(true);
true
});
}
// Permanently consume those permits
for permit in to_forget {
permit.forget();
}
}
// If we have a permit already, get a connection out of the hash table
async fn get_conn_with_permit(self: Arc<Self>, permit: OwnedSemaphorePermit) -> PooledClient {
let mut inner = self.inner.lock().await;
// TODO: Use a heap, although the number of connections is small
if let Some((&id, entry)) = inner
.entries
.iter_mut()
.filter(|(_, e)| e.active_consumers < self.max_consumers)
.filter(|(_, e)| e.consecutive_errors < self.error_threshold)
.max_by_key(|(_, e)| e.active_consumers)
{
entry.active_consumers += 1;
let client = PooledClient {
channel: entry.channel.clone(),
pool: Arc::clone(&self),
id,
permit,
};
return client;
} else {
panic!("Corrupt state: no available connections with permit acquired.");
}
}
pub async fn get_client(self: Arc<Self>) -> Result<PooledClient, tonic::Status> {
if self.shutdown_token.is_cancelled() {
return Err(tonic::Status::unavailable("Pool is shutting down"));
}
// Try to get the semaphore. If it fails, we are out of connections, so
// request that a new connection be created.
let mut semaphore = Arc::clone(&self.channel_semaphore);
match semaphore.try_acquire_owned() {
Ok(permit_) => {
let pool_conn = self.get_conn_with_permit(permit_).await;
return Ok(pool_conn);
}
Err(_) => {
match self.aggregate_metrics {
Some(ref metrics) => {
metrics.retry_counters.with_label_values(&["sema_acquire_failed"]).inc();
}
None => {}
}
{
let mut inner = self.inner.lock().await;
inner.waiters += 1;
if inner.waiters > (inner.in_progress * self.max_consumers) {
let self_clone = Arc::clone(&self);
tokio::task::spawn(async move {
self_clone.create_connection().await;
});
inner.in_progress += 1;
}
}
// Wait for a connection to become available, either because it
// was created or because a connection was returned to the pool.
semaphore = Arc::clone(&self.channel_semaphore);
let conn_permit = semaphore.acquire_owned().await.unwrap();
{
let mut inner = self.inner.lock().await;
inner.waiters -= 1;
}
let pool_conn = self.get_conn_with_permit(conn_permit).await;
return Ok(pool_conn);
}
// Wait for a new connection, or for one of the consumers to release a connection
// TODO: Put this notify in a timeout
self.notify.notified().await;
}
}
@@ -329,7 +412,6 @@ impl ConnectionPool {
// This is a custom connector that inserts delays and errors, for
// testing purposes. It would normally be disabled by the config.
let connector = service_fn(move |uri: Uri| {
let max_delay = max_delay_ms;
let drop_rate = drop_rate;
let hang_rate = hang_rate;
async move {
@@ -340,11 +422,6 @@ impl ConnectionPool {
return future::pending::<Result<TokioIo<TokioTcp>, std::io::Error>>().await;
}
if max_delay > 0 {
// Random delay before connecting
let delay = rng.gen_range(0..max_delay);
tokio::time::sleep(Duration::from_millis(delay)).await;
}
// Random drop (connect error)
if drop_rate > 0.0 && rng.gen_bool(drop_rate) {
return Err(std::io::Error::new(
@@ -363,30 +440,23 @@ impl ConnectionPool {
_ => return Err(Error::new(ErrorKind::InvalidInput, "no host or port")),
};
//let addr = uri.authority().unwrap().as_str();
let tcp = TcpStream::connect(addr).await?;
let tcpwrapper = TokioTcp::new(tcp, max_delay_ms);
Ok(TokioIo::new(tcpwrapper))
}
});
// Wait to be signalled to create a connection.
let mut recv = self.cc_watch_tx.subscribe();
if !*self.cc_watch_rx.borrow() {
while recv.changed().await.is_ok() {
if *self.cc_watch_rx.borrow() {
break;
}
}
}
let mut backoff_delay = self.connect_backoff;
loop {
// Back off.
// Loop because failure can occur while we are sleeping, so wait
// until the failure stopped for at least one backoff period.
loop {
if let Some(delay) = {
let inner = self.inner.lock().await;
inner.last_connect_failure.and_then(|at| {
(at.elapsed() < self.connect_backoff)
.then(|| self.connect_backoff - at.elapsed())
(at.elapsed() < backoff_delay)
.then(|| backoff_delay - at.elapsed())
})
} {
sleep(delay).await;
@@ -402,6 +472,14 @@ impl ConnectionPool {
// on this connection. (Requests made later on this channel will time out
// with the same timeout.)
//
match self.aggregate_metrics {
Some(ref metrics) => {
metrics.retry_counters.with_label_values(&["connection_attempt"]).inc();
}
None => {}
}
let attempt = tokio::time::timeout(
self.connect_timeout,
Endpoint::from_shared(self.endpoint.clone())
@@ -414,6 +492,12 @@ impl ConnectionPool {
match attempt {
Ok(Ok(channel)) => {
{
match self.aggregate_metrics {
Some(ref metrics) => {
metrics.retry_counters.with_label_values(&["connection_success"]).inc();
}
None => {}
}
let mut inner = self.inner.lock().await;
let id = uuid::Uuid::new_v4();
inner.entries.insert(
@@ -426,31 +510,34 @@ impl ConnectionPool {
last_used: Instant::now(),
},
);
self.notify.notify_one();
let _ = self.cc_watch_tx.send(false);
self.channel_semaphore.add_permits(self.max_consumers);
// decrement in progress connections
inner.in_progress -= 1;
return;
};
}
Ok(Err(_)) | Err(_) => {
match self.aggregate_metrics {
Some(ref metrics) => {
metrics.retry_counters.with_label_values(&["connect_failed"]).inc();
}
None => {}
}
let mut inner = self.inner.lock().await;
inner.last_connect_failure = Some(Instant::now());
// Add some jitter so that every connection doesn't retry at once
let jitter = rand::thread_rng().gen_range(0..=backoff_delay.as_millis() as u64);
backoff_delay = Duration::from_millis(backoff_delay.as_millis() as u64 + jitter);
// Do not delay longer than one minute
if backoff_delay > Duration::from_secs(60) {
backoff_delay = Duration::from_secs(60);
}
}
}
}
}
/// Get a client we can use to send gRPC messages.
pub async fn get_client(&self) -> PooledClient {
let (resp_tx, mut resp_rx) = mpsc::channel(1);
self.request_tx
.send(resp_tx)
.await
.expect("ConnectionPool task has shut down");
resp_rx
.recv()
.await
.expect("ConnectionPool task has shut down")
}
/// Return client to the pool, indicating success or error.
pub async fn return_client(&self, id: uuid::Uuid, success: bool) {
@@ -458,7 +545,6 @@ impl ConnectionPool {
let mut new_failure = false;
if let Some(entry) = inner.entries.get_mut(&id) {
entry.last_used = Instant::now();
// TODO: This should be a debug_assert
if entry.active_consumers <= 0 {
panic!("A consumer completed when active_consumers was zero!")
}
@@ -488,10 +574,9 @@ impl ConnectionPool {
if remove == 0 {
inner.entries.remove(&id);
}
} else {
self.notify.notify_one();
}
}
// The semaphore permit is released when the pooled client is dropped.
}
}

View File

@@ -22,6 +22,8 @@ use utils::shard::ShardIndex;
use std::fmt::Debug;
mod client_cache;
use metrics::{IntCounter, IntCounterVec, core::Collector};
#[derive(Error, Debug)]
pub enum PageserverClientError {
#[error("could not connect to service: {0}")]
@@ -38,6 +40,42 @@ pub enum PageserverClientError {
Other(String),
}
#[derive(Clone, Debug)]
pub struct PageserverClientAggregateMetrics {
pub request_counters: IntCounterVec,
pub retry_counters: IntCounterVec,
}
impl PageserverClientAggregateMetrics {
pub fn new() -> Self {
let request_counters = IntCounterVec::new(
metrics::core::Opts::new(
"backend_requests_total",
"Number of requests from backends.",
),
&["request_kind"],
).unwrap();
let retry_counters = IntCounterVec::new(
metrics::core::Opts::new(
"backend_requests_retries_total",
"Number of retried requests from backends.",
),
&["request_kind"],
).unwrap();
Self {
request_counters,
retry_counters,
}
}
pub fn collect (&self) -> Vec<metrics::proto::MetricFamily> {
let mut metrics = Vec::new();
metrics.append(&mut self.request_counters.collect());
metrics.append(&mut self.retry_counters.collect());
metrics
}
}
pub struct PageserverClient {
_tenant_id: String,
_timeline_id: String,
@@ -51,6 +89,8 @@ pub struct PageserverClient {
auth_interceptor: AuthInterceptor,
client_cache_options: ClientCacheOptions,
aggregate_metrics: Option<Arc<PageserverClientAggregateMetrics>>,
}
pub struct ClientCacheOptions {
@@ -82,7 +122,7 @@ impl PageserverClient {
drop_rate: 0.0,
hang_rate: 0.0,
};
Self::new_with_config(tenant_id, timeline_id, auth_token, shard_map, options)
Self::new_with_config(tenant_id, timeline_id, auth_token, shard_map, options, None)
}
pub fn new_with_config(
tenant_id: &str,
@@ -90,7 +130,9 @@ impl PageserverClient {
auth_token: &Option<String>,
shard_map: HashMap<ShardIndex, String>,
options: ClientCacheOptions,
metrics: Option<Arc<PageserverClientAggregateMetrics>>,
) -> Self {
Self {
_tenant_id: tenant_id.to_string(),
_timeline_id: timeline_id.to_string(),
@@ -99,6 +141,7 @@ impl PageserverClient {
channels: RwLock::new(HashMap::new()),
auth_interceptor: AuthInterceptor::new(tenant_id, timeline_id, auth_token.as_deref()),
client_cache_options: options,
aggregate_metrics: metrics,
}
}
pub async fn process_check_rel_exists_request(
@@ -185,6 +228,13 @@ impl PageserverClient {
));
};
match self.aggregate_metrics {
Some(ref metrics) => {
metrics.request_counters.with_label_values(&["get_page"]).inc();
}
None => {}
}
match response {
Err(status) => {
pooled_client.finish(Err(status.clone())).await; // Pass error to finish
@@ -305,7 +355,7 @@ impl PageserverClient {
let usable_pool: Arc<client_cache::ConnectionPool>;
match reused_pool {
Some(pool) => {
let pooled_client = pool.get_client().await;
let pooled_client = pool.get_client().await.unwrap();
return pooled_client;
}
None => {
@@ -323,6 +373,7 @@ impl PageserverClient {
self.client_cache_options.max_delay_ms,
self.client_cache_options.drop_rate,
self.client_cache_options.hang_rate,
self.aggregate_metrics.clone(),
);
let mut write_pool = self.channels.write().unwrap();
write_pool.insert(shard, new_pool.clone());
@@ -330,7 +381,7 @@ impl PageserverClient {
}
}
let pooled_client = usable_pool.get_client().await;
let pooled_client = usable_pool.get_client().await.unwrap();
return pooled_client;
}
}

View File

@@ -22,6 +22,10 @@ tracing.workspace = true
tokio.workspace = true
tokio-stream.workspace = true
tokio-util.workspace = true
axum.workspace = true
http.workspace = true
metrics.workspace = true
pageserver_client.workspace = true
pageserver_client_grpc.workspace = true

View File

@@ -23,6 +23,19 @@ use utils::id::TenantTimelineId;
use utils::lsn::Lsn;
use utils::shard::ShardIndex;
use axum::Router;
use axum::body::Body;
use axum::extract::State;
use axum::response::Response;
use http::StatusCode;
use http::header::CONTENT_TYPE;
use metrics;
use metrics::proto::MetricFamily;
use metrics::{Encoder, TextEncoder};
use crate::util::tokio_thread_local_stats::AllThreadLocalStats;
use crate::util::{request_stats, tokio_thread_local_stats};
@@ -157,6 +170,36 @@ pub(crate) fn main(args: Args) -> anyhow::Result<()> {
main_impl(args, thread_local_stats)
})
}
async fn get_metrics(State(state): State<Arc<pageserver_client_grpc::PageserverClientAggregateMetrics>>) -> Response {
use metrics::core::Collector;
let metrics = state.collect();
info!("metrics: {metrics:?}");
// When we call TextEncoder::encode() below, it will immediately return an
// error if a metric family has no metrics, so we need to preemptively
// filter out metric families with no metrics.
let metrics = metrics
.into_iter()
.filter(|m| !m.get_metric().is_empty())
.collect::<Vec<MetricFamily>>();
let encoder = TextEncoder::new();
let mut buffer = vec![];
if let Err(e) = encoder.encode(&metrics, &mut buffer) {
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.header(CONTENT_TYPE, "application/text")
.body(Body::from(e.to_string()))
.unwrap()
} else {
Response::builder()
.status(StatusCode::OK)
.header(CONTENT_TYPE, encoder.format_type())
.body(Body::from(buffer))
.unwrap()
}
}
async fn main_impl(
args: Args,
@@ -164,6 +207,24 @@ async fn main_impl(
) -> anyhow::Result<()> {
let args: &'static Args = Box::leak(Box::new(args));
// Vector of pageserver clients
let client_metrics = Arc::new(pageserver_client_grpc::PageserverClientAggregateMetrics::new());
use axum::routing::get;
let app = Router::new()
.route("/metrics", get(get_metrics))
.with_state(client_metrics.clone());
// TODO: make configurable. Or listen on unix domain socket?
let listener = tokio::net::TcpListener::bind("127.0.0.1:9090")
.await
.unwrap();
tokio::spawn(async {
tracing::info!("metrics listener spawned");
axum::serve(listener, app).await.unwrap()
});
let mgmt_api_client = Arc::new(pageserver_client::mgmt_api::Client::new(
reqwest::Client::new(), // TODO: support ssl_ca_file for https APIs in pagebench.
args.mgmt_api_endpoint.clone(),
@@ -322,6 +383,8 @@ async fn main_impl(
let rps_period = args
.per_client_rate
.map(|rps_limit| Duration::from_secs_f64(1.0 / (rps_limit as f64)));
let new_metrics = client_metrics.clone();
let make_worker: &dyn Fn(WorkerId) -> Pin<Box<dyn Send + Future<Output = ()>>> = &|worker_id| {
let ss = shared_state.clone();
let cancel = cancel.clone();
@@ -334,11 +397,12 @@ async fn main_impl(
rand::distributions::weighted::WeightedIndex::new(ranges.iter().map(|v| v.len()))
.unwrap();
let new_value = new_metrics.clone();
Box::pin(async move {
if args.grpc_stream {
client_grpc_stream(args, worker_id, ss, cancel, rps_period, ranges, weights).await
} else if args.grpc {
client_grpc(args, worker_id, ss, cancel, rps_period, ranges, weights).await
client_grpc(args, worker_id, new_value, ss, cancel, rps_period, ranges, weights).await
} else {
client_libpq(args, worker_id, ss, cancel, rps_period, ranges, weights).await
}
@@ -485,6 +549,7 @@ async fn client_libpq(
async fn client_grpc(
args: &Args,
worker_id: WorkerId,
client_metrics: Arc<pageserver_client_grpc::PageserverClientAggregateMetrics>,
shared_state: Arc<SharedState>,
cancel: CancellationToken,
rps_period: Option<Duration>,
@@ -511,9 +576,13 @@ async fn client_grpc(
&None,
shard_map,
options,
Some(client_metrics.clone()),
);
let client = Arc::new(client);
shared_state.start_work_barrier.wait().await;
let client_start = Instant::now();
let mut ticks_processed = 0;