Remove some old code

This commit is contained in:
Erik Grinaker
2025-07-02 11:46:54 +02:00
parent 8ab8fc11a3
commit bf01145ae4
7 changed files with 12 additions and 2073 deletions

View File

@@ -1,273 +0,0 @@
// examples/load_test.rs, generated by AI
use std::collections::{HashMap, HashSet};
use std::sync::{
Arc, Mutex,
atomic::{AtomicU64, AtomicUsize, Ordering},
};
use std::time::{Duration, Instant};
use rand::Rng;
use tokio::task;
use tokio::time::sleep;
use tonic::Status;
// Pull in your ConnectionPool and PooledItemFactory from the pageserver_client_grpc crate.
// Adjust these paths if necessary.
use pageserver_client_grpc::client_cache::ConnectionPool;
use pageserver_client_grpc::client_cache::PooledItemFactory;
// --------------------------------------
// GLOBAL COUNTERS FOR “CREATED” / “DROPPED” MockConnections
// --------------------------------------
static CREATED: AtomicU64 = AtomicU64::new(0);
static DROPPED: AtomicU64 = AtomicU64::new(0);
// --------------------------------------
// MockConnection + Factory
// --------------------------------------
#[derive(Debug)]
pub struct MockConnection {
pub id: u64,
}
impl Clone for MockConnection {
fn clone(&self) -> Self {
// Cloning a MockConnection does NOT count as “creating” a brandnew connection,
// so we do NOT bump CREATED here. We only bump CREATED in the factorys `create()`.
CREATED.fetch_add(1, Ordering::Relaxed);
MockConnection { id: self.id }
}
}
impl Drop for MockConnection {
fn drop(&mut self) {
// When a MockConnection actually gets dropped, bump the counter.
DROPPED.fetch_add(1, Ordering::SeqCst);
}
}
#[derive(Default)]
pub struct MockConnectionFactory {
counter: AtomicU64,
}
#[async_trait::async_trait]
impl PooledItemFactory<MockConnection> for MockConnectionFactory {
/// The trait on ConnectionPool expects:
/// async fn create(&self, timeout: Duration)
/// -> Result<Result<MockConnection, Status>, tokio::time::error::Elapsed>;
///
/// On success: Ok(Ok(MockConnection))
/// On a simulated “gRPC” failure: Ok(Err(Status::…))
/// On a transport/factory error: Err(Box<…>)
async fn create(
&self,
_timeout: Duration,
) -> Result<Result<MockConnection, Status>, tokio::time::error::Elapsed> {
// Simulate connection creation immediately succeeding.
CREATED.fetch_add(1, Ordering::SeqCst);
let next_id = self.counter.fetch_add(1, Ordering::Relaxed);
Ok(Ok(MockConnection { id: next_id }))
}
}
// --------------------------------------
// CLIENT WORKER
// --------------------------------------
//
// Each worker repeatedly calls `pool.get_client().await`. When it succeeds, we:
// 1. Lock the shared Mutex<HashMap<u64, Arc<AtomicUsize>>> to fetch/insert an Arc<AtomicUsize> for this conn_id.
// 2. Lock the shared Mutex<HashSet<u64>> to record this conn_id as “seen.”
// 3. Drop both locks, then atomically increment that counter and assert it ≤ max_consumers.
// 4. Sleep 10100 ms to simulate “work.”
// 5. Atomically decrement the counter.
// 6. Call `pooled.finish(Ok(()))` to return to the pool.
async fn client_worker(
pool: Arc<ConnectionPool<MockConnection>>,
usage_map: Arc<Mutex<HashMap<u64, Arc<AtomicUsize>>>>,
seen_set: Arc<Mutex<HashSet<u64>>>,
max_consumers: usize,
worker_id: usize,
) {
for iteration in 0..10 {
match pool.clone().get_client().await {
Ok(pooled) => {
let conn: MockConnection = pooled.channel();
let conn_id = conn.id;
// 1. Fetch or insert the Arc<AtomicUsize> for this conn_id:
let counter_arc: Arc<AtomicUsize> = {
let mut guard = usage_map.lock().unwrap();
guard
.entry(conn_id)
.or_insert_with(|| Arc::new(AtomicUsize::new(0)))
.clone()
// MutexGuard is dropped here
};
// 2. Record this conn_id in the shared HashSet of “seen” IDs:
{
let mut seen_guard = seen_set.lock().unwrap();
seen_guard.insert(conn_id);
// MutexGuard is dropped immediately
}
// 3. Atomically bump the count for this connection ID
let prev = counter_arc.fetch_add(1, Ordering::SeqCst);
let current = prev + 1;
assert!(
current <= max_consumers,
"Connection {conn_id} exceeded max_consumers (got {current})",
);
println!(
"[worker {worker_id}][iter {iteration}] got MockConnection id={conn_id} ({current} concurrent)",
);
// 4. Simulate some work (10100 ms)
let delay_ms = rand::thread_rng().gen_range(10..100);
sleep(Duration::from_millis(delay_ms)).await;
// 5. Decrement the usage counter
let prev2 = counter_arc.fetch_sub(1, Ordering::SeqCst);
let after = prev2 - 1;
println!(
"[worker {worker_id}][iter {iteration}] returning MockConnection id={conn_id} (now {after} remain)",
);
// 6. Return to the pool (mark success)
pooled.finish(Ok(())).await;
}
Err(status) => {
eprintln!(
"[worker {worker_id}][iter {iteration}] failed to get client: {status:?}",
);
}
}
// Small random pause before next iteration to spread out load
let pause = rand::thread_rng().gen_range(0..20);
sleep(Duration::from_millis(pause)).await;
}
}
#[tokio::main(flavor = "multi_thread", worker_threads = 8)]
async fn main() {
// --------------------------------------
// 1. Create factory and shared instrumentation
// --------------------------------------
let factory = Arc::new(MockConnectionFactory::default());
// Shared map: connection ID → Arc<AtomicUsize>
let usage_map: Arc<Mutex<HashMap<u64, Arc<AtomicUsize>>>> =
Arc::new(Mutex::new(HashMap::new()));
// Shared set: record each unique connection ID we actually saw
let seen_set: Arc<Mutex<HashSet<u64>>> = Arc::new(Mutex::new(HashSet::new()));
// --------------------------------------
// 2. Pool parameters
// --------------------------------------
let connect_timeout = Duration::from_millis(500);
let connect_backoff = Duration::from_millis(100);
let max_consumers = 100; // test limit
let error_threshold = 2; // mock never fails
let max_idle_duration = Duration::from_secs(2);
let max_total_connections = 3;
let aggregate_metrics = None;
let pool: Arc<ConnectionPool<MockConnection>> = ConnectionPool::new(
factory,
connect_timeout,
connect_backoff,
max_consumers,
error_threshold,
max_idle_duration,
max_total_connections,
aggregate_metrics,
);
// --------------------------------------
// 3. Spawn worker tasks
// --------------------------------------
let num_workers = 10000;
let mut handles = Vec::with_capacity(num_workers);
let start_time = Instant::now();
for worker_id in 0..num_workers {
let pool_clone = Arc::clone(&pool);
let usage_clone = Arc::clone(&usage_map);
let seen_clone = Arc::clone(&seen_set);
let mc = max_consumers;
let handle = task::spawn(async move {
client_worker(pool_clone, usage_clone, seen_clone, mc, worker_id).await;
});
handles.push(handle);
}
// --------------------------------------
// 4. Wait for workers to finish
// --------------------------------------
for handle in handles {
let _ = handle.await;
}
let elapsed = Instant::now().duration_since(start_time);
println!("All {num_workers} workers completed in {elapsed:?}");
// --------------------------------------
// 5. Print the total number of unique connections seen so far
// --------------------------------------
let unique_count = {
let seen_guard = seen_set.lock().unwrap();
seen_guard.len()
};
println!("Total unique connections used by workers: {unique_count}");
// --------------------------------------
// 6. Sleep so the background sweeper can run (max_idle_duration = 2 s)
// --------------------------------------
sleep(Duration::from_secs(3)).await;
// --------------------------------------
// 7. Shutdown the pool
// --------------------------------------
let shutdown_pool = Arc::clone(&pool);
shutdown_pool.shutdown().await;
println!("Pool.shutdown() returned.");
// --------------------------------------
// 8. Verify that no background task still holds an Arc clone of `pool`.
// If any task is still alive (sweeper/create_connection), strong_count > 1.
// --------------------------------------
sleep(Duration::from_secs(1)).await; // give tasks time to exit
let sc = Arc::strong_count(&pool);
assert!(
sc == 1,
"Pool tasks did not all terminate: Arc::strong_count = {sc} (expected 1)",
);
println!("Verified: all pool tasks have terminated (strong_count == 1).");
// --------------------------------------
// 9. Verify no MockConnection was leaked:
// CREATED must equal DROPPED.
// --------------------------------------
let created = CREATED.load(Ordering::SeqCst);
let dropped = DROPPED.load(Ordering::SeqCst);
assert!(
created == dropped,
"Leaked connections: created={created} but dropped={dropped}",
);
println!("Verified: no connections leaked (created = {created}, dropped = {dropped}).");
// --------------------------------------
// 10. Because `client_worker` asserted inside that no connection
// ever exceeded `max_consumers`, reaching this point means that check passed.
// --------------------------------------
println!("All per-connection usage stayed within max_consumers = {max_consumers}.");
println!("Load test complete; exiting cleanly.");
}

View File

@@ -1,705 +0,0 @@
use std::{
collections::HashMap,
io::{self, Error, ErrorKind},
sync::Arc,
time::{Duration, Instant},
};
use priority_queue::PriorityQueue;
use tokio::{
io::{AsyncRead, AsyncWrite, ReadBuf},
net::TcpStream,
sync::{Mutex, OwnedSemaphorePermit, Semaphore},
time::sleep,
};
use tonic::transport::{Channel, Endpoint};
use uuid;
use std::{
pin::Pin,
task::{Context, Poll},
};
use futures::future;
use rand::{Rng, SeedableRng, rngs::StdRng};
use bytes::BytesMut;
use http::Uri;
use hyper_util::rt::TokioIo;
use tower::service_fn;
use async_trait::async_trait;
use tokio_util::sync::CancellationToken;
//
// The "TokioTcp" is flakey TCP network for testing purposes, in order
// to simulate network errors and delays.
//
/// Wraps a `TcpStream`, buffers incoming data, and injects a random delay per fresh read/write.
pub struct TokioTcp {
tcp: TcpStream,
/// Maximum randomized delay in milliseconds
delay_ms: u64,
/// Next deadline instant for delay
deadline: Instant,
/// Internal buffer of previously-read data
buffer: BytesMut,
}
impl TokioTcp {
/// Create a new wrapper with given max delay (ms)
pub fn new(stream: TcpStream, delay_ms: u64) -> Self {
let initial = if delay_ms > 0 {
rand::thread_rng().gen_range(0..delay_ms)
} else {
0
};
let deadline = Instant::now() + Duration::from_millis(initial);
TokioTcp {
tcp: stream,
delay_ms,
deadline,
buffer: BytesMut::new(),
}
}
}
impl AsyncRead for TokioTcp {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
// Safe because TokioTcp is Unpin
let this = self.get_mut();
// 1) Drain any buffered data
if !this.buffer.is_empty() {
let to_copy = this.buffer.len().min(buf.remaining());
buf.put_slice(&this.buffer.split_to(to_copy));
return Poll::Ready(Ok(()));
}
// 2) If we're still before the deadline, schedule a wake and return Pending
let now = Instant::now();
if this.delay_ms > 0 && now < this.deadline {
let waker = cx.waker().clone();
let wait = this.deadline - now;
tokio::spawn(async move {
sleep(wait).await;
waker.wake_by_ref();
});
return Poll::Pending;
}
// 3) Past deadline: compute next random deadline
if this.delay_ms > 0 {
let next_ms = rand::thread_rng().gen_range(0..=this.delay_ms);
this.deadline = Instant::now() + Duration::from_millis(next_ms);
}
// 4) Perform actual read into a temporary buffer
let mut tmp = [0u8; 4096];
let mut rb = ReadBuf::new(&mut tmp);
match Pin::new(&mut this.tcp).poll_read(cx, &mut rb) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(())) => {
let filled = rb.filled();
if filled.is_empty() {
// EOF or zero bytes
Poll::Ready(Ok(()))
} else {
this.buffer.extend_from_slice(filled);
let to_copy = this.buffer.len().min(buf.remaining());
buf.put_slice(&this.buffer.split_to(to_copy));
Poll::Ready(Ok(()))
}
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
}
}
}
impl AsyncWrite for TokioTcp {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
data: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
// 1) If before deadline, schedule wake and return Pending
let now = Instant::now();
if this.delay_ms > 0 && now < this.deadline {
let waker = cx.waker().clone();
let wait = this.deadline - now;
tokio::spawn(async move {
sleep(wait).await;
waker.wake_by_ref();
});
return Poll::Pending;
}
// 2) Past deadline: compute next random deadline
if this.delay_ms > 0 {
let next_ms = rand::thread_rng().gen_range(0..=this.delay_ms);
this.deadline = Instant::now() + Duration::from_millis(next_ms);
}
// 3) Actual write
Pin::new(&mut this.tcp).poll_write(cx, data)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
Pin::new(&mut this.tcp).poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
Pin::new(&mut this.tcp).poll_shutdown(cx)
}
}
#[async_trait]
pub trait PooledItemFactory<T>: Send + Sync + 'static {
/// Create a new pooled item.
async fn create(
&self,
connect_timeout: Duration,
) -> Result<Result<T, tonic::Status>, tokio::time::error::Elapsed>;
}
pub struct ChannelFactory {
endpoint: String,
max_delay_ms: u64,
drop_rate: f64,
hang_rate: f64,
}
impl ChannelFactory {
pub fn new(endpoint: String, max_delay_ms: u64, drop_rate: f64, hang_rate: f64) -> Self {
ChannelFactory {
endpoint,
max_delay_ms,
drop_rate,
hang_rate,
}
}
}
#[async_trait]
impl PooledItemFactory<Channel> for ChannelFactory {
async fn create(
&self,
connect_timeout: Duration,
) -> Result<Result<Channel, tonic::Status>, tokio::time::error::Elapsed> {
let max_delay_ms = self.max_delay_ms;
let drop_rate = self.drop_rate;
let hang_rate = self.hang_rate;
// This is a custom connector that inserts delays and errors, for
// testing purposes. It would normally be disabled by the config.
let connector = service_fn(move |uri: Uri| {
let drop_rate = drop_rate;
let hang_rate = hang_rate;
async move {
let mut rng = StdRng::from_entropy();
// Simulate an indefinite hang
if hang_rate > 0.0 && rng.gen_bool(hang_rate) {
// never completes, to test timeout
return future::pending::<Result<TokioIo<TokioTcp>, std::io::Error>>().await;
}
// Random drop (connect error)
if drop_rate > 0.0 && rng.gen_bool(drop_rate) {
return Err(std::io::Error::other("simulated connect drop"));
}
// Otherwise perform real TCP connect
let addr = match (uri.host(), uri.port()) {
// host + explicit port
(Some(host), Some(port)) => format!("{}:{}", host, port.as_str()),
// host only (no port)
(Some(host), None) => host.to_string(),
// neither? error out
_ => return Err(Error::new(ErrorKind::InvalidInput, "no host or port")),
};
let tcp = TcpStream::connect(addr).await?;
let tcpwrapper = TokioTcp::new(tcp, max_delay_ms);
Ok(TokioIo::new(tcpwrapper))
}
});
let attempt = tokio::time::timeout(
connect_timeout,
Endpoint::from_shared(self.endpoint.clone())
.expect("invalid endpoint")
.timeout(connect_timeout)
.connect_with_connector(connector),
)
.await;
match attempt {
Ok(Ok(channel)) => {
// Connection succeeded
Ok(Ok(channel))
}
Ok(Err(e)) => Ok(Err(tonic::Status::new(
tonic::Code::Unavailable,
format!("Failed to connect: {e}"),
))),
Err(e) => Err(e),
}
}
}
/// A pooled gRPC client with capacity tracking and error handling.
pub struct ConnectionPool<T> {
inner: Mutex<Inner<T>>,
fact: Arc<dyn PooledItemFactory<T> + Send + Sync>,
connect_timeout: Duration,
connect_backoff: Duration,
/// The maximum number of consumers that can use a single connection.
max_consumers: usize,
/// The number of consecutive errors before a connection is removed from the pool.
error_threshold: usize,
/// The maximum duration a connection can be idle before being removed.
max_idle_duration: Duration,
max_total_connections: usize,
channel_semaphore: Arc<Semaphore>,
shutdown_token: CancellationToken,
aggregate_metrics: Option<Arc<crate::PageserverClientAggregateMetrics>>,
}
struct Inner<T> {
entries: HashMap<uuid::Uuid, ConnectionEntry<T>>,
pq: PriorityQueue<uuid::Uuid, usize>,
// This is updated when a connection is dropped, or we fail
// to create a new connection.
last_connect_failure: Option<Instant>,
waiters: usize,
in_progress: usize,
}
struct ConnectionEntry<T> {
channel: T,
active_consumers: usize,
consecutive_errors: usize,
last_used: Instant,
}
/// A client borrowed from the pool.
pub struct PooledClient<T> {
pub channel: T,
pool: Arc<ConnectionPool<T>>,
is_ok: bool,
id: uuid::Uuid,
permit: OwnedSemaphorePermit,
}
impl<T: Clone + Send + 'static> ConnectionPool<T> {
#[allow(clippy::too_many_arguments)]
pub fn new(
fact: Arc<dyn PooledItemFactory<T> + Send + Sync>,
connect_timeout: Duration,
connect_backoff: Duration,
max_consumers: usize,
error_threshold: usize,
max_idle_duration: Duration,
max_total_connections: usize,
aggregate_metrics: Option<Arc<crate::PageserverClientAggregateMetrics>>,
) -> Arc<Self> {
let shutdown_token = CancellationToken::new();
let pool = Arc::new(Self {
inner: Mutex::new(Inner::<T> {
entries: HashMap::new(),
pq: PriorityQueue::new(),
last_connect_failure: None,
waiters: 0,
in_progress: 0,
}),
fact: Arc::clone(&fact),
connect_timeout,
connect_backoff,
max_consumers,
error_threshold,
max_idle_duration,
max_total_connections,
channel_semaphore: Arc::new(Semaphore::new(0)),
shutdown_token: shutdown_token.clone(),
aggregate_metrics: aggregate_metrics.clone(),
});
// Cancelable background task to sweep idle connections
let sweeper_token = shutdown_token.clone();
let sweeper_pool = Arc::clone(&pool);
tokio::spawn(async move {
loop {
tokio::select! {
_ = sweeper_token.cancelled() => break,
_ = async {
sweeper_pool.sweep_idle_connections().await;
sleep(Duration::from_secs(5)).await;
} => {}
}
}
});
pool
}
pub async fn shutdown(self: Arc<Self>) {
self.shutdown_token.cancel();
loop {
let all_idle = {
let inner = self.inner.lock().await;
inner.entries.values().all(|e| e.active_consumers == 0)
};
if all_idle {
break;
}
sleep(Duration::from_millis(100)).await;
}
// 4. Remove all entries
let mut inner = self.inner.lock().await;
inner.entries.clear();
}
/// Sweep and remove idle connections safely, burning their permits.
async fn sweep_idle_connections(self: &Arc<Self>) {
let mut ids_to_remove = Vec::new();
let now = Instant::now();
// Remove idle entries. First collect permits for those connections so that
// no consumer will reserve them, then remove them from the pool.
{
let mut inner = self.inner.lock().await;
inner.entries.retain(|id, entry| {
if entry.active_consumers == 0
&& now.duration_since(entry.last_used) > self.max_idle_duration
{
// metric
if let Some(ref metrics) = self.aggregate_metrics {
metrics
.retry_counters
.with_label_values(&["connection_swept"])
.inc();
}
ids_to_remove.push(*id);
return false; // remove this entry
}
true
});
// Remove the entries from the priority queue
for id in ids_to_remove {
inner.pq.remove(&id);
}
}
}
// If we have a permit already, get a connection out of the heap
async fn get_conn_with_permit(
self: Arc<Self>,
permit: OwnedSemaphorePermit,
) -> Option<PooledClient<T>> {
let mut inner = self.inner.lock().await;
// Pop the highest-active-consumers connection. There are no connections
// in the heap that have more than max_consumers active consumers.
if let Some((id, _cons)) = inner.pq.pop() {
let entry = inner
.entries
.get_mut(&id)
.expect("pq and entries got out of sync");
let mut active_consumers = entry.active_consumers;
entry.active_consumers += 1;
entry.last_used = Instant::now();
let client = PooledClient::<T> {
channel: entry.channel.clone(),
pool: Arc::clone(&self),
is_ok: true,
id,
permit,
};
// reinsert with updated priority
active_consumers += 1;
if active_consumers < self.max_consumers {
inner.pq.push(id, active_consumers as usize);
}
Some(client)
} else {
// If there is no connection to take, it is because permits for a connection
// need to drain. This can happen if a connection is removed because it has
// too many errors. It is taken out of the heap/hash table in this case, but
// we can't remove it's permits until now.
//
// Just forget the permit and retry.
permit.forget();
None
}
}
pub async fn get_client(self: Arc<Self>) -> Result<PooledClient<T>, tonic::Status> {
// The pool is shutting down. Don't accept new connections.
if self.shutdown_token.is_cancelled() {
return Err(tonic::Status::unavailable("Pool is shutting down"));
}
// A loop is necessary because when a connection is draining, we have to return
// a permit and retry.
loop {
let self_clone = Arc::clone(&self);
let mut semaphore = Arc::clone(&self_clone.channel_semaphore);
match semaphore.try_acquire_owned() {
Ok(permit_) => {
// We got a permit, so check the heap for a connection
// we can use.
let pool_conn = self_clone.get_conn_with_permit(permit_).await;
match pool_conn {
Some(pool_conn_) => {
return Ok(pool_conn_);
}
None => {
// No connection available. Forget the permit and retry.
continue;
}
}
}
Err(_) => {
if let Some(ref metrics) = self_clone.aggregate_metrics {
metrics
.retry_counters
.with_label_values(&["sema_acquire_success"])
.inc();
}
{
//
// This is going to generate enough connections to handle a burst,
// but it may generate up to twice the number of connections needed
// in the worst case. Extra connections will go idle and be cleaned
// up.
//
let mut inner = self_clone.inner.lock().await;
inner.waiters += 1;
if inner.waiters > (inner.in_progress * self_clone.max_consumers)
&& (inner.entries.len() + inner.in_progress)
< self_clone.max_total_connections
{
let self_clone_spawn = Arc::clone(&self_clone);
tokio::task::spawn(async move {
self_clone_spawn.create_connection().await;
});
inner.in_progress += 1;
}
}
// Wait for a connection to become available, either because it
// was created or because a connection was returned to the pool
// by another consumer.
semaphore = Arc::clone(&self_clone.channel_semaphore);
let conn_permit = semaphore.acquire_owned().await.unwrap();
{
let mut inner = self_clone.inner.lock().await;
inner.waiters -= 1;
}
// We got a permit, check the heap for a connection.
let pool_conn = self_clone.get_conn_with_permit(conn_permit).await;
match pool_conn {
Some(pool_conn_) => {
return Ok(pool_conn_);
}
None => {
// No connection was found, forget the permit and retry.
continue;
}
}
}
}
}
}
async fn create_connection(&self) {
// Generate a random backoff to add some jitter so that connections
// don't all retry at the same time.
let mut backoff_delay = Duration::from_millis(
rand::thread_rng().gen_range(0..=self.connect_backoff.as_millis() as u64),
);
loop {
if self.shutdown_token.is_cancelled() {
return;
}
// Back off.
// Loop because failure can occur while we are sleeping, so wait
// until the failure stopped for at least one backoff period. Backoff
// period includes some jitter, so that if multiple connections are
// failing, they don't all retry at the same time.
while let Some(delay) = {
let inner = self.inner.lock().await;
inner.last_connect_failure.and_then(|at| {
(at.elapsed() < backoff_delay).then(|| backoff_delay - at.elapsed())
})
} {
sleep(delay).await;
}
//
// Create a new connection.
//
// The connect timeout is also the timeout for an individual gRPC request
// on this connection. (Requests made later on this channel will time out
// with the same timeout.)
//
if let Some(ref metrics) = self.aggregate_metrics {
metrics
.retry_counters
.with_label_values(&["connection_attempt"])
.inc();
}
let attempt = self.fact.create(self.connect_timeout).await;
match attempt {
// Connection succeeded
Ok(Ok(channel)) => {
{
if let Some(ref metrics) = self.aggregate_metrics {
metrics
.retry_counters
.with_label_values(&["connection_success"])
.inc();
}
let mut inner = self.inner.lock().await;
let id = uuid::Uuid::new_v4();
inner.entries.insert(
id,
ConnectionEntry::<T> {
channel: channel.clone(),
active_consumers: 0,
consecutive_errors: 0,
last_used: Instant::now(),
},
);
inner.pq.push(id, 0);
inner.in_progress -= 1;
self.channel_semaphore.add_permits(self.max_consumers);
return;
};
}
// Connection failed, back off and retry
Ok(Err(_)) | Err(_) => {
if let Some(ref metrics) = self.aggregate_metrics {
metrics
.retry_counters
.with_label_values(&["connect_failed"])
.inc();
}
let mut inner = self.inner.lock().await;
inner.last_connect_failure = Some(Instant::now());
// Add some jitter so that every connection doesn't retry at once
let jitter = rand::thread_rng().gen_range(0..=backoff_delay.as_millis() as u64);
backoff_delay =
Duration::from_millis(backoff_delay.as_millis() as u64 + jitter);
// Do not backoff longer than one minute
if backoff_delay > Duration::from_secs(60) {
backoff_delay = Duration::from_secs(60);
}
// continue the loop to retry
}
}
}
}
/// Return client to the pool, indicating success or error.
pub async fn return_client(&self, id: uuid::Uuid, success: bool, permit: OwnedSemaphorePermit) {
let mut inner = self.inner.lock().await;
if let Some(entry) = inner.entries.get_mut(&id) {
entry.last_used = Instant::now();
if entry.active_consumers == 0 {
panic!("A consumer completed when active_consumers was zero!")
}
entry.active_consumers -= 1;
if success {
if entry.consecutive_errors < self.error_threshold {
entry.consecutive_errors = 0;
}
} else {
entry.consecutive_errors += 1;
if entry.consecutive_errors == self.error_threshold {
if let Some(ref metrics) = self.aggregate_metrics {
metrics
.retry_counters
.with_label_values(&["connection_dropped"])
.inc();
}
}
}
//
// Too many errors on this connection. If there are no active users,
// remove it. Otherwise just wait for active_consumers to go to zero.
// This connection will not be selected for new consumers.
//
let active_consumers = entry.active_consumers;
if entry.consecutive_errors >= self.error_threshold {
// too many errors, remove the connection permanently. Once it drains,
// it will be dropped.
if inner.pq.get_priority(&id).is_some() {
inner.pq.remove(&id);
}
// remove from entries
// check if entry is in inner
if inner.entries.contains_key(&id) {
inner.entries.remove(&id);
}
inner.last_connect_failure = Some(Instant::now());
// The connection has been removed, it's permits will be
// drained because if we look for a connection and it's not there
// we just forget the permit. However, this process can be a little
// bit faster if we just forget permits as the connections are returned.
permit.forget();
} else {
// update its priority in the queue
if inner.pq.get_priority(&id).is_some() {
inner.pq.change_priority(&id, active_consumers);
} else {
// This connection is not in the heap, but it has space
// for more consumers. Put it back in the heap.
if active_consumers < self.max_consumers {
inner.pq.push(id, active_consumers);
}
}
}
}
}
}
impl<T: Clone + Send + 'static> PooledClient<T> {
pub fn channel(&self) -> T {
self.channel.clone()
}
pub async fn finish(mut self, result: Result<(), tonic::Status>) {
self.is_ok = result.is_ok();
self.pool
.return_client(self.id, self.is_ok, self.permit)
.await;
}
}

View File

@@ -1,451 +1,4 @@
//! Pageserver Data API client
//!
//! - Manage connections to pageserver
//! - Send requests to correct shards
//!
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::Arc;
use std::sync::RwLock;
use std::time::Duration;
mod client;
mod pool;
use bytes::Bytes;
use futures::{Stream, StreamExt};
use thiserror::Error;
use tonic::metadata::AsciiMetadataValue;
use tonic::transport::Channel;
use pageserver_page_api::proto;
use pageserver_page_api::proto::PageServiceClient;
use pageserver_page_api::*;
use utils::shard::ShardIndex;
pub mod client;
pub mod client_cache;
pub mod pool;
pub mod request_tracker;
use metrics::{IntCounterVec, core::Collector};
#[derive(Error, Debug)]
pub enum PageserverClientError {
#[error("could not connect to service: {0}")]
ConnectError(#[from] tonic::transport::Error),
#[error("could not perform request: {0}`")]
RequestError(#[from] tonic::Status),
#[error("protocol error: {0}")]
ProtocolError(#[from] ProtocolError),
#[error("could not perform request: {0}`")]
InvalidUri(#[from] http::uri::InvalidUri),
#[error("could not perform request: {0}`")]
Other(String),
}
#[derive(Clone, Debug)]
pub struct PageserverClientAggregateMetrics {
pub request_counters: IntCounterVec,
pub retry_counters: IntCounterVec,
}
impl Default for PageserverClientAggregateMetrics {
fn default() -> Self {
Self::new()
}
}
impl PageserverClientAggregateMetrics {
pub fn new() -> Self {
let request_counters = IntCounterVec::new(
metrics::core::Opts::new(
"backend_requests_total",
"Number of requests from backends.",
),
&["request_kind"],
)
.unwrap();
let retry_counters = IntCounterVec::new(
metrics::core::Opts::new(
"backend_requests_retries_total",
"Number of retried requests from backends.",
),
&["request_kind"],
)
.unwrap();
Self {
request_counters,
retry_counters,
}
}
pub fn collect(&self) -> Vec<metrics::proto::MetricFamily> {
let mut metrics = Vec::new();
metrics.append(&mut self.request_counters.collect());
metrics.append(&mut self.retry_counters.collect());
metrics
}
}
pub struct PageserverClient {
_tenant_id: String,
_timeline_id: String,
_auth_token: Option<String>,
shard_map: HashMap<ShardIndex, String>,
channels: RwLock<HashMap<ShardIndex, Arc<client_cache::ConnectionPool<Channel>>>>,
auth_interceptor: AuthInterceptor,
client_cache_options: ClientCacheOptions,
aggregate_metrics: Option<Arc<PageserverClientAggregateMetrics>>,
}
#[derive(Clone)]
pub struct ClientCacheOptions {
pub max_consumers: usize,
pub error_threshold: usize,
pub connect_timeout: Duration,
pub connect_backoff: Duration,
pub max_idle_duration: Duration,
pub max_total_connections: usize,
pub max_delay_ms: u64,
pub drop_rate: f64,
pub hang_rate: f64,
}
impl PageserverClient {
/// TODO: this doesn't currently react to changes in the shard map.
pub fn new(
tenant_id: &str,
timeline_id: &str,
auth_token: &Option<String>,
shard_map: HashMap<ShardIndex, String>,
) -> Self {
let options = ClientCacheOptions {
max_consumers: 5000,
error_threshold: 5,
connect_timeout: Duration::from_secs(5),
connect_backoff: Duration::from_secs(1),
max_idle_duration: Duration::from_secs(60),
max_total_connections: 100000,
max_delay_ms: 0,
drop_rate: 0.0,
hang_rate: 0.0,
};
Self::new_with_config(tenant_id, timeline_id, auth_token, shard_map, options, None)
}
pub fn new_with_config(
tenant_id: &str,
timeline_id: &str,
auth_token: &Option<String>,
shard_map: HashMap<ShardIndex, String>,
options: ClientCacheOptions,
metrics: Option<Arc<PageserverClientAggregateMetrics>>,
) -> Self {
Self {
_tenant_id: tenant_id.to_string(),
_timeline_id: timeline_id.to_string(),
_auth_token: auth_token.clone(),
shard_map,
channels: RwLock::new(HashMap::new()),
auth_interceptor: AuthInterceptor::new(tenant_id, timeline_id, auth_token.as_deref()),
client_cache_options: options,
aggregate_metrics: metrics,
}
}
pub async fn process_check_rel_exists_request(
&self,
request: CheckRelExistsRequest,
) -> Result<bool, PageserverClientError> {
// Current sharding model assumes that all metadata is present only at shard 0.
let shard = ShardIndex::unsharded();
let pooled_client = self.get_client(shard).await;
let chan = pooled_client.channel();
let mut client =
PageServiceClient::with_interceptor(chan, self.auth_interceptor.for_shard(shard));
let request = proto::CheckRelExistsRequest::from(request);
let response = client.check_rel_exists(tonic::Request::new(request)).await;
match response {
Err(status) => {
pooled_client.finish(Err(status.clone())).await; // Pass error to finish
Err(PageserverClientError::RequestError(status))
}
Ok(resp) => {
pooled_client.finish(Ok(())).await; // Pass success to finish
Ok(resp.get_ref().exists)
}
}
}
pub async fn process_get_rel_size_request(
&self,
request: GetRelSizeRequest,
) -> Result<u32, PageserverClientError> {
// Current sharding model assumes that all metadata is present only at shard 0.
let shard = ShardIndex::unsharded();
let pooled_client = self.get_client(shard).await;
let chan = pooled_client.channel();
let mut client =
PageServiceClient::with_interceptor(chan, self.auth_interceptor.for_shard(shard));
let request = proto::GetRelSizeRequest::from(request);
let response = client.get_rel_size(tonic::Request::new(request)).await;
match response {
Err(status) => {
pooled_client.finish(Err(status.clone())).await; // Pass error to finish
Err(PageserverClientError::RequestError(status))
}
Ok(resp) => {
pooled_client.finish(Ok(())).await; // Pass success to finish
Ok(resp.get_ref().num_blocks)
}
}
}
// Request a single batch of pages
//
// TODO: This opens a new gRPC stream for every request, which is extremely inefficient
pub async fn get_page(
&self,
request: GetPageRequest,
) -> Result<Vec<Bytes>, PageserverClientError> {
// FIXME: calculate the shard number correctly
let shard = ShardIndex::unsharded();
let pooled_client = self.get_client(shard).await;
let chan = pooled_client.channel();
let mut client =
PageServiceClient::with_interceptor(chan, self.auth_interceptor.for_shard(shard));
let request = proto::GetPageRequest::from(request);
let request_stream = futures::stream::once(std::future::ready(request));
let mut response_stream = client
.get_pages(tonic::Request::new(request_stream))
.await?
.into_inner();
let Some(response) = response_stream.next().await else {
return Err(PageserverClientError::Other(
"no response received for getpage request".to_string(),
));
};
if let Some(ref metrics) = self.aggregate_metrics {
metrics
.request_counters
.with_label_values(&["get_page"])
.inc();
}
match response {
Err(status) => {
pooled_client.finish(Err(status.clone())).await; // Pass error to finish
Err(PageserverClientError::RequestError(status))
}
Ok(resp) => {
pooled_client.finish(Ok(())).await; // Pass success to finish
let response: GetPageResponse = resp.into();
Ok(response.page_images.to_vec())
}
}
}
// Open a stream for requesting pages
//
// TODO: This is a pretty low level interface, the caller should not need to be concerned
// with streams. But 'get_page' is currently very naive and inefficient.
pub async fn get_pages(
&self,
requests: impl Stream<Item = proto::GetPageRequest> + Send + 'static,
) -> std::result::Result<
tonic::Response<tonic::codec::Streaming<proto::GetPageResponse>>,
PageserverClientError,
> {
// FIXME: calculate the shard number correctly
let shard = ShardIndex::unsharded();
let pooled_client = self.get_client(shard).await;
let chan = pooled_client.channel();
let mut client =
PageServiceClient::with_interceptor(chan, self.auth_interceptor.for_shard(shard));
let response = client.get_pages(tonic::Request::new(requests)).await;
match response {
Err(status) => {
pooled_client.finish(Err(status.clone())).await; // Pass error to finish
Err(PageserverClientError::RequestError(status))
}
Ok(resp) => Ok(resp),
}
}
/// Process a request to get the size of a database.
pub async fn process_get_dbsize_request(
&self,
request: GetDbSizeRequest,
) -> Result<u64, PageserverClientError> {
// Current sharding model assumes that all metadata is present only at shard 0.
let shard = ShardIndex::unsharded();
let pooled_client = self.get_client(shard).await;
let chan = pooled_client.channel();
let mut client =
PageServiceClient::with_interceptor(chan, self.auth_interceptor.for_shard(shard));
let request = proto::GetDbSizeRequest::from(request);
let response = client.get_db_size(tonic::Request::new(request)).await;
match response {
Err(status) => {
pooled_client.finish(Err(status.clone())).await; // Pass error to finish
Err(PageserverClientError::RequestError(status))
}
Ok(resp) => {
pooled_client.finish(Ok(())).await; // Pass success to finish
Ok(resp.get_ref().num_bytes)
}
}
}
/// Process a request to get the size of a database.
pub async fn get_base_backup(
&self,
request: GetBaseBackupRequest,
gzip: bool,
) -> std::result::Result<
tonic::Response<tonic::codec::Streaming<proto::GetBaseBackupResponseChunk>>,
PageserverClientError,
> {
// Current sharding model assumes that all metadata is present only at shard 0.
let shard = ShardIndex::unsharded();
let pooled_client = self.get_client(shard).await;
let chan = pooled_client.channel();
let mut client =
PageServiceClient::with_interceptor(chan, self.auth_interceptor.for_shard(shard));
if gzip {
client = client.accept_compressed(tonic::codec::CompressionEncoding::Gzip);
}
let request = proto::GetBaseBackupRequest::from(request);
let response = client.get_base_backup(tonic::Request::new(request)).await;
match response {
Err(status) => {
pooled_client.finish(Err(status.clone())).await; // Pass error to finish
Err(PageserverClientError::RequestError(status))
}
Ok(resp) => {
pooled_client.finish(Ok(())).await; // Pass success to finish
Ok(resp)
}
}
}
/// Get a client for given shard
///
/// Get a client from the pool for this shard, also creating the pool if it doesn't exist.
///
async fn get_client(&self, shard: ShardIndex) -> client_cache::PooledClient<Channel> {
let reused_pool: Option<Arc<client_cache::ConnectionPool<Channel>>> = {
let channels = self.channels.read().unwrap();
channels.get(&shard).cloned()
};
let usable_pool = match reused_pool {
Some(pool) => {
let pooled_client = pool.get_client().await.unwrap();
return pooled_client;
}
None => {
// Create a new pool using client_cache_options
// declare new_pool
let channel_fact = Arc::new(client_cache::ChannelFactory::new(
self.shard_map.get(&shard).unwrap().clone(),
self.client_cache_options.max_delay_ms,
self.client_cache_options.drop_rate,
self.client_cache_options.hang_rate,
));
let new_pool = client_cache::ConnectionPool::new(
channel_fact,
self.client_cache_options.connect_timeout,
self.client_cache_options.connect_backoff,
self.client_cache_options.max_consumers,
self.client_cache_options.error_threshold,
self.client_cache_options.max_idle_duration,
self.client_cache_options.max_total_connections,
self.aggregate_metrics.clone(),
);
let mut write_pool = self.channels.write().unwrap();
write_pool.insert(shard, new_pool.clone());
new_pool.clone()
}
};
usable_pool.get_client().await.unwrap()
}
}
/// Inject tenant_id, timeline_id and authentication token to all pageserver requests.
#[derive(Clone)]
pub struct AuthInterceptor {
tenant_id: AsciiMetadataValue,
shard_id: Option<AsciiMetadataValue>,
timeline_id: AsciiMetadataValue,
auth_header: Option<AsciiMetadataValue>, // including "Bearer " prefix
}
impl AuthInterceptor {
pub fn new(tenant_id: &str, timeline_id: &str, auth_token: Option<&str>) -> Self {
Self {
tenant_id: tenant_id.parse().expect("could not parse tenant id"),
shard_id: None,
timeline_id: timeline_id.parse().expect("could not parse timeline id"),
auth_header: auth_token
.map(|t| format!("Bearer {t}"))
.map(|t| t.parse().expect("could not parse auth token")),
}
}
fn for_shard(&self, shard_id: ShardIndex) -> Self {
let mut with_shard = self.clone();
with_shard.shard_id = Some(
shard_id
.to_string()
.parse()
.expect("could not parse shard id"),
);
with_shard
}
}
impl tonic::service::Interceptor for AuthInterceptor {
fn call(&mut self, mut req: tonic::Request<()>) -> Result<tonic::Request<()>, tonic::Status> {
req.metadata_mut()
.insert("neon-tenant-id", self.tenant_id.clone());
if let Some(shard_id) = &self.shard_id {
req.metadata_mut().insert("neon-shard-id", shard_id.clone());
}
req.metadata_mut()
.insert("neon-timeline-id", self.timeline_id.clone());
if let Some(auth_header) = &self.auth_header {
req.metadata_mut()
.insert("authorization", auth_header.clone());
}
Ok(req)
}
}
pub use client::PageserverClient;

View File

@@ -39,7 +39,7 @@ use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc, oneshot};
use tonic::transport::{Channel, Endpoint};
use tracing::warn;
use pageserver_page_api::{self as page_api, GetPageRequest, GetPageResponse};
use pageserver_page_api as page_api;
use utils::id::{TenantId, TimelineId};
use utils::shard::ShardIndex;
@@ -358,9 +358,9 @@ pub struct StreamPool {
}
type StreamID = usize;
type RequestSender = Sender<(GetPageRequest, ResponseSender)>;
type RequestReceiver = Receiver<(GetPageRequest, ResponseSender)>;
type ResponseSender = oneshot::Sender<tonic::Result<GetPageResponse>>;
type RequestSender = Sender<(page_api::GetPageRequest, ResponseSender)>;
type RequestReceiver = Receiver<(page_api::GetPageRequest, ResponseSender)>;
type ResponseSender = oneshot::Sender<tonic::Result<page_api::GetPageResponse>>;
struct StreamEntry {
/// Sends caller requests to the stream task. The stream task exits when this is dropped.
@@ -400,7 +400,10 @@ impl StreamPool {
/// * Allow spinning up multiple streams concurrently, but don't overshoot limits.
///
/// For now, we just do something simple and functional, but very inefficient (linear scan).
pub async fn send(&self, req: GetPageRequest) -> tonic::Result<GetPageResponse> {
pub async fn send(
&self,
req: page_api::GetPageRequest,
) -> tonic::Result<page_api::GetPageResponse> {
// Acquire a permit. For simplicity, we drop it when this method returns. This may exceed
// the queue depth if a caller goes away while a request is in flight, but that's okay. We
// do the same for queue depth tracking.

View File

@@ -1,578 +0,0 @@
//! The request tracker dispatches GetPage- and other requests to pageservers, managing a pool of
//! connections and gRPC streams.
//!
//! There is usually one global instance of ShardedRequestTracker in an application, in particular
//! in the neon extension's communicator process. The application calls the async functions in
//! ShardedRequestTracker, which routes them to the correct pageservers, taking sharding into
//! account. In the future, there can be multiple pageservers per shard, and RequestTracker manages
//! load balancing between them, but that's not implemented yet.
use crate::AuthInterceptor;
use crate::ClientCacheOptions;
use crate::PageserverClientAggregateMetrics;
use crate::client_cache;
use crate::client_cache::ChannelFactory;
use crate::client_cache::ConnectionPool;
use pageserver_page_api::GetPageRequest;
use pageserver_page_api::GetPageResponse;
use pageserver_page_api::proto;
use pageserver_page_api::*;
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use tonic::{Request, transport::Channel};
use utils::shard::ShardIndex;
use pageserver_page_api::proto::PageServiceClient;
use tokio_stream::wrappers::ReceiverStream;
use tonic::{Code, Status};
use async_trait::async_trait;
use std::time::Duration;
use client_cache::PooledItemFactory;
/// StreamReturner represents a gRPC stream to a pageserver.
///
/// To send a request:
/// 1. insert the request's ID, along with a channel to receive the response
/// 2. send the request to 'sender'
#[derive(Clone)]
pub struct StreamReturner {
sender: tokio::sync::mpsc::Sender<proto::GetPageRequest>,
#[allow(clippy::type_complexity)]
sender_hashmap: Arc<
tokio::sync::Mutex<
Option<
std::collections::HashMap<
u64,
tokio::sync::mpsc::Sender<Result<proto::GetPageResponse, Status>>,
>,
>,
>,
>,
}
pub struct StreamFactory {
connection_pool: Arc<client_cache::ConnectionPool<Channel>>,
auth_interceptor: AuthInterceptor,
shard: ShardIndex,
}
impl StreamFactory {
pub fn new(
connection_pool: Arc<ConnectionPool<Channel>>,
auth_interceptor: AuthInterceptor,
shard: ShardIndex,
) -> Self {
StreamFactory {
connection_pool,
auth_interceptor,
shard,
}
}
}
#[async_trait]
impl PooledItemFactory<StreamReturner> for StreamFactory {
async fn create(
&self,
_connect_timeout: Duration,
) -> Result<Result<StreamReturner, tonic::Status>, tokio::time::error::Elapsed> {
let pool_clone: Arc<ConnectionPool<Channel>> = Arc::clone(&self.connection_pool);
let pooled_client = pool_clone.get_client().await;
let channel = pooled_client.unwrap().channel();
let mut client = PageServiceClient::with_interceptor(
channel,
self.auth_interceptor.for_shard(self.shard),
);
let (sender, receiver) = tokio::sync::mpsc::channel::<proto::GetPageRequest>(1000);
let outbound = ReceiverStream::new(receiver);
let client_resp = client.get_pages(Request::new(outbound)).await;
match client_resp {
Err(status) => {
// TODO: Convert this error correctly
Ok(Err(tonic::Status::new(
status.code(),
format!("Failed to connect to pageserver: {}", status.message()),
)))
}
Ok(resp) => {
let stream_returner = StreamReturner {
sender: sender.clone(),
sender_hashmap: Arc::new(tokio::sync::Mutex::new(Some(
std::collections::HashMap::new(),
))),
};
let map = Arc::clone(&stream_returner.sender_hashmap);
tokio::spawn(async move {
let map_clone = Arc::clone(&map);
let mut inner = resp.into_inner();
loop {
match inner.message().await {
Err(e) => {
tracing::info!("error received on getpage stream: {e}");
break; // Exit the loop if no more messages
}
Ok(None) => {
break; // Sender closed the stream
}
Ok(Some(response)) => {
// look up stream in hash map
let mut hashmap = map_clone.lock().await;
let hashmap =
hashmap.as_mut().expect("no other task clears the hashmap");
if let Some(sender) = hashmap.get(&response.request_id) {
// Send the response to the original request sender
if let Err(e) = sender.send(Ok(response.clone())).await {
eprintln!("Failed to send response: {e}");
}
hashmap.remove(&response.request_id);
} else {
eprintln!(
"No sender found for request ID: {}",
response.request_id
);
}
}
}
}
// Don't accept any more requests
// Close every sender stream in the hashmap
let mut hashmap_opt = map_clone.lock().await;
let hashmap = hashmap_opt
.as_mut()
.expect("no other task clears the hashmap");
for sender in hashmap.values() {
let error = Status::new(Code::Unknown, "Stream closed");
if let Err(e) = sender.send(Err(error)).await {
eprintln!("Failed to send close response: {e}");
}
}
*hashmap_opt = None;
});
Ok(Ok(stream_returner))
}
}
}
}
#[derive(Clone)]
pub struct RequestTracker {
_cur_id: Arc<AtomicU64>,
stream_pool: Arc<ConnectionPool<StreamReturner>>,
unary_pool: Arc<ConnectionPool<Channel>>,
auth_interceptor: AuthInterceptor,
shard: ShardIndex,
}
impl RequestTracker {
pub fn new(
stream_pool: Arc<ConnectionPool<StreamReturner>>,
unary_pool: Arc<ConnectionPool<Channel>>,
auth_interceptor: AuthInterceptor,
shard: ShardIndex,
) -> Self {
let cur_id = Arc::new(AtomicU64::new(0));
RequestTracker {
_cur_id: cur_id.clone(),
stream_pool,
unary_pool,
auth_interceptor,
shard,
}
}
pub async fn send_process_check_rel_exists_request(
&self,
req: CheckRelExistsRequest,
) -> Result<bool, tonic::Status> {
loop {
let unary_pool = Arc::clone(&self.unary_pool);
let pooled_client = unary_pool.get_client().await.unwrap();
let channel = pooled_client.channel();
let mut ps_client = PageServiceClient::with_interceptor(
channel,
self.auth_interceptor.for_shard(self.shard),
);
let request = proto::CheckRelExistsRequest::from(req);
let response = ps_client
.check_rel_exists(tonic::Request::new(request))
.await;
match response {
Err(status) => {
pooled_client.finish(Err(status.clone())).await; // Pass error to finish
continue;
}
Ok(resp) => {
pooled_client.finish(Ok(())).await; // Pass success to finish
return Ok(resp.get_ref().exists);
}
}
}
}
pub async fn send_process_get_rel_size_request(
&self,
req: GetRelSizeRequest,
) -> Result<u32, tonic::Status> {
loop {
// Current sharding model assumes that all metadata is present only at shard 0.
let unary_pool = Arc::clone(&self.unary_pool);
let pooled_client = unary_pool.get_client().await.unwrap();
let channel = pooled_client.channel();
let mut ps_client = PageServiceClient::with_interceptor(
channel,
self.auth_interceptor.for_shard(self.shard),
);
let request = proto::GetRelSizeRequest::from(req);
let response = ps_client.get_rel_size(tonic::Request::new(request)).await;
match response {
Err(status) => {
tracing::info!("send_process_get_rel_size_request: got error {status}, retrying");
pooled_client.finish(Err(status.clone())).await; // Pass error to finish
continue;
}
Ok(resp) => {
pooled_client.finish(Ok(())).await; // Pass success to finish
return Ok(resp.get_ref().num_blocks);
}
}
}
}
pub async fn send_process_get_dbsize_request(
&self,
req: GetDbSizeRequest,
) -> Result<u64, tonic::Status> {
loop {
// Current sharding model assumes that all metadata is present only at shard 0.
let unary_pool = Arc::clone(&self.unary_pool);
let pooled_client = unary_pool.get_client().await.unwrap();
let channel = pooled_client.channel();
let mut ps_client = PageServiceClient::with_interceptor(
channel,
self.auth_interceptor.for_shard(self.shard),
);
let request = proto::GetDbSizeRequest::from(req);
let response = ps_client.get_db_size(tonic::Request::new(request)).await;
match response {
Err(status) => {
pooled_client.finish(Err(status.clone())).await; // Pass error to finish
continue;
}
Ok(resp) => {
pooled_client.finish(Ok(())).await; // Pass success to finish
return Ok(resp.get_ref().num_bytes);
}
}
}
}
pub async fn send_getpage_request(
&mut self,
req: GetPageRequest,
) -> Result<GetPageResponse, tonic::Status> {
loop {
let request = req.clone();
// Increment cur_id
//let request_id = self.cur_id.fetch_add(1, Ordering::SeqCst) + 1;
let request_id = request.request_id;
let response_sender: tokio::sync::mpsc::Sender<Result<proto::GetPageResponse, Status>>;
let mut response_receiver: tokio::sync::mpsc::Receiver<
Result<proto::GetPageResponse, Status>,
>;
(response_sender, response_receiver) = tokio::sync::mpsc::channel(1);
//request.request_id = request_id;
// Get a stream from the stream pool
let pool_clone = Arc::clone(&self.stream_pool);
let sender_stream_pool = pool_clone.get_client().await;
let stream_returner = match sender_stream_pool {
Ok(stream_ret) => stream_ret,
Err(_e) => {
// retry
continue;
}
};
let returner = stream_returner.channel();
let map = returner.sender_hashmap.clone();
// Insert the response sender into the hashmap
{
if let Some(map_inner) = map.lock().await.as_mut() {
let old = map_inner.insert(request_id, response_sender);
// request IDs must be unique
if old.is_some() {
panic!("request with ID {request_id} is already in-flight");
}
} else {
// The stream was closed. Try a different one.
tracing::info!("stream was concurrently closed");
continue;
}
}
let sent = returner
.sender
.send(proto::GetPageRequest::from(request))
.await;
if let Err(_e) = sent {
// Remove the request from the map if sending failed
{
if let Some(map_inner) = map.lock().await.as_mut() {
// remove from hashmap
map_inner.remove(&request_id);
}
}
stream_returner
.finish(Err(Status::new(Code::Unknown, "Failed to send request")))
.await;
continue;
}
let response = response_receiver.recv().await;
match response {
Some(resp) => {
match resp {
Err(_status) => {
// Handle the case where the response was not received
stream_returner
.finish(Err(Status::new(
Code::Unknown,
"Failed to receive response",
)))
.await;
continue;
}
Ok(resp) => {
stream_returner.finish(Result::Ok(())).await;
return Ok(resp.clone().into());
}
}
}
None => {
// Handle the case where the response channel was closed
stream_returner
.finish(Err(Status::new(Code::Unknown, "Response channel closed")))
.await;
continue;
}
}
}
}
}
struct ShardedRequestTrackerInner {
// Hashmap of shard index to RequestTracker
trackers: std::collections::HashMap<ShardIndex, RequestTracker>,
}
pub struct ShardedRequestTracker {
inner: Arc<std::sync::Mutex<ShardedRequestTrackerInner>>,
tcp_client_cache_options: ClientCacheOptions,
stream_client_cache_options: ClientCacheOptions,
}
//
// TODO: Functions in the ShardedRequestTracker should be able to timeout and
// cancel a reqeust. The request should return an error if it is cancelled.
//
impl Default for ShardedRequestTracker {
fn default() -> Self {
ShardedRequestTracker::new()
}
}
impl ShardedRequestTracker {
pub fn new() -> Self {
//
// Default configuration for the client. These could be added to a config file
//
let tcp_client_cache_options = ClientCacheOptions {
max_delay_ms: 0,
drop_rate: 0.0,
hang_rate: 0.0,
connect_timeout: Duration::from_secs(1),
connect_backoff: Duration::from_millis(100),
max_consumers: 8, // Streams per connection
error_threshold: 10,
max_idle_duration: Duration::from_secs(5),
max_total_connections: 8,
};
let stream_client_cache_options = ClientCacheOptions {
max_delay_ms: 0,
drop_rate: 0.0,
hang_rate: 0.0,
connect_timeout: Duration::from_secs(1),
connect_backoff: Duration::from_millis(100),
max_consumers: 64, // Requests per stream
error_threshold: 10,
max_idle_duration: Duration::from_secs(5),
max_total_connections: 64, // Total allowable number of streams
};
ShardedRequestTracker {
inner: Arc::new(std::sync::Mutex::new(ShardedRequestTrackerInner {
trackers: std::collections::HashMap::new(),
})),
tcp_client_cache_options,
stream_client_cache_options,
}
}
pub async fn update_shard_map(
&self,
shard_urls: std::collections::HashMap<ShardIndex, String>,
metrics: Option<Arc<PageserverClientAggregateMetrics>>,
tenant_id: String,
timeline_id: String,
auth_str: Option<&str>,
) {
let mut trackers = std::collections::HashMap::new();
for (shard, endpoint_url) in shard_urls {
//
// Create a pool of streams for streaming get_page requests
//
let channel_fact: Arc<dyn PooledItemFactory<Channel> + Send + Sync> =
Arc::new(ChannelFactory::new(
endpoint_url.clone(),
self.tcp_client_cache_options.max_delay_ms,
self.tcp_client_cache_options.drop_rate,
self.tcp_client_cache_options.hang_rate,
));
let new_pool = ConnectionPool::new(
Arc::clone(&channel_fact),
self.tcp_client_cache_options.connect_timeout,
self.tcp_client_cache_options.connect_backoff,
self.tcp_client_cache_options.max_consumers,
self.tcp_client_cache_options.error_threshold,
self.tcp_client_cache_options.max_idle_duration,
self.tcp_client_cache_options.max_total_connections,
metrics.clone(),
);
let auth_interceptor =
AuthInterceptor::new(tenant_id.as_str(), timeline_id.as_str(), auth_str);
let stream_pool = ConnectionPool::<StreamReturner>::new(
Arc::new(StreamFactory::new(
new_pool.clone(),
auth_interceptor.clone(),
ShardIndex::unsharded(),
)),
self.stream_client_cache_options.connect_timeout,
self.stream_client_cache_options.connect_backoff,
self.stream_client_cache_options.max_consumers,
self.stream_client_cache_options.error_threshold,
self.stream_client_cache_options.max_idle_duration,
self.stream_client_cache_options.max_total_connections,
metrics.clone(),
);
//
// Create a client pool for unary requests
//
let unary_pool = ConnectionPool::new(
Arc::clone(&channel_fact),
self.tcp_client_cache_options.connect_timeout,
self.tcp_client_cache_options.connect_backoff,
self.tcp_client_cache_options.max_consumers,
self.tcp_client_cache_options.error_threshold,
self.tcp_client_cache_options.max_idle_duration,
self.tcp_client_cache_options.max_total_connections,
metrics.clone(),
);
//
// Create a new RequestTracker for this shard
//
let new_tracker = RequestTracker::new(stream_pool, unary_pool, auth_interceptor, shard);
trackers.insert(shard, new_tracker);
}
let mut inner = self.inner.lock().unwrap();
inner.trackers = trackers;
}
pub async fn get_page(&self, req: GetPageRequest) -> Result<GetPageResponse, tonic::Status> {
// Get shard index from the request and look up the RequestTracker instance for that shard
let shard_index = ShardIndex::unsharded(); // TODO!
let mut tracker = self.lookup_tracker_for_shard(shard_index)?;
let response = tracker.send_getpage_request(req).await;
match response {
Ok(resp) => Ok(resp),
Err(e) => Err(tonic::Status::unknown(format!("Failed to get page: {e}"))),
}
}
pub async fn process_get_dbsize_request(
&self,
request: GetDbSizeRequest,
) -> Result<u64, tonic::Status> {
// Current sharding model assumes that all metadata is present only at shard 0.
let tracker = self.lookup_tracker_for_shard(ShardIndex::unsharded())?;
let response = tracker.send_process_get_dbsize_request(request).await;
match response {
Ok(resp) => Ok(resp),
Err(e) => Err(e),
}
}
pub async fn process_get_rel_size_request(
&self,
request: GetRelSizeRequest,
) -> Result<u32, tonic::Status> {
// Current sharding model assumes that all metadata is present only at shard 0.
let tracker = self.lookup_tracker_for_shard(ShardIndex::unsharded())?;
let response = tracker.send_process_get_rel_size_request(request).await;
match response {
Ok(resp) => Ok(resp),
Err(e) => Err(e),
}
}
pub async fn process_check_rel_exists_request(
&self,
request: CheckRelExistsRequest,
) -> Result<bool, tonic::Status> {
// Current sharding model assumes that all metadata is present only at shard 0.
let tracker = self.lookup_tracker_for_shard(ShardIndex::unsharded())?;
let response = tracker.send_process_check_rel_exists_request(request).await;
match response {
Ok(resp) => Ok(resp),
Err(e) => Err(e),
}
}
#[allow(clippy::result_large_err)]
fn lookup_tracker_for_shard(
&self,
shard_index: ShardIndex,
) -> Result<RequestTracker, tonic::Status> {
let inner = self.inner.lock().unwrap();
if let Some(t) = inner.trackers.get(&shard_index) {
Ok(t.clone())
} else {
Err(tonic::Status::not_found(format!(
"Shard {shard_index} not found",
)))
}
}
}

View File

@@ -26,17 +26,6 @@ use utils::id::TenantTimelineId;
use utils::lsn::Lsn;
use utils::shard::ShardIndex;
use axum::Router;
use axum::body::Body;
use axum::extract::State;
use axum::response::Response;
use http::StatusCode;
use http::header::CONTENT_TYPE;
use metrics::proto::MetricFamily;
use metrics::{Encoder, TextEncoder};
use crate::util::tokio_thread_local_stats::AllThreadLocalStats;
use crate::util::{request_stats, tokio_thread_local_stats};
@@ -185,62 +174,12 @@ pub(crate) fn main(args: Args) -> anyhow::Result<()> {
main_impl(args, thread_local_stats)
})
}
async fn get_metrics(
State(state): State<Arc<pageserver_client_grpc::PageserverClientAggregateMetrics>>,
) -> Response {
let metrics = state.collect();
info!("metrics: {metrics:?}");
// When we call TextEncoder::encode() below, it will immediately return an
// error if a metric family has no metrics, so we need to preemptively
// filter out metric families with no metrics.
let metrics = metrics
.into_iter()
.filter(|m| !m.get_metric().is_empty())
.collect::<Vec<MetricFamily>>();
let encoder = TextEncoder::new();
let mut buffer = vec![];
if let Err(e) = encoder.encode(&metrics, &mut buffer) {
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.header(CONTENT_TYPE, "application/text")
.body(Body::from(e.to_string()))
.unwrap()
} else {
Response::builder()
.status(StatusCode::OK)
.header(CONTENT_TYPE, encoder.format_type())
.body(Body::from(buffer))
.unwrap()
}
}
async fn main_impl(
args: Args,
all_thread_local_stats: AllThreadLocalStats<request_stats::Stats>,
) -> anyhow::Result<()> {
let args: &'static Args = Box::leak(Box::new(args));
// Vector of pageserver clients
let client_metrics = Arc::new(pageserver_client_grpc::PageserverClientAggregateMetrics::new());
use axum::routing::get;
let app = Router::new()
.route("/metrics", get(get_metrics))
.with_state(client_metrics.clone());
// TODO: make configurable. Or listen on unix domain socket?
let listener = tokio::net::TcpListener::bind("127.0.0.1:9090")
.await
.unwrap();
tokio::spawn(async {
tracing::info!("metrics listener spawned");
axum::serve(listener, app).await.unwrap()
});
let mgmt_api_client = Arc::new(pageserver_client::mgmt_api::Client::new(
reqwest::Client::new(), // TODO: support ssl_ca_file for https APIs in pagebench.
args.mgmt_api_endpoint.clone(),

View File

@@ -13,7 +13,7 @@ use crate::integrated_cache::{CacheResult, IntegratedCacheWriteAccess};
use crate::neon_request::{CGetPageVRequest, CPrefetchVRequest};
use crate::neon_request::{NeonIORequest, NeonIOResult};
use crate::worker_process::in_progress_ios::{RequestInProgressKey, RequestInProgressTable};
use pageserver_client_grpc::client::PageserverClient;
use pageserver_client_grpc::PageserverClient;
use pageserver_page_api as page_api;
use metrics::{IntCounter, IntCounterVec};