Request Tracker Prototype

Does not include splitting requests across shards.
This commit is contained in:
Elizabeth Murray
2025-06-05 13:32:18 -07:00
committed by GitHub
parent 786888d93f
commit 68f18ccacf
10 changed files with 1282 additions and 134 deletions

5
Cargo.lock generated
View File

@@ -4592,17 +4592,22 @@ dependencies = [
name = "pageserver_client_grpc"
version = "0.1.0"
dependencies = [
"async-trait",
"bytes",
"chrono",
"dashmap 5.5.0",
"futures",
"http 1.1.0",
"hyper 1.6.0",
"hyper-util",
"metrics",
"pageserver_api",
"pageserver_page_api",
"priority-queue",
"rand 0.8.5",
"thiserror 1.0.69",
"tokio",
"tokio-stream",
"tokio-util",
"tonic 0.13.1",
"tower 0.4.13",

View File

@@ -19,7 +19,12 @@ hyper-util = "0.1.9"
hyper = "1.6.0"
metrics.workspace = true
priority-queue = "2.3.1"
async-trait = { version = "0.1" }
tokio-stream = "0.1"
dashmap = "5"
chrono = { version = "0.4", features = ["serde"] }
pageserver_page_api.workspace = true
pageserver_api.workspace = true
utils.workspace = true

View File

@@ -0,0 +1,296 @@
// examples/load_test.rs, generated by AI
use std::collections::{HashMap, HashSet};
use std::sync::{
Arc,
Mutex,
atomic::{AtomicU64, AtomicUsize, Ordering},
};
use std::time::{Duration, Instant};
use tokio::task;
use tokio::time::sleep;
use rand::Rng;
use tonic::Status;
use uuid::Uuid;
// Pull in your ConnectionPool and PooledItemFactory from the pageserver_client_grpc crate.
// Adjust these paths if necessary.
use pageserver_client_grpc::client_cache::ConnectionPool;
use pageserver_client_grpc::client_cache::PooledItemFactory;
// --------------------------------------
// GLOBAL COUNTERS FOR “CREATED” / “DROPPED” MockConnections
// --------------------------------------
static CREATED: AtomicU64 = AtomicU64::new(0);
static DROPPED: AtomicU64 = AtomicU64::new(0);
// --------------------------------------
// MockConnection + Factory
// --------------------------------------
#[derive(Debug)]
pub struct MockConnection {
pub id: u64,
}
impl Clone for MockConnection {
fn clone(&self) -> Self {
// Cloning a MockConnection does NOT count as “creating” a brandnew connection,
// so we do NOT bump CREATED here. We only bump CREATED in the factorys `create()`.
CREATED.fetch_add(1, Ordering::Relaxed);
MockConnection { id: self.id }
}
}
impl Drop for MockConnection {
fn drop(&mut self) {
// When a MockConnection actually gets dropped, bump the counter.
DROPPED.fetch_add(1, Ordering::SeqCst);
}
}
pub struct MockConnectionFactory {
counter: AtomicU64,
}
impl MockConnectionFactory {
pub fn new() -> Self {
MockConnectionFactory {
counter: AtomicU64::new(1),
}
}
}
#[async_trait::async_trait]
impl PooledItemFactory<MockConnection> for MockConnectionFactory {
/// The trait on ConnectionPool expects:
/// async fn create(&self, timeout: Duration)
/// -> Result<Result<MockConnection, Status>, tokio::time::error::Elapsed>;
///
/// On success: Ok(Ok(MockConnection))
/// On a simulated “gRPC” failure: Ok(Err(Status::…))
/// On a transport/factory error: Err(Box<…>)
async fn create(
&self,
_timeout: Duration,
) -> Result<Result<MockConnection, Status>, tokio::time::error::Elapsed> {
// Simulate connection creation immediately succeeding.
CREATED.fetch_add(1, Ordering::SeqCst);
let next_id = self.counter.fetch_add(1, Ordering::Relaxed);
Ok(Ok(MockConnection { id: next_id }))
}
}
// --------------------------------------
// CLIENT WORKER
// --------------------------------------
//
// Each worker repeatedly calls `pool.get_client().await`. When it succeeds, we:
// 1. Lock the shared Mutex<HashMap<u64, Arc<AtomicUsize>>> to fetch/insert an Arc<AtomicUsize> for this conn_id.
// 2. Lock the shared Mutex<HashSet<u64>> to record this conn_id as “seen.”
// 3. Drop both locks, then atomically increment that counter and assert it ≤ max_consumers.
// 4. Sleep 10100 ms to simulate “work.”
// 5. Atomically decrement the counter.
// 6. Call `pooled.finish(Ok(()))` to return to the pool.
async fn client_worker(
pool: Arc<ConnectionPool<MockConnection>>,
usage_map: Arc<Mutex<HashMap<u64, Arc<AtomicUsize>>>>,
seen_set: Arc<Mutex<HashSet<u64>>>,
max_consumers: usize,
worker_id: usize,
) {
for iteration in 0..10 {
match pool.clone().get_client().await {
Ok(pooled) => {
let conn: MockConnection = pooled.channel();
let conn_id = conn.id;
// 1. Fetch or insert the Arc<AtomicUsize> for this conn_id:
let counter_arc: Arc<AtomicUsize> = {
let mut guard = usage_map.lock().unwrap();
guard
.entry(conn_id)
.or_insert_with(|| Arc::new(AtomicUsize::new(0)))
.clone()
// MutexGuard is dropped here
};
// 2. Record this conn_id in the shared HashSet of “seen” IDs:
{
let mut seen_guard = seen_set.lock().unwrap();
seen_guard.insert(conn_id);
// MutexGuard is dropped immediately
}
// 3. Atomically bump the count for this connection ID
let prev = counter_arc.fetch_add(1, Ordering::SeqCst);
let current = prev + 1;
assert!(
current <= max_consumers,
"Connection {} exceeded max_consumers (got {})",
conn_id,
current
);
println!(
"[worker {}][iter {}] got MockConnection id={} ({} concurrent)",
worker_id, iteration, conn_id, current
);
// 4. Simulate some work (10100 ms)
let delay_ms = rand::thread_rng().gen_range(10..100);
sleep(Duration::from_millis(delay_ms)).await;
// 5. Decrement the usage counter
let prev2 = counter_arc.fetch_sub(1, Ordering::SeqCst);
let after = prev2 - 1;
println!(
"[worker {}][iter {}] returning MockConnection id={} (now {} remain)",
worker_id, iteration, conn_id, after
);
// 6. Return to the pool (mark success)
pooled.finish(Ok(())).await;
}
Err(status) => {
eprintln!(
"[worker {}][iter {}] failed to get client: {:?}",
worker_id, iteration, status
);
}
}
// Small random pause before next iteration to spread out load
let pause = rand::thread_rng().gen_range(0..20);
sleep(Duration::from_millis(pause)).await;
}
}
#[tokio::main(flavor = "multi_thread", worker_threads = 8)]
async fn main() {
// --------------------------------------
// 1. Create factory and shared instrumentation
// --------------------------------------
let factory = Arc::new(MockConnectionFactory::new());
// Shared map: connection ID → Arc<AtomicUsize>
let usage_map: Arc<Mutex<HashMap<u64, Arc<AtomicUsize>>>> =
Arc::new(Mutex::new(HashMap::new()));
// Shared set: record each unique connection ID we actually saw
let seen_set: Arc<Mutex<HashSet<u64>>> = Arc::new(Mutex::new(HashSet::new()));
// --------------------------------------
// 2. Pool parameters
// --------------------------------------
let connect_timeout = Duration::from_millis(500);
let connect_backoff = Duration::from_millis(100);
let max_consumers = 100; // test limit
let error_threshold = 2; // mock never fails
let max_idle_duration = Duration::from_secs(2);
let max_total_connections = 3;
let aggregate_metrics = None;
let pool: Arc<ConnectionPool<MockConnection>> = ConnectionPool::new(
factory,
connect_timeout,
connect_backoff,
max_consumers,
error_threshold,
max_idle_duration,
max_total_connections,
aggregate_metrics,
);
// --------------------------------------
// 3. Spawn worker tasks
// --------------------------------------
let num_workers = 10000;
let mut handles = Vec::with_capacity(num_workers);
let start_time = Instant::now();
for worker_id in 0..num_workers {
let pool_clone = Arc::clone(&pool);
let usage_clone = Arc::clone(&usage_map);
let seen_clone = Arc::clone(&seen_set);
let mc = max_consumers;
let handle = task::spawn(async move {
client_worker(pool_clone, usage_clone, seen_clone, mc, worker_id).await;
});
handles.push(handle);
}
// --------------------------------------
// 4. Wait for workers to finish
// --------------------------------------
for handle in handles {
let _ = handle.await;
}
let elapsed = Instant::now().duration_since(start_time);
println!(
"All {} workers completed in {:?}",
num_workers, elapsed
);
// --------------------------------------
// 5. Print the total number of unique connections seen so far
// --------------------------------------
let unique_count = {
let seen_guard = seen_set.lock().unwrap();
seen_guard.len()
};
println!("Total unique connections used by workers: {}", unique_count);
// --------------------------------------
// 6. Sleep so the background sweeper can run (max_idle_duration = 2 s)
// --------------------------------------
sleep(Duration::from_secs(3)).await;
// --------------------------------------
// 7. Shutdown the pool
// --------------------------------------
let shutdown_pool = Arc::clone(&pool);
shutdown_pool.shutdown().await;
println!("Pool.shutdown() returned.");
// --------------------------------------
// 8. Verify that no background task still holds an Arc clone of `pool`.
// If any task is still alive (sweeper/create_connection), strong_count > 1.
// --------------------------------------
sleep(Duration::from_secs(1)).await; // give tasks time to exit
let sc = Arc::strong_count(&pool);
assert!(
sc == 1,
"Pool tasks did not all terminate: Arc::strong_count = {} (expected 1)",
sc
);
println!("Verified: all pool tasks have terminated (strong_count == 1).");
// --------------------------------------
// 9. Verify no MockConnection was leaked:
// CREATED must equal DROPPED.
// --------------------------------------
let created = CREATED.load(Ordering::SeqCst);
let dropped = DROPPED.load(Ordering::SeqCst);
assert!(
created == dropped,
"Leaked connections: created={} but dropped={}",
created,
dropped
);
println!(
"Verified: no connections leaked (created = {}, dropped = {}).",
created, dropped
);
// --------------------------------------
// 10. Because `client_worker` asserted inside that no connection
// ever exceeded `max_consumers`, reaching this point means that check passed.
// --------------------------------------
println!("All per-connection usage stayed within max_consumers = {}.", max_consumers);
println!("Load test complete; exiting cleanly.");
}

View File

@@ -0,0 +1,160 @@
// examples/request_tracker_load_test.rs
use std::{sync::Arc, time::Duration};
use tokio;
use pageserver_client_grpc::request_tracker::RequestTracker;
use pageserver_client_grpc::request_tracker::MockStreamFactory;
use pageserver_client_grpc::request_tracker::StreamReturner;
use pageserver_client_grpc::client_cache::ConnectionPool;
use pageserver_client_grpc::client_cache::PooledItemFactory;
use pageserver_client_grpc::ClientCacheOptions;
use pageserver_client_grpc::PageserverClientAggregateMetrics;
use pageserver_client_grpc::AuthInterceptor;
use pageserver_client_grpc::client_cache::ChannelFactory;
use tonic::{transport::{Channel}, Request};
use rand::prelude::*;
use pageserver_api::key::Key;
use utils::lsn::Lsn;
use utils::id::TenantTimelineId;
use futures::stream::FuturesOrdered;
use futures::StreamExt;
// use chrono
use chrono::Utc;
use pageserver_page_api::{GetPageClass, GetPageResponse};
use pageserver_page_api::proto;
#[derive(Clone)]
struct KeyRange {
timeline: TenantTimelineId,
timeline_lsn: Lsn,
start: i128,
end: i128,
}
impl KeyRange {
fn len(&self) -> i128 {
self.end - self.start
}
}
#[tokio::main]
async fn main() {
// 1) configure the clientpool behavior
let client_cache_options = ClientCacheOptions {
max_delay_ms: 0,
drop_rate: 0.0,
hang_rate: 0.0,
connect_timeout: Duration::from_secs(10),
connect_backoff: Duration::from_millis(200),
max_consumers: 64,
error_threshold: 10,
max_idle_duration: Duration::from_secs(60),
max_total_connections: 12,
};
// 2) metrics collector (we assume Default is implemented)
let metrics = Arc::new(PageserverClientAggregateMetrics::new());
let pool = ConnectionPool::<StreamReturner>::new(
Arc::new(MockStreamFactory::new(
)),
client_cache_options.connect_timeout,
client_cache_options.connect_backoff,
client_cache_options.max_consumers,
client_cache_options.error_threshold,
client_cache_options.max_idle_duration,
client_cache_options.max_total_connections,
Some(Arc::clone(&metrics)),
);
// -----------
// There is no mock for the unary connection pool, so for now just
// don't use this pool
//
let channel_fact : Arc<dyn PooledItemFactory<Channel> + Send + Sync> = Arc::new(ChannelFactory::new(
"".to_string(),
client_cache_options.max_delay_ms,
client_cache_options.drop_rate,
client_cache_options.hang_rate,
));
let unary_pool: Arc<ConnectionPool<Channel>> = ConnectionPool::new(
Arc::clone(&channel_fact),
client_cache_options.connect_timeout,
client_cache_options.connect_backoff,
client_cache_options.max_consumers,
client_cache_options.error_threshold,
client_cache_options.max_idle_duration,
client_cache_options.max_total_connections,
Some(Arc::clone(&metrics)),
);
// -----------
// Dummy auth interceptor. This is not used in this test.
let auth_interceptor = AuthInterceptor::new("dummy_tenant_id",
"dummy_timeline_id",
None);
let mut tracker = RequestTracker::new(
pool,
unary_pool,
auth_interceptor,
);
// 4) fire off 10 000 requests in parallel
let mut handles = FuturesOrdered::new();
for i in 0..500000 {
let mut rng = rand::thread_rng();
let r = 0..=1000000i128;
let key: i128 = rng.gen_range(r.clone());
let key = Key::from_i128(key);
let (rel_tag, block_no) = key
.to_rel_block()
.expect("we filter non-rel-block keys out above");
let req2 = proto::GetPageRequest {
request_id: 0,
request_class: proto::GetPageClass::Normal as i32,
read_lsn: Some(proto::ReadLsn {
request_lsn: if rng.gen_bool(0.5) {
u64::from(Lsn::MAX)
} else {
10000
},
not_modified_since_lsn: 10000,
}),
rel: Some(rel_tag.into()),
block_number: vec![block_no],
};
let req_model = pageserver_page_api::GetPageRequest::try_from(req2.clone());
// RequestTracker is Clone, so we can share it
let mut tr = tracker.clone();
let fut = async move {
let resp = tr.send_getpage_request(req_model.unwrap()).await.unwrap();
// sanitycheck: the mock echo returns the same request_id
assert!(resp.request_id > 0);
};
handles.push_back(fut);
// empty future
let fut = async move {};
fut.await;
}
// print timestamp
println!("Starting 5000000 requests at: {}", chrono::Utc::now());
// 5) wait for them all
for i in 0..500000 {
handles.next().await.expect("Failed to get next handle");
}
// print timestamp
println!("Finished 5000000 requests at: {}", chrono::Utc::now());
println!("✅ All 100000 requests completed successfully");
}

View File

@@ -31,6 +31,7 @@ use hyper_util::rt::TokioIo;
use tower::service_fn;
use tokio_util::sync::CancellationToken;
use async_trait::async_trait;
//
// The "TokioTcp" is flakey TCP network for testing purposes, in order
@@ -164,32 +165,132 @@ impl AsyncWrite for TokioTcp {
}
}
/// A pooled gRPC client with capacity tracking and error handling.
pub struct ConnectionPool {
inner: Mutex<Inner>,
#[async_trait]
pub trait PooledItemFactory<T>: Send + Sync + 'static {
/// Create a new pooled item.
async fn create(&self, connect_timeout: Duration) -> Result<Result<T, tonic::Status>, tokio::time::error::Elapsed>;
}
// Config options that apply to each connection
pub struct ChannelFactory {
endpoint: String,
max_consumers: usize,
error_threshold: usize,
connect_timeout: Duration,
connect_backoff: Duration,
// Parameters for testing
max_delay_ms: u64,
drop_rate: f64,
hang_rate: f64,
}
// The maximum duration a connection can be idle before being removed
impl ChannelFactory {
pub fn new(
endpoint: String,
max_delay_ms: u64,
drop_rate: f64,
hang_rate: f64,
) -> Self {
ChannelFactory {
endpoint,
max_delay_ms,
drop_rate,
hang_rate,
}
}
}
#[async_trait]
impl PooledItemFactory<Channel> for ChannelFactory {
async fn create(&self, connect_timeout: Duration) -> Result<Result<Channel, tonic::Status>, tokio::time::error::Elapsed> {
let max_delay_ms = self.max_delay_ms;
let drop_rate = self.drop_rate;
let hang_rate = self.hang_rate;
// This is a custom connector that inserts delays and errors, for
// testing purposes. It would normally be disabled by the config.
let connector = service_fn(move |uri: Uri| {
let drop_rate = drop_rate;
let hang_rate = hang_rate;
async move {
let mut rng = StdRng::from_entropy();
// Simulate an indefinite hang
if hang_rate > 0.0 && rng.gen_bool(hang_rate) {
// never completes, to test timeout
return future::pending::<Result<TokioIo<TokioTcp>, std::io::Error>>().await;
}
// Random drop (connect error)
if drop_rate > 0.0 && rng.gen_bool(drop_rate) {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"simulated connect drop",
));
}
// Otherwise perform real TCP connect
let addr = match (uri.host(), uri.port()) {
// host + explicit port
(Some(host), Some(port)) => format!("{}:{}", host, port.as_str()),
// host only (no port)
(Some(host), None) => host.to_string(),
// neither? error out
_ => return Err(Error::new(ErrorKind::InvalidInput, "no host or port")),
};
let tcp = TcpStream::connect(addr).await?;
let tcpwrapper = TokioTcp::new(tcp, max_delay_ms);
Ok(TokioIo::new(tcpwrapper))
}
});
let attempt = tokio::time::timeout(
connect_timeout,
Endpoint::from_shared(self.endpoint.clone())
.expect("invalid endpoint")
.timeout(connect_timeout)
.connect_with_connector(connector),
)
.await;
match attempt {
Ok(Ok(channel)) => {
// Connection succeeded
Ok(Ok(channel))
}
Ok(Err(e)) => {
Ok(Err(tonic::Status::new(
tonic::Code::Unavailable,
format!("Failed to connect: {}", e),
)))
}
Err(e) => {
Err(e)
}
}
}
}
/// A pooled gRPC client with capacity tracking and error handling.
pub struct ConnectionPool<T> {
inner: Mutex<Inner<T>>,
fact: Arc<dyn PooledItemFactory<T> + Send + Sync>,
connect_timeout: Duration,
connect_backoff: Duration,
/// The maximum number of consumers that can use a single connection.
max_consumers: usize,
/// The number of consecutive errors before a connection is removed from the pool.
error_threshold: usize,
/// The maximum duration a connection can be idle before being removed.
max_idle_duration: Duration,
max_total_connections: usize,
channel_semaphore: Arc<Semaphore>,
shutdown_token: CancellationToken,
aggregate_metrics: Option<Arc<crate::PageserverClientAggregateMetrics>>,
}
struct Inner {
entries: HashMap<uuid::Uuid, ConnectionEntry>,
struct Inner<T> {
entries: HashMap<uuid::Uuid, ConnectionEntry<T>>,
pq: PriorityQueue<uuid::Uuid, usize>,
// This is updated when a connection is dropped, or we fail
// to create a new connection.
@@ -197,54 +298,50 @@ struct Inner {
waiters: usize,
in_progress: usize,
}
struct ConnectionEntry {
channel: Channel,
struct ConnectionEntry<T> {
channel: T,
active_consumers: usize,
consecutive_errors: usize,
last_used: Instant,
}
/// A client borrowed from the pool.
pub struct PooledClient {
pub channel: Channel,
pool: Arc<ConnectionPool>,
pub struct PooledClient<T> {
pub channel: T,
pool: Arc<ConnectionPool<T>>,
is_ok: bool,
id: uuid::Uuid,
permit: OwnedSemaphorePermit,
}
impl ConnectionPool {
impl<T: Clone + Send + 'static> ConnectionPool<T> {
pub fn new(
endpoint: &String,
max_consumers: usize,
error_threshold: usize,
fact: Arc<dyn PooledItemFactory<T> + Send + Sync>,
connect_timeout: Duration,
connect_backoff: Duration,
max_consumers: usize,
error_threshold: usize,
max_idle_duration: Duration,
max_delay_ms: u64,
drop_rate: f64,
hang_rate: f64,
max_total_connections: usize,
aggregate_metrics: Option<Arc<crate::PageserverClientAggregateMetrics>>,
) -> Arc<Self> {
let shutdown_token = CancellationToken::new();
let pool = Arc::new(Self {
inner: Mutex::new(Inner {
inner: Mutex::new(Inner::<T> {
entries: HashMap::new(),
pq: PriorityQueue::new(),
last_connect_failure: None,
waiters: 0,
in_progress: 0,
}),
fact: Arc::clone(&fact),
connect_timeout,
connect_backoff,
max_consumers,
error_threshold,
max_idle_duration,
max_total_connections,
channel_semaphore: Arc::new(Semaphore::new(0)),
endpoint: endpoint.clone(),
max_consumers: max_consumers,
error_threshold: error_threshold,
connect_timeout: connect_timeout,
connect_backoff: connect_backoff,
max_idle_duration: max_idle_duration,
max_delay_ms: max_delay_ms,
drop_rate: drop_rate,
hang_rate: hang_rate,
shutdown_token: shutdown_token.clone(),
aggregate_metrics: aggregate_metrics.clone(),
});
@@ -325,7 +422,7 @@ impl ConnectionPool {
async fn get_conn_with_permit(
self: Arc<Self>,
permit: OwnedSemaphorePermit,
) -> Option<PooledClient> {
) -> Option<PooledClient<T>> {
let mut inner = self.inner.lock().await;
// Pop the highest-active-consumers connection. There are no connections
@@ -340,9 +437,10 @@ impl ConnectionPool {
entry.active_consumers += 1;
entry.last_used = Instant::now();
let client = PooledClient {
let client = PooledClient::<T> {
channel: entry.channel.clone(),
pool: Arc::clone(&self),
is_ok: true,
id,
permit: permit,
};
@@ -365,7 +463,7 @@ impl ConnectionPool {
}
}
pub async fn get_client(self: Arc<Self>) -> Result<PooledClient, tonic::Status> {
pub async fn get_client(self: Arc<Self>) -> Result<PooledClient<T>, tonic::Status> {
// The pool is shutting down. Don't accept new connections.
if self.shutdown_token.is_cancelled() {
return Err(tonic::Status::unavailable("Pool is shutting down"));
@@ -412,12 +510,16 @@ impl ConnectionPool {
//
let mut inner = self_clone.inner.lock().await;
inner.waiters += 1;
if inner.waiters >= (inner.in_progress * self_clone.max_consumers) {
let self_clone_spawn = Arc::clone(&self_clone);
tokio::task::spawn(async move {
self_clone_spawn.create_connection().await;
});
inner.in_progress += 1;
if inner.waiters > (inner.in_progress * self_clone.max_consumers) {
if (inner.entries.len() + inner.in_progress) < self_clone.max_total_connections {
let self_clone_spawn = Arc::clone(&self_clone);
tokio::task::spawn(async move {
self_clone_spawn.create_connection().await;
});
inner.in_progress += 1;
}
}
}
// Wait for a connection to become available, either because it
@@ -446,46 +548,6 @@ impl ConnectionPool {
}
async fn create_connection(&self) -> () {
let max_delay_ms = self.max_delay_ms;
let drop_rate = self.drop_rate;
let hang_rate = self.hang_rate;
// This is a custom connector that inserts delays and errors, for
// testing purposes. It would normally be disabled by the config.
let connector = service_fn(move |uri: Uri| {
let drop_rate = drop_rate;
let hang_rate = hang_rate;
async move {
let mut rng = StdRng::from_entropy();
// Simulate an indefinite hang
if hang_rate > 0.0 && rng.gen_bool(hang_rate) {
// never completes, to test timeout
return future::pending::<Result<TokioIo<TokioTcp>, std::io::Error>>().await;
}
// Random drop (connect error)
if drop_rate > 0.0 && rng.gen_bool(drop_rate) {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"simulated connect drop",
));
}
// Otherwise perform real TCP connect
let addr = match (uri.host(), uri.port()) {
// host + explicit port
(Some(host), Some(port)) => format!("{}:{}", host, port.as_str()),
// host only (no port)
(Some(host), None) => host.to_string(),
// neither? error out
_ => return Err(Error::new(ErrorKind::InvalidInput, "no host or port")),
};
let tcp = TcpStream::connect(addr).await?;
let tcpwrapper = TokioTcp::new(tcp, max_delay_ms);
Ok(TokioIo::new(tcpwrapper))
}
});
// Generate a random backoff to add some jitter so that connections
// don't all retry at the same time.
@@ -533,14 +595,9 @@ impl ConnectionPool {
None => {}
}
let attempt = tokio::time::timeout(
self.connect_timeout,
Endpoint::from_shared(self.endpoint.clone())
.expect("invalid endpoint")
.timeout(self.connect_timeout)
.connect_with_connector(connector),
)
.await;
let attempt = self.fact
.create(self.connect_timeout)
.await;
match attempt {
// Connection succeeded
@@ -559,7 +616,7 @@ impl ConnectionPool {
let id = uuid::Uuid::new_v4();
inner.entries.insert(
id,
ConnectionEntry {
ConnectionEntry::<T> {
channel: channel.clone(),
active_consumers: 0,
consecutive_errors: 0,
@@ -641,6 +698,11 @@ impl ConnectionPool {
inner.pq.remove(&id);
}
// remove from entries
// check if entry is in inner
if inner.entries.contains_key(&id) {
inner.entries.remove(&id);
}
inner.last_connect_failure = Some(Instant::now());
// The connection has been removed, it's permits will be
@@ -661,18 +723,19 @@ impl ConnectionPool {
}
}
}
// The semaphore permit is released when the pooled client is dropped.
}
}
impl PooledClient {
pub fn channel(&self) -> Channel {
impl<T: Clone + Send + 'static> PooledClient<T> {
pub fn channel(&self) -> T {
return self.channel.clone();
}
pub async fn finish(self, result: Result<(), tonic::Status>) {
self.pool
.return_client(self.id, result.is_ok(), self.permit)
.await;
pub async fn finish(mut self, result: Result<(), tonic::Status>) {
self.is_ok = result.is_ok();
self.pool.return_client(
self.id,
self.is_ok,
self.permit,
).await;
}
}

View File

@@ -20,9 +20,16 @@ use pageserver_page_api::proto::PageServiceClient;
use utils::shard::ShardIndex;
use std::fmt::Debug;
mod client_cache;
pub mod client_cache;
pub mod request_tracker;
use tonic::transport::Channel;
use metrics::{IntCounterVec, core::Collector};
use crate::client_cache::{PooledItemFactory};
use tokio::sync::mpsc;
use async_trait::async_trait;
#[derive(Error, Debug)]
pub enum PageserverClientError {
@@ -77,6 +84,7 @@ impl PageserverClientAggregateMetrics {
metrics
}
}
pub struct PageserverClient {
_tenant_id: String,
_timeline_id: String,
@@ -85,7 +93,7 @@ pub struct PageserverClient {
shard_map: HashMap<ShardIndex, String>,
channels: RwLock<HashMap<ShardIndex, Arc<client_cache::ConnectionPool>>>,
channels: RwLock<HashMap<ShardIndex, Arc<client_cache::ConnectionPool<Channel>>>>,
auth_interceptor: AuthInterceptor,
@@ -93,13 +101,14 @@ pub struct PageserverClient {
aggregate_metrics: Option<Arc<PageserverClientAggregateMetrics>>,
}
#[derive(Clone)]
pub struct ClientCacheOptions {
pub max_consumers: usize,
pub error_threshold: usize,
pub connect_timeout: Duration,
pub connect_backoff: Duration,
pub max_idle_duration: Duration,
pub max_total_connections: usize,
pub max_delay_ms: u64,
pub drop_rate: f64,
pub hang_rate: f64,
@@ -119,6 +128,7 @@ impl PageserverClient {
connect_timeout: Duration::from_secs(5),
connect_backoff: Duration::from_secs(1),
max_idle_duration: Duration::from_secs(60),
max_total_connections: 100000,
max_delay_ms: 0,
drop_rate: 0.0,
hang_rate: 0.0,
@@ -349,13 +359,13 @@ impl PageserverClient {
///
/// Get a client from the pool for this shard, also creating the pool if it doesn't exist.
///
async fn get_client(&self, shard: ShardIndex) -> client_cache::PooledClient {
let reused_pool: Option<Arc<client_cache::ConnectionPool>> = {
async fn get_client(&self, shard: ShardIndex) -> client_cache::PooledClient<Channel> {
let reused_pool: Option<Arc<client_cache::ConnectionPool<Channel>>> = {
let channels = self.channels.read().unwrap();
channels.get(&shard).cloned()
};
let usable_pool: Arc<client_cache::ConnectionPool>;
let usable_pool: Arc<client_cache::ConnectionPool<Channel>>;
match reused_pool {
Some(pool) => {
let pooled_client = pool.get_client().await.unwrap();
@@ -365,17 +375,21 @@ impl PageserverClient {
// Create a new pool using client_cache_options
// declare new_pool
let new_pool: Arc<client_cache::ConnectionPool>;
new_pool = client_cache::ConnectionPool::new(
self.shard_map.get(&shard).unwrap(),
self.client_cache_options.max_consumers,
self.client_cache_options.error_threshold,
self.client_cache_options.connect_timeout,
self.client_cache_options.connect_backoff,
self.client_cache_options.max_idle_duration,
let new_pool: Arc<client_cache::ConnectionPool<Channel>>;
let channel_fact = Arc::new(client_cache::ChannelFactory::new(
self.shard_map.get(&shard).unwrap().clone(),
self.client_cache_options.max_delay_ms,
self.client_cache_options.drop_rate,
self.client_cache_options.hang_rate,
));
new_pool = client_cache::ConnectionPool::new(
channel_fact,
self.client_cache_options.connect_timeout,
self.client_cache_options.connect_backoff,
self.client_cache_options.max_consumers,
self.client_cache_options.error_threshold,
self.client_cache_options.max_idle_duration,
self.client_cache_options.max_total_connections,
self.aggregate_metrics.clone(),
);
let mut write_pool = self.channels.write().unwrap();
@@ -391,7 +405,7 @@ impl PageserverClient {
/// Inject tenant_id, timeline_id and authentication token to all pageserver requests.
#[derive(Clone)]
struct AuthInterceptor {
pub struct AuthInterceptor {
tenant_id: AsciiMetadataValue,
shard_id: Option<AsciiMetadataValue>,
timeline_id: AsciiMetadataValue,
@@ -400,7 +414,7 @@ struct AuthInterceptor {
}
impl AuthInterceptor {
fn new(tenant_id: &str, timeline_id: &str, auth_token: Option<&str>) -> Self {
pub fn new(tenant_id: &str, timeline_id: &str, auth_token: Option<&str>) -> Self {
Self {
tenant_id: tenant_id.parse().expect("could not parse tenant id"),
shard_id: None,

View File

@@ -0,0 +1,590 @@
//
// API Visible to the spawner, just a function call that is async
//
use std::sync::Arc;
use crate::client_cache;
use pageserver_page_api::GetPageRequest;
use pageserver_page_api::GetPageResponse;
use pageserver_page_api::*;
use pageserver_page_api::proto;
use crate::client_cache::ConnectionPool;
use crate::client_cache::ChannelFactory;
use crate::AuthInterceptor;
use tonic::{transport::{Channel}, Request};
use crate::ClientCacheOptions;
use crate::PageserverClientAggregateMetrics;
use tokio::sync::Mutex;
use std::sync::atomic::{AtomicU64, Ordering};
use utils::shard::ShardIndex;
use tokio_stream::wrappers::ReceiverStream;
use pageserver_page_api::proto::PageServiceClient;
use tonic::{
Status,
Code,
};
use async_trait::async_trait;
use std::time::Duration;
use client_cache::PooledItemFactory;
//use tracing::info;
//
// A mock stream pool that just returns a sending channel, and whenever a GetPageRequest
// comes in on that channel, it randomly sleeps before sending a GetPageResponse
//
#[derive(Clone)]
pub struct StreamReturner {
sender: tokio::sync::mpsc::Sender<proto::GetPageRequest>,
sender_hashmap: Arc<Mutex<std::collections::HashMap<u64, tokio::sync::mpsc::Sender<Result<proto::GetPageResponse, Status>>>>>,
}
pub struct MockStreamFactory {
}
impl MockStreamFactory {
pub fn new() -> Self {
MockStreamFactory {
}
}
}
#[async_trait]
impl PooledItemFactory<StreamReturner> for MockStreamFactory {
async fn create(&self, _connect_timeout: Duration) -> Result<Result<StreamReturner, tonic::Status>, tokio::time::error::Elapsed> {
let (sender, mut receiver) = tokio::sync::mpsc::channel::<proto::GetPageRequest>(1000);
// Create a StreamReturner that will send requests to the receiver channel
let stream_returner = StreamReturner {
sender: sender.clone(),
sender_hashmap: Arc::new(Mutex::new(std::collections::HashMap::new())),
};
let map : Arc<Mutex<std::collections::HashMap<u64, tokio::sync::mpsc::Sender<Result<proto::GetPageResponse, _>>>>>
= Arc::clone(&stream_returner.sender_hashmap);
tokio::spawn(async move {
while let Some(request) = receiver.recv().await {
// Break out of the loop with 1% chance
if rand::random::<f32>() < 0.001 {
break;
}
// Generate a random number between 0 and 100
// Simulate some processing time
let mapclone = Arc::clone(&map);
tokio::spawn(async move {
let sleep_ms = rand::random::<u64>() % 100;
tokio::time::sleep(tokio::time::Duration::from_millis(sleep_ms)).await;
let response = proto::GetPageResponse {
request_id: request.request_id,
..Default::default()
};
// look up stream in hash map
let mut hashmap = mapclone.lock().await;
if let Some(sender) = hashmap.get(&request.request_id) {
// Send the response to the original request sender
if let Err(e) = sender.send(Ok(response.clone())).await {
eprintln!("Failed to send response: {}", e);
}
hashmap.remove(&request.request_id);
} else {
eprintln!("No sender found for request ID: {}", request.request_id);
}
});
}
// Close every sender stream in the hashmap
let hashmap = map.lock().await;
for sender in hashmap.values() {
let error = Status::new(Code::Unknown, "Stream closed");
if let Err(e) = sender.send(Err(error)).await {
eprintln!("Failed to send close response: {}", e);
}
}
});
Ok(Ok(stream_returner))
}
}
pub struct StreamFactory {
connection_pool: Arc<client_cache::ConnectionPool<Channel>>,
auth_interceptor: AuthInterceptor,
shard: ShardIndex,
}
impl StreamFactory {
pub fn new(
connection_pool: Arc<ConnectionPool<Channel>>,
auth_interceptor: AuthInterceptor,
shard: ShardIndex,
) -> Self {
StreamFactory {
connection_pool,
auth_interceptor,
shard,
}
}
}
#[async_trait]
impl PooledItemFactory<StreamReturner> for StreamFactory {
async fn create(&self, _connect_timeout: Duration) ->
Result<Result<StreamReturner, tonic::Status>, tokio::time::error::Elapsed>
{
let pool_clone : Arc<ConnectionPool<Channel>> = Arc::clone(&self.connection_pool);
let pooled_client = pool_clone.get_client().await;
let channel = pooled_client.unwrap().channel();
let mut client =
PageServiceClient::with_interceptor(channel, self.auth_interceptor.for_shard(self.shard));
let (sender, receiver) = tokio::sync::mpsc::channel::<proto::GetPageRequest>(1000);
let outbound = ReceiverStream::new(receiver);
let client_resp = client
.get_pages(Request::new(outbound))
.await;
match client_resp {
Err(status) => {
// TODO: Convert this error correctly
Ok(Err(tonic::Status::new(
status.code(),
format!("Failed to connect to pageserver: {}", status.message()),
)))
}
Ok(resp) => {
let stream_returner = StreamReturner {
sender: sender.clone(),
sender_hashmap: Arc::new(Mutex::new(std::collections::HashMap::new())),
};
let map : Arc<Mutex<std::collections::HashMap<u64, tokio::sync::mpsc::Sender<Result<proto::GetPageResponse, _>>>>>
= Arc::clone(&stream_returner.sender_hashmap);
tokio::spawn(async move {
let map_clone = Arc::clone(&map);
let mut inner = resp.into_inner();
loop {
let resp = inner.message().await;
if !resp.is_ok() {
break; // Exit the loop if no more messages
}
let response = resp.unwrap().unwrap();
// look up stream in hash map
let mut hashmap = map_clone.lock().await;
if let Some(sender) = hashmap.get(&response.request_id) {
// Send the response to the original request sender
if let Err(e) = sender.send(Ok(response.clone())).await {
eprintln!("Failed to send response: {}", e);
}
hashmap.remove(&response.request_id);
} else {
eprintln!("No sender found for request ID: {}", response.request_id);
}
}
// Close every sender stream in the hashmap
let hashmap = map_clone.lock().await;
for sender in hashmap.values() {
let error = Status::new(Code::Unknown, "Stream closed");
if let Err(e) = sender.send(Err(error)).await {
eprintln!("Failed to send close response: {}", e);
}
}
});
Ok(Ok(stream_returner))
}
}
}
}
#[derive(Clone)]
pub struct RequestTracker {
cur_id: Arc<AtomicU64>,
stream_pool: Arc<ConnectionPool<StreamReturner>>,
unary_pool: Arc<ConnectionPool<Channel>>,
auth_interceptor: AuthInterceptor,
shard: ShardIndex,
}
impl RequestTracker {
pub fn new(stream_pool: Arc<ConnectionPool<StreamReturner>>,
unary_pool: Arc<ConnectionPool<Channel>>,
auth_interceptor: AuthInterceptor,
shard: ShardIndex,
) -> Self {
let cur_id = Arc::new(AtomicU64::new(0));
RequestTracker {
cur_id: cur_id.clone(),
stream_pool: stream_pool,
unary_pool: unary_pool,
auth_interceptor: auth_interceptor,
shard: shard.clone()
}
}
pub async fn send_process_check_rel_exists_request(
&self,
req: CheckRelExistsRequest,
) -> Result<bool, tonic::Status> {
loop {
let unary_pool = Arc::clone(&self.unary_pool);
let pooled_client = unary_pool.get_client().await.unwrap();
let channel = pooled_client.channel();
let mut ps_client = PageServiceClient::with_interceptor(channel, self.auth_interceptor.for_shard(self.shard));
let request = proto::CheckRelExistsRequest::from(req.clone());
let response = ps_client.check_rel_exists(tonic::Request::new(request)).await;
match response {
Err(status) => {
pooled_client.finish(Err(status.clone())).await; // Pass error to finish
continue;
}
Ok(resp) => {
pooled_client.finish(Ok(())).await; // Pass success to finish
return Ok(resp.get_ref().exists);
}
}
}
}
pub async fn send_process_get_rel_size_request(
&self,
req: GetRelSizeRequest,
) -> Result<u32, tonic::Status> {
loop {
// Current sharding model assumes that all metadata is present only at shard 0.
let unary_pool = Arc::clone(&self.unary_pool);
let pooled_client = unary_pool.get_client().await.unwrap();
let channel = pooled_client.channel();
let mut ps_client = PageServiceClient::with_interceptor(channel, self.auth_interceptor.for_shard(self.shard));
let request = proto::GetRelSizeRequest::from(req.clone());
let response = ps_client.get_rel_size(tonic::Request::new(request)).await;
match response {
Err(status) => {
pooled_client.finish(Err(status.clone())).await; // Pass error to finish
continue;
}
Ok(resp) => {
pooled_client.finish(Ok(())).await; // Pass success to finish
return Ok(resp.get_ref().num_blocks);
}
}
}
}
pub async fn send_process_get_dbsize_request(
&self,
req: GetDbSizeRequest,
) -> Result<u64, tonic::Status> {
loop {
// Current sharding model assumes that all metadata is present only at shard 0.
let unary_pool = Arc::clone(&self.unary_pool);
let pooled_client = unary_pool.get_client().await.unwrap();let channel = pooled_client.channel();
let mut ps_client = PageServiceClient::with_interceptor(channel, self.auth_interceptor.for_shard(self.shard));
let request = proto::GetDbSizeRequest::from(req.clone());
let response = ps_client.get_db_size(tonic::Request::new(request)).await;
match response {
Err(status) => {
pooled_client.finish(Err(status.clone())).await; // Pass error to finish
continue;
}
Ok(resp) => {
pooled_client.finish(Ok(())).await; // Pass success to finish
return Ok(resp.get_ref().num_bytes);
}
}
}
}
pub async fn send_getpage_request(
&mut self,
req: GetPageRequest,
) -> Result<GetPageResponse, tonic::Status> {
loop {
let mut request = req.clone();
// Increment cur_id
//let request_id = self.cur_id.fetch_add(1, Ordering::SeqCst) + 1;
let request_id = request.request_id;
let response_sender: tokio::sync::mpsc::Sender<Result<proto::GetPageResponse, Status>>;
let mut response_receiver: tokio::sync::mpsc::Receiver<Result<proto::GetPageResponse, Status>>;
(response_sender, response_receiver) = tokio::sync::mpsc::channel(1);
//request.request_id = request_id;
// Get a stream from the stream pool
let pool_clone = Arc::clone(&self.stream_pool);
let sender_stream_pool = pool_clone.get_client().await;
let stream_returner = match sender_stream_pool {
Ok(stream_ret) => stream_ret,
Err(_e) => {
// retry
continue;
}
};
let returner = stream_returner.channel();
let map = returner.sender_hashmap.clone();
// Insert the response sender into the hashmap
{
let mut map_inner = map.lock().await;
map_inner.insert(request_id, response_sender);
}
let sent = returner.sender.send(proto::GetPageRequest::from(request))
.await;
if let Err(_e) = sent {
// Remove the request from the map if sending failed
{
let mut map_inner = map.lock().await;
// remove from hashmap
map_inner.remove(&request_id);
}
stream_returner.finish(Err(Status::new(Code::Unknown,
"Failed to send request"))).await;
continue;
}
let response: Option<Result<proto::GetPageResponse, Status>>;
response = response_receiver.recv().await;
match response {
Some (resp) => {
match resp {
Err(_status) => {
// Handle the case where the response was not received
stream_returner.finish(Err(Status::new(Code::Unknown,
"Failed to receive response"))).await;
continue;
},
Ok(resp) => {
stream_returner.finish(Result::Ok(())).await;
return Ok(resp.clone().into());
}
}
}
None => {
// Handle the case where the response channel was closed
stream_returner.finish(Err(Status::new(Code::Unknown,
"Response channel closed"))).await;
continue;
}
}
}
}
}
struct ShardedRequestTrackerInner {
// Hashmap of shard index to RequestTracker
trackers: std::collections::HashMap<ShardIndex, RequestTracker>,
}
pub struct ShardedRequestTracker {
inner: Arc<Mutex<ShardedRequestTrackerInner>>,
tcp_client_cache_options: ClientCacheOptions,
stream_client_cache_options: ClientCacheOptions,
}
//
// TODO: Functions in the ShardedRequestTracker should be able to timeout and
// cancel a reqeust. The request should return an error if it is cancelled.
//
impl ShardedRequestTracker {
pub fn new() -> Self {
//
// Default configuration for the client. These could be added to a config file
//
let tcp_client_cache_options = ClientCacheOptions {
max_delay_ms: 0,
drop_rate: 0.0,
hang_rate: 0.0,
connect_timeout: Duration::from_secs(1),
connect_backoff: Duration::from_millis(100),
max_consumers: 8, // Streams per connection
error_threshold: 10,
max_idle_duration: Duration::from_secs(5),
max_total_connections: 8,
};
let stream_client_cache_options = ClientCacheOptions {
max_delay_ms: 0,
drop_rate: 0.0,
hang_rate: 0.0,
connect_timeout: Duration::from_secs(1),
connect_backoff: Duration::from_millis(100),
max_consumers: 64, // Requests per stream
error_threshold: 10,
max_idle_duration: Duration::from_secs(5),
max_total_connections: 64, // Total allowable number of streams
};
ShardedRequestTracker {
inner: Arc::new(Mutex::new(ShardedRequestTrackerInner {
trackers: std::collections::HashMap::new(),
})),
tcp_client_cache_options,
stream_client_cache_options,
}
}
pub async fn update_shard_map(&self,
shard_urls: std::collections::HashMap<ShardIndex, String>,
metrics: Option<Arc<PageserverClientAggregateMetrics>>,
tenant_id: String, timeline_id: String, auth_str: Option<&str>) {
let mut trackers = std::collections::HashMap::new();
for (shard, endpoint_url) in shard_urls {
//
// Create a pool of streams for streaming get_page requests
//
let channel_fact : Arc<dyn PooledItemFactory<Channel> + Send + Sync> = Arc::new(ChannelFactory::new(
endpoint_url.clone(),
self.tcp_client_cache_options.max_delay_ms,
self.tcp_client_cache_options.drop_rate,
self.tcp_client_cache_options.hang_rate,
));
let new_pool: Arc<ConnectionPool<Channel>>;
new_pool = ConnectionPool::new(
Arc::clone(&channel_fact),
self.tcp_client_cache_options.connect_timeout,
self.tcp_client_cache_options.connect_backoff,
self.tcp_client_cache_options.max_consumers,
self.tcp_client_cache_options.error_threshold,
self.tcp_client_cache_options.max_idle_duration,
self.tcp_client_cache_options.max_total_connections,
metrics.clone(),
);
let auth_interceptor = AuthInterceptor::new(tenant_id.as_str(),
timeline_id.as_str(),
auth_str);
let stream_pool = ConnectionPool::<StreamReturner>::new(
Arc::new(StreamFactory::new(new_pool.clone(),
auth_interceptor.clone(), ShardIndex::unsharded())),
self.stream_client_cache_options.connect_timeout,
self.stream_client_cache_options.connect_backoff,
self.stream_client_cache_options.max_consumers,
self.stream_client_cache_options.error_threshold,
self.stream_client_cache_options.max_idle_duration,
self.stream_client_cache_options.max_total_connections,
metrics.clone(),
);
//
// Create a client pool for unary requests
//
let unary_pool: Arc<ConnectionPool<Channel>>;
unary_pool = ConnectionPool::new(
Arc::clone(&channel_fact),
self.tcp_client_cache_options.connect_timeout,
self.tcp_client_cache_options.connect_backoff,
self.tcp_client_cache_options.max_consumers,
self.tcp_client_cache_options.error_threshold,
self.tcp_client_cache_options.max_idle_duration,
self.tcp_client_cache_options.max_total_connections,
metrics.clone()
);
//
// Create a new RequestTracker for this shard
//
let new_tracker = RequestTracker::new(stream_pool, unary_pool, auth_interceptor, shard);
trackers.insert(shard, new_tracker);
}
let mut inner = self.inner.lock().await;
inner.trackers = trackers;
}
pub async fn get_page(
&self,
req: GetPageRequest,
) -> Result<GetPageResponse, tonic::Status> {
// Get shard index from the request
let shard_index = ShardIndex::unsharded();
let inner = self.inner.lock().await;
let mut tracker : RequestTracker;
if let Some(t) = inner.trackers.get(&shard_index) {
tracker = t.clone();
} else {
return Err(tonic::Status::not_found(format!("Shard {} not found", shard_index)));
}
drop(inner);
// Call the send_getpage_request method on the tracker
let response = tracker.send_getpage_request(req).await;
match response {
Ok(resp) => Ok(resp),
Err(e) => Err(tonic::Status::unknown(format!("Failed to get page: {}", e))),
}
}
pub async fn process_get_dbsize_request(
&self,
request: GetDbSizeRequest,
) -> Result<u64, tonic::Status> {
let shard_index = ShardIndex::unsharded();
let inner = self.inner.lock().await;
let mut tracker: RequestTracker;
if let Some(t) = inner.trackers.get(&shard_index) {
tracker = t.clone();
} else {
return Err(tonic::Status::not_found(format!("Shard {} not found", shard_index)));
}
drop(inner); // Release the lock before calling send_process_get_dbsize_request
// Call the send_process_get_dbsize_request method on the tracker
let response = tracker.send_process_get_dbsize_request(request).await;
match response {
Ok(resp) => Ok(resp),
Err(e) => Err(e),
}
}
pub async fn process_get_rel_size_request(
&self,
request: GetRelSizeRequest,
) -> Result<u32, tonic::Status> {
let shard_index = ShardIndex::unsharded();
let inner = self.inner.lock().await;
let mut tracker: RequestTracker;
if let Some(t) = inner.trackers.get(&shard_index) {
tracker = t.clone();
} else {
return Err(tonic::Status::not_found(format!("Shard {} not found", shard_index)));
}
drop(inner); // Release the lock before calling send_process_get_rel_size_request
// Call the send_process_get_rel_size_request method on the tracker
let response = tracker.send_process_get_rel_size_request(request).await;
match response {
Ok(resp) => Ok(resp),
Err(e) => Err(e),
}
}
pub async fn process_check_rel_exists_request(
&self,
request: CheckRelExistsRequest,
) -> Result<bool, tonic::Status> {
let shard_index = ShardIndex::unsharded();
let inner = self.inner.lock().await;
let mut tracker: RequestTracker;
if let Some(t) = inner.trackers.get(&shard_index) {
tracker = t.clone();
} else {
return Err(tonic::Status::not_found(format!("Shard {} not found", shard_index)));
}
drop(inner); // Release the lock before calling send_process_check_rel_exists_request
// Call the send_process_check_rel_exists_request method on the tracker
let response = tracker.send_process_check_rel_exists_request(request).await;
match response {
Ok(resp) => Ok(resp),
Err(e) => Err(e),
}
}
}

View File

@@ -487,6 +487,7 @@ impl From<GetPageStatusCode> for i32 {
// Fetches the size of a relation at a given LSN, as # of blocks. Only valid on shard 0, other
// shards will error.
#[derive(Clone)]
pub struct GetRelSizeRequest {
pub read_lsn: ReadLsn,
pub rel: RelTag,

View File

@@ -1,10 +1,11 @@
use std::collections::{HashMap, HashSet, VecDeque};
use std::collections::{HashSet, HashMap, VecDeque};
use std::future::Future;
use std::num::NonZeroUsize;
use std::pin::Pin;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use std::io::Error;
use anyhow::Context;
use async_trait::async_trait;
@@ -23,6 +24,8 @@ use tracing::info;
use utils::id::TenantTimelineId;
use utils::lsn::Lsn;
use tonic::transport::Channel;
use axum::Router;
use axum::body::Body;
use axum::extract::State;
@@ -427,6 +430,7 @@ async fn main_impl(
.await
.unwrap(),
),
};
run_worker(args, client, ss, cancel, rps_period, ranges, weights).await
})
@@ -694,6 +698,7 @@ impl Client for LibpqClient {
struct GrpcClient {
req_tx: tokio::sync::mpsc::Sender<proto::GetPageRequest>,
resp_rx: tonic::Streaming<proto::GetPageResponse>,
start_times: Vec<Instant>,
}
impl GrpcClient {
@@ -717,6 +722,7 @@ impl GrpcClient {
Ok(Self {
req_tx,
resp_rx: resp_stream,
start_times: Vec::new(),
})
}
}
@@ -741,6 +747,7 @@ impl Client for GrpcClient {
rel: Some(rel.into()),
block_number: blks,
};
self.start_times.push(Instant::now());
self.req_tx.send(req).await?;
Ok(())
}
@@ -755,3 +762,4 @@ impl Client for GrpcClient {
Ok((resp.request_id, resp.page_image))
}
}

View File

@@ -12,7 +12,7 @@ use crate::integrated_cache::{CacheResult, IntegratedCacheWriteAccess};
use crate::neon_request::{CGetPageVRequest, CPrefetchVRequest};
use crate::neon_request::{NeonIORequest, NeonIOResult};
use crate::worker_process::in_progress_ios::{RequestInProgressKey, RequestInProgressTable};
use pageserver_client_grpc::PageserverClient;
use pageserver_client_grpc::request_tracker::ShardedRequestTracker;
use pageserver_page_api as page_api;
use metrics::{IntCounter, IntCounterVec};
@@ -30,7 +30,7 @@ use utils::lsn::Lsn;
pub struct CommunicatorWorkerProcessStruct<'a> {
neon_request_slots: &'a [NeonIOHandle],
pageserver_client: PageserverClient,
request_tracker: ShardedRequestTracker,
pub(crate) cache: IntegratedCacheWriteAccess<'a>,
@@ -74,6 +74,7 @@ pub(super) async fn init(
initial_file_cache_size: u64,
file_cache_path: Option<PathBuf>,
) -> CommunicatorWorkerProcessStruct<'static> {
info!("Test log message");
let last_lsn = get_request_lsn();
let file_cache = if let Some(path) = file_cache_path {
@@ -97,7 +98,12 @@ pub(super) async fn init(
.integrated_cache_init_struct
.worker_process_init(last_lsn, file_cache);
let pageserver_client = PageserverClient::new(&tenant_id, &timeline_id, &auth_token, shard_map);
let mut request_tracker = ShardedRequestTracker::new();
request_tracker.update_shard_map(shard_map,
None,
tenant_id,
timeline_id,
auth_token.as_deref()).await;
let request_counters = IntCounterVec::new(
metrics::core::Opts::new(
@@ -148,7 +154,7 @@ pub(super) async fn init(
CommunicatorWorkerProcessStruct {
neon_request_slots: cis.neon_request_slots,
pageserver_client,
request_tracker,
cache,
submission_pipe_read_fd: cis.submission_pipe_read_fd,
next_request_id: AtomicU64::new(1),
@@ -257,7 +263,7 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> {
};
match self
.pageserver_client
.request_tracker
.process_check_rel_exists_request(page_api::CheckRelExistsRequest {
read_lsn: self.request_lsns(not_modified_since),
rel,
@@ -291,7 +297,7 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> {
let read_lsn = self.request_lsns(not_modified_since);
match self
.pageserver_client
.request_tracker
.process_get_rel_size_request(page_api::GetRelSizeRequest {
read_lsn,
rel: rel.clone(),
@@ -344,7 +350,7 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> {
};
match self
.pageserver_client
.request_tracker
.process_get_dbsize_request(page_api::GetDbSizeRequest {
read_lsn: self.request_lsns(not_modified_since),
db_oid: req.db_oid,
@@ -467,7 +473,7 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> {
// TODO: Use batched protocol
for (blkno, _lsn, dest, _guard) in cache_misses.iter() {
match self
.pageserver_client
.request_tracker
.get_page(page_api::GetPageRequest {
request_id: self.next_request_id.fetch_add(1, Ordering::Relaxed),
request_class: page_api::GetPageClass::Normal,
@@ -477,11 +483,11 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> {
})
.await
{
Ok(page_images) => {
Ok(resp) => {
// Write the received page image directly to the shared memory location
// that the backend requested.
assert!(page_images.len() == 1);
let page_image = page_images[0].clone();
assert!(resp.page_images.len() == 1);
let page_image = resp.page_images[0].clone();
let src: &[u8] = page_image.as_ref();
let len = std::cmp::min(src.len(), dest.bytes_total() as usize);
unsafe {
@@ -545,7 +551,7 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> {
// TODO: Use batched protocol
for (blkno, _lsn, _guard) in cache_misses.iter() {
match self
.pageserver_client
.request_tracker
.get_page(page_api::GetPageRequest {
request_id: self.next_request_id.fetch_add(1, Ordering::Relaxed),
request_class: page_api::GetPageClass::Prefetch,
@@ -555,13 +561,13 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> {
})
.await
{
Ok(page_images) => {
Ok(resp) => {
trace!(
"prefetch completed, remembering blk {} in rel {:?} in LFC",
*blkno, rel
);
assert!(page_images.len() == 1);
let page_image = page_images[0].clone();
assert!(resp.page_images.len() == 1);
let page_image = resp.page_images[0].clone();
self.cache
.remember_page(&rel, *blkno, page_image, not_modified_since, false)
.await;