mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-10 15:02:56 +00:00
Request Tracker Prototype
Does not include splitting requests across shards.
This commit is contained in:
5
Cargo.lock
generated
5
Cargo.lock
generated
@@ -4592,17 +4592,22 @@ dependencies = [
|
||||
name = "pageserver_client_grpc"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"bytes",
|
||||
"chrono",
|
||||
"dashmap 5.5.0",
|
||||
"futures",
|
||||
"http 1.1.0",
|
||||
"hyper 1.6.0",
|
||||
"hyper-util",
|
||||
"metrics",
|
||||
"pageserver_api",
|
||||
"pageserver_page_api",
|
||||
"priority-queue",
|
||||
"rand 0.8.5",
|
||||
"thiserror 1.0.69",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tokio-util",
|
||||
"tonic 0.13.1",
|
||||
"tower 0.4.13",
|
||||
|
||||
@@ -19,7 +19,12 @@ hyper-util = "0.1.9"
|
||||
hyper = "1.6.0"
|
||||
metrics.workspace = true
|
||||
priority-queue = "2.3.1"
|
||||
async-trait = { version = "0.1" }
|
||||
tokio-stream = "0.1"
|
||||
dashmap = "5"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
|
||||
|
||||
pageserver_page_api.workspace = true
|
||||
pageserver_api.workspace = true
|
||||
utils.workspace = true
|
||||
|
||||
296
pageserver/client_grpc/examples/load_test.rs
Normal file
296
pageserver/client_grpc/examples/load_test.rs
Normal file
@@ -0,0 +1,296 @@
|
||||
// examples/load_test.rs, generated by AI
|
||||
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::{
|
||||
Arc,
|
||||
Mutex,
|
||||
atomic::{AtomicU64, AtomicUsize, Ordering},
|
||||
};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use tokio::task;
|
||||
use tokio::time::sleep;
|
||||
use rand::Rng;
|
||||
use tonic::Status;
|
||||
use uuid::Uuid;
|
||||
|
||||
// Pull in your ConnectionPool and PooledItemFactory from the pageserver_client_grpc crate.
|
||||
// Adjust these paths if necessary.
|
||||
use pageserver_client_grpc::client_cache::ConnectionPool;
|
||||
use pageserver_client_grpc::client_cache::PooledItemFactory;
|
||||
|
||||
// --------------------------------------
|
||||
// GLOBAL COUNTERS FOR “CREATED” / “DROPPED” MockConnections
|
||||
// --------------------------------------
|
||||
static CREATED: AtomicU64 = AtomicU64::new(0);
|
||||
static DROPPED: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
// --------------------------------------
|
||||
// MockConnection + Factory
|
||||
// --------------------------------------
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MockConnection {
|
||||
pub id: u64,
|
||||
}
|
||||
|
||||
impl Clone for MockConnection {
|
||||
fn clone(&self) -> Self {
|
||||
// Cloning a MockConnection does NOT count as “creating” a brand‐new connection,
|
||||
// so we do NOT bump CREATED here. We only bump CREATED in the factory’s `create()`.
|
||||
CREATED.fetch_add(1, Ordering::Relaxed);
|
||||
MockConnection { id: self.id }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for MockConnection {
|
||||
fn drop(&mut self) {
|
||||
// When a MockConnection actually gets dropped, bump the counter.
|
||||
DROPPED.fetch_add(1, Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MockConnectionFactory {
|
||||
counter: AtomicU64,
|
||||
}
|
||||
|
||||
impl MockConnectionFactory {
|
||||
pub fn new() -> Self {
|
||||
MockConnectionFactory {
|
||||
counter: AtomicU64::new(1),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl PooledItemFactory<MockConnection> for MockConnectionFactory {
|
||||
/// The trait on ConnectionPool expects:
|
||||
/// async fn create(&self, timeout: Duration)
|
||||
/// -> Result<Result<MockConnection, Status>, tokio::time::error::Elapsed>;
|
||||
///
|
||||
/// On success: Ok(Ok(MockConnection))
|
||||
/// On a simulated “gRPC” failure: Ok(Err(Status::…))
|
||||
/// On a transport/factory error: Err(Box<…>)
|
||||
async fn create(
|
||||
&self,
|
||||
_timeout: Duration,
|
||||
) -> Result<Result<MockConnection, Status>, tokio::time::error::Elapsed> {
|
||||
// Simulate connection creation immediately succeeding.
|
||||
CREATED.fetch_add(1, Ordering::SeqCst);
|
||||
let next_id = self.counter.fetch_add(1, Ordering::Relaxed);
|
||||
Ok(Ok(MockConnection { id: next_id }))
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------
|
||||
// CLIENT WORKER
|
||||
// --------------------------------------
|
||||
//
|
||||
// Each worker repeatedly calls `pool.get_client().await`. When it succeeds, we:
|
||||
// 1. Lock the shared Mutex<HashMap<u64, Arc<AtomicUsize>>> to fetch/insert an Arc<AtomicUsize> for this conn_id.
|
||||
// 2. Lock the shared Mutex<HashSet<u64>> to record this conn_id as “seen.”
|
||||
// 3. Drop both locks, then atomically increment that counter and assert it ≤ max_consumers.
|
||||
// 4. Sleep 10–100 ms to simulate “work.”
|
||||
// 5. Atomically decrement the counter.
|
||||
// 6. Call `pooled.finish(Ok(()))` to return to the pool.
|
||||
|
||||
async fn client_worker(
|
||||
pool: Arc<ConnectionPool<MockConnection>>,
|
||||
usage_map: Arc<Mutex<HashMap<u64, Arc<AtomicUsize>>>>,
|
||||
seen_set: Arc<Mutex<HashSet<u64>>>,
|
||||
max_consumers: usize,
|
||||
worker_id: usize,
|
||||
) {
|
||||
for iteration in 0..10 {
|
||||
match pool.clone().get_client().await {
|
||||
Ok(pooled) => {
|
||||
let conn: MockConnection = pooled.channel();
|
||||
let conn_id = conn.id;
|
||||
|
||||
// 1. Fetch or insert the Arc<AtomicUsize> for this conn_id:
|
||||
let counter_arc: Arc<AtomicUsize> = {
|
||||
let mut guard = usage_map.lock().unwrap();
|
||||
guard
|
||||
.entry(conn_id)
|
||||
.or_insert_with(|| Arc::new(AtomicUsize::new(0)))
|
||||
.clone()
|
||||
// MutexGuard is dropped here
|
||||
};
|
||||
|
||||
// 2. Record this conn_id in the shared HashSet of “seen” IDs:
|
||||
{
|
||||
let mut seen_guard = seen_set.lock().unwrap();
|
||||
seen_guard.insert(conn_id);
|
||||
// MutexGuard is dropped immediately
|
||||
}
|
||||
|
||||
// 3. Atomically bump the count for this connection ID
|
||||
let prev = counter_arc.fetch_add(1, Ordering::SeqCst);
|
||||
let current = prev + 1;
|
||||
assert!(
|
||||
current <= max_consumers,
|
||||
"Connection {} exceeded max_consumers (got {})",
|
||||
conn_id,
|
||||
current
|
||||
);
|
||||
|
||||
println!(
|
||||
"[worker {}][iter {}] got MockConnection id={} ({} concurrent)",
|
||||
worker_id, iteration, conn_id, current
|
||||
);
|
||||
|
||||
// 4. Simulate some work (10–100 ms)
|
||||
let delay_ms = rand::thread_rng().gen_range(10..100);
|
||||
sleep(Duration::from_millis(delay_ms)).await;
|
||||
|
||||
// 5. Decrement the usage counter
|
||||
let prev2 = counter_arc.fetch_sub(1, Ordering::SeqCst);
|
||||
let after = prev2 - 1;
|
||||
println!(
|
||||
"[worker {}][iter {}] returning MockConnection id={} (now {} remain)",
|
||||
worker_id, iteration, conn_id, after
|
||||
);
|
||||
|
||||
// 6. Return to the pool (mark success)
|
||||
pooled.finish(Ok(())).await;
|
||||
}
|
||||
Err(status) => {
|
||||
eprintln!(
|
||||
"[worker {}][iter {}] failed to get client: {:?}",
|
||||
worker_id, iteration, status
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Small random pause before next iteration to spread out load
|
||||
let pause = rand::thread_rng().gen_range(0..20);
|
||||
sleep(Duration::from_millis(pause)).await;
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main(flavor = "multi_thread", worker_threads = 8)]
|
||||
async fn main() {
|
||||
// --------------------------------------
|
||||
// 1. Create factory and shared instrumentation
|
||||
// --------------------------------------
|
||||
let factory = Arc::new(MockConnectionFactory::new());
|
||||
|
||||
// Shared map: connection ID → Arc<AtomicUsize>
|
||||
let usage_map: Arc<Mutex<HashMap<u64, Arc<AtomicUsize>>>> =
|
||||
Arc::new(Mutex::new(HashMap::new()));
|
||||
|
||||
// Shared set: record each unique connection ID we actually saw
|
||||
let seen_set: Arc<Mutex<HashSet<u64>>> = Arc::new(Mutex::new(HashSet::new()));
|
||||
|
||||
// --------------------------------------
|
||||
// 2. Pool parameters
|
||||
// --------------------------------------
|
||||
let connect_timeout = Duration::from_millis(500);
|
||||
let connect_backoff = Duration::from_millis(100);
|
||||
let max_consumers = 100; // test limit
|
||||
let error_threshold = 2; // mock never fails
|
||||
let max_idle_duration = Duration::from_secs(2);
|
||||
let max_total_connections = 3;
|
||||
let aggregate_metrics = None;
|
||||
|
||||
let pool: Arc<ConnectionPool<MockConnection>> = ConnectionPool::new(
|
||||
factory,
|
||||
connect_timeout,
|
||||
connect_backoff,
|
||||
max_consumers,
|
||||
error_threshold,
|
||||
max_idle_duration,
|
||||
max_total_connections,
|
||||
aggregate_metrics,
|
||||
);
|
||||
|
||||
// --------------------------------------
|
||||
// 3. Spawn worker tasks
|
||||
// --------------------------------------
|
||||
let num_workers = 10000;
|
||||
let mut handles = Vec::with_capacity(num_workers);
|
||||
let start_time = Instant::now();
|
||||
|
||||
for worker_id in 0..num_workers {
|
||||
let pool_clone = Arc::clone(&pool);
|
||||
let usage_clone = Arc::clone(&usage_map);
|
||||
let seen_clone = Arc::clone(&seen_set);
|
||||
let mc = max_consumers;
|
||||
|
||||
let handle = task::spawn(async move {
|
||||
client_worker(pool_clone, usage_clone, seen_clone, mc, worker_id).await;
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// --------------------------------------
|
||||
// 4. Wait for workers to finish
|
||||
// --------------------------------------
|
||||
for handle in handles {
|
||||
let _ = handle.await;
|
||||
}
|
||||
let elapsed = Instant::now().duration_since(start_time);
|
||||
println!(
|
||||
"All {} workers completed in {:?}",
|
||||
num_workers, elapsed
|
||||
);
|
||||
|
||||
// --------------------------------------
|
||||
// 5. Print the total number of unique connections seen so far
|
||||
// --------------------------------------
|
||||
let unique_count = {
|
||||
let seen_guard = seen_set.lock().unwrap();
|
||||
seen_guard.len()
|
||||
};
|
||||
println!("Total unique connections used by workers: {}", unique_count);
|
||||
|
||||
// --------------------------------------
|
||||
// 6. Sleep so the background sweeper can run (max_idle_duration = 2 s)
|
||||
// --------------------------------------
|
||||
sleep(Duration::from_secs(3)).await;
|
||||
|
||||
// --------------------------------------
|
||||
// 7. Shutdown the pool
|
||||
// --------------------------------------
|
||||
let shutdown_pool = Arc::clone(&pool);
|
||||
shutdown_pool.shutdown().await;
|
||||
println!("Pool.shutdown() returned.");
|
||||
|
||||
// --------------------------------------
|
||||
// 8. Verify that no background task still holds an Arc clone of `pool`.
|
||||
// If any task is still alive (sweeper/create_connection), strong_count > 1.
|
||||
// --------------------------------------
|
||||
sleep(Duration::from_secs(1)).await; // give tasks time to exit
|
||||
let sc = Arc::strong_count(&pool);
|
||||
assert!(
|
||||
sc == 1,
|
||||
"Pool tasks did not all terminate: Arc::strong_count = {} (expected 1)",
|
||||
sc
|
||||
);
|
||||
println!("Verified: all pool tasks have terminated (strong_count == 1).");
|
||||
|
||||
// --------------------------------------
|
||||
// 9. Verify no MockConnection was leaked:
|
||||
// CREATED must equal DROPPED.
|
||||
// --------------------------------------
|
||||
let created = CREATED.load(Ordering::SeqCst);
|
||||
let dropped = DROPPED.load(Ordering::SeqCst);
|
||||
assert!(
|
||||
created == dropped,
|
||||
"Leaked connections: created={} but dropped={}",
|
||||
created,
|
||||
dropped
|
||||
);
|
||||
println!(
|
||||
"Verified: no connections leaked (created = {}, dropped = {}).",
|
||||
created, dropped
|
||||
);
|
||||
|
||||
// --------------------------------------
|
||||
// 10. Because `client_worker` asserted inside that no connection
|
||||
// ever exceeded `max_consumers`, reaching this point means that check passed.
|
||||
// --------------------------------------
|
||||
println!("All per-connection usage stayed within max_consumers = {}.", max_consumers);
|
||||
|
||||
println!("Load test complete; exiting cleanly.");
|
||||
}
|
||||
160
pageserver/client_grpc/examples/request_tracker_load_test.rs
Normal file
160
pageserver/client_grpc/examples/request_tracker_load_test.rs
Normal file
@@ -0,0 +1,160 @@
|
||||
// examples/request_tracker_load_test.rs
|
||||
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use tokio;
|
||||
use pageserver_client_grpc::request_tracker::RequestTracker;
|
||||
use pageserver_client_grpc::request_tracker::MockStreamFactory;
|
||||
use pageserver_client_grpc::request_tracker::StreamReturner;
|
||||
use pageserver_client_grpc::client_cache::ConnectionPool;
|
||||
use pageserver_client_grpc::client_cache::PooledItemFactory;
|
||||
use pageserver_client_grpc::ClientCacheOptions;
|
||||
use pageserver_client_grpc::PageserverClientAggregateMetrics;
|
||||
use pageserver_client_grpc::AuthInterceptor;
|
||||
|
||||
use pageserver_client_grpc::client_cache::ChannelFactory;
|
||||
|
||||
use tonic::{transport::{Channel}, Request};
|
||||
|
||||
use rand::prelude::*;
|
||||
|
||||
use pageserver_api::key::Key;
|
||||
|
||||
use utils::lsn::Lsn;
|
||||
use utils::id::TenantTimelineId;
|
||||
|
||||
use futures::stream::FuturesOrdered;
|
||||
use futures::StreamExt;
|
||||
// use chrono
|
||||
use chrono::Utc;
|
||||
|
||||
use pageserver_page_api::{GetPageClass, GetPageResponse};
|
||||
use pageserver_page_api::proto;
|
||||
#[derive(Clone)]
|
||||
struct KeyRange {
|
||||
timeline: TenantTimelineId,
|
||||
timeline_lsn: Lsn,
|
||||
start: i128,
|
||||
end: i128,
|
||||
}
|
||||
|
||||
impl KeyRange {
|
||||
fn len(&self) -> i128 {
|
||||
self.end - self.start
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
// 1) configure the client‐pool behavior
|
||||
let client_cache_options = ClientCacheOptions {
|
||||
max_delay_ms: 0,
|
||||
drop_rate: 0.0,
|
||||
hang_rate: 0.0,
|
||||
connect_timeout: Duration::from_secs(10),
|
||||
connect_backoff: Duration::from_millis(200),
|
||||
max_consumers: 64,
|
||||
error_threshold: 10,
|
||||
max_idle_duration: Duration::from_secs(60),
|
||||
max_total_connections: 12,
|
||||
};
|
||||
|
||||
// 2) metrics collector (we assume Default is implemented)
|
||||
let metrics = Arc::new(PageserverClientAggregateMetrics::new());
|
||||
let pool = ConnectionPool::<StreamReturner>::new(
|
||||
Arc::new(MockStreamFactory::new(
|
||||
)),
|
||||
client_cache_options.connect_timeout,
|
||||
client_cache_options.connect_backoff,
|
||||
client_cache_options.max_consumers,
|
||||
client_cache_options.error_threshold,
|
||||
client_cache_options.max_idle_duration,
|
||||
client_cache_options.max_total_connections,
|
||||
Some(Arc::clone(&metrics)),
|
||||
);
|
||||
|
||||
// -----------
|
||||
// There is no mock for the unary connection pool, so for now just
|
||||
// don't use this pool
|
||||
//
|
||||
let channel_fact : Arc<dyn PooledItemFactory<Channel> + Send + Sync> = Arc::new(ChannelFactory::new(
|
||||
"".to_string(),
|
||||
client_cache_options.max_delay_ms,
|
||||
client_cache_options.drop_rate,
|
||||
client_cache_options.hang_rate,
|
||||
));
|
||||
let unary_pool: Arc<ConnectionPool<Channel>> = ConnectionPool::new(
|
||||
Arc::clone(&channel_fact),
|
||||
client_cache_options.connect_timeout,
|
||||
client_cache_options.connect_backoff,
|
||||
client_cache_options.max_consumers,
|
||||
client_cache_options.error_threshold,
|
||||
client_cache_options.max_idle_duration,
|
||||
client_cache_options.max_total_connections,
|
||||
Some(Arc::clone(&metrics)),
|
||||
);
|
||||
|
||||
// -----------
|
||||
// Dummy auth interceptor. This is not used in this test.
|
||||
let auth_interceptor = AuthInterceptor::new("dummy_tenant_id",
|
||||
"dummy_timeline_id",
|
||||
None);
|
||||
let mut tracker = RequestTracker::new(
|
||||
pool,
|
||||
unary_pool,
|
||||
auth_interceptor,
|
||||
);
|
||||
|
||||
// 4) fire off 10 000 requests in parallel
|
||||
let mut handles = FuturesOrdered::new();
|
||||
for i in 0..500000 {
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
let r = 0..=1000000i128;
|
||||
let key: i128 = rng.gen_range(r.clone());
|
||||
let key = Key::from_i128(key);
|
||||
let (rel_tag, block_no) = key
|
||||
.to_rel_block()
|
||||
.expect("we filter non-rel-block keys out above");
|
||||
|
||||
let req2 = proto::GetPageRequest {
|
||||
request_id: 0,
|
||||
request_class: proto::GetPageClass::Normal as i32,
|
||||
read_lsn: Some(proto::ReadLsn {
|
||||
request_lsn: if rng.gen_bool(0.5) {
|
||||
u64::from(Lsn::MAX)
|
||||
} else {
|
||||
10000
|
||||
},
|
||||
not_modified_since_lsn: 10000,
|
||||
}),
|
||||
rel: Some(rel_tag.into()),
|
||||
block_number: vec![block_no],
|
||||
};
|
||||
let req_model = pageserver_page_api::GetPageRequest::try_from(req2.clone());
|
||||
|
||||
// RequestTracker is Clone, so we can share it
|
||||
let mut tr = tracker.clone();
|
||||
let fut = async move {
|
||||
let resp = tr.send_getpage_request(req_model.unwrap()).await.unwrap();
|
||||
// sanity‐check: the mock echo returns the same request_id
|
||||
assert!(resp.request_id > 0);
|
||||
};
|
||||
handles.push_back(fut);
|
||||
|
||||
// empty future
|
||||
let fut = async move {};
|
||||
fut.await;
|
||||
}
|
||||
|
||||
// print timestamp
|
||||
println!("Starting 5000000 requests at: {}", chrono::Utc::now());
|
||||
// 5) wait for them all
|
||||
for i in 0..500000 {
|
||||
handles.next().await.expect("Failed to get next handle");
|
||||
}
|
||||
|
||||
// print timestamp
|
||||
println!("Finished 5000000 requests at: {}", chrono::Utc::now());
|
||||
|
||||
println!("✅ All 100000 requests completed successfully");
|
||||
}
|
||||
@@ -31,6 +31,7 @@ use hyper_util::rt::TokioIo;
|
||||
use tower::service_fn;
|
||||
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use async_trait::async_trait;
|
||||
|
||||
//
|
||||
// The "TokioTcp" is flakey TCP network for testing purposes, in order
|
||||
@@ -164,32 +165,132 @@ impl AsyncWrite for TokioTcp {
|
||||
}
|
||||
}
|
||||
|
||||
/// A pooled gRPC client with capacity tracking and error handling.
|
||||
pub struct ConnectionPool {
|
||||
inner: Mutex<Inner>,
|
||||
#[async_trait]
|
||||
pub trait PooledItemFactory<T>: Send + Sync + 'static {
|
||||
/// Create a new pooled item.
|
||||
async fn create(&self, connect_timeout: Duration) -> Result<Result<T, tonic::Status>, tokio::time::error::Elapsed>;
|
||||
}
|
||||
|
||||
// Config options that apply to each connection
|
||||
pub struct ChannelFactory {
|
||||
endpoint: String,
|
||||
max_consumers: usize,
|
||||
error_threshold: usize,
|
||||
connect_timeout: Duration,
|
||||
connect_backoff: Duration,
|
||||
|
||||
// Parameters for testing
|
||||
max_delay_ms: u64,
|
||||
drop_rate: f64,
|
||||
hang_rate: f64,
|
||||
}
|
||||
|
||||
// The maximum duration a connection can be idle before being removed
|
||||
|
||||
impl ChannelFactory {
|
||||
pub fn new(
|
||||
endpoint: String,
|
||||
max_delay_ms: u64,
|
||||
drop_rate: f64,
|
||||
hang_rate: f64,
|
||||
) -> Self {
|
||||
ChannelFactory {
|
||||
endpoint,
|
||||
max_delay_ms,
|
||||
drop_rate,
|
||||
hang_rate,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl PooledItemFactory<Channel> for ChannelFactory {
|
||||
async fn create(&self, connect_timeout: Duration) -> Result<Result<Channel, tonic::Status>, tokio::time::error::Elapsed> {
|
||||
let max_delay_ms = self.max_delay_ms;
|
||||
let drop_rate = self.drop_rate;
|
||||
let hang_rate = self.hang_rate;
|
||||
|
||||
// This is a custom connector that inserts delays and errors, for
|
||||
// testing purposes. It would normally be disabled by the config.
|
||||
let connector = service_fn(move |uri: Uri| {
|
||||
let drop_rate = drop_rate;
|
||||
let hang_rate = hang_rate;
|
||||
async move {
|
||||
let mut rng = StdRng::from_entropy();
|
||||
// Simulate an indefinite hang
|
||||
if hang_rate > 0.0 && rng.gen_bool(hang_rate) {
|
||||
// never completes, to test timeout
|
||||
return future::pending::<Result<TokioIo<TokioTcp>, std::io::Error>>().await;
|
||||
}
|
||||
|
||||
// Random drop (connect error)
|
||||
if drop_rate > 0.0 && rng.gen_bool(drop_rate) {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
"simulated connect drop",
|
||||
));
|
||||
}
|
||||
|
||||
// Otherwise perform real TCP connect
|
||||
let addr = match (uri.host(), uri.port()) {
|
||||
// host + explicit port
|
||||
(Some(host), Some(port)) => format!("{}:{}", host, port.as_str()),
|
||||
// host only (no port)
|
||||
(Some(host), None) => host.to_string(),
|
||||
// neither? error out
|
||||
_ => return Err(Error::new(ErrorKind::InvalidInput, "no host or port")),
|
||||
};
|
||||
|
||||
let tcp = TcpStream::connect(addr).await?;
|
||||
let tcpwrapper = TokioTcp::new(tcp, max_delay_ms);
|
||||
Ok(TokioIo::new(tcpwrapper))
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
let attempt = tokio::time::timeout(
|
||||
connect_timeout,
|
||||
Endpoint::from_shared(self.endpoint.clone())
|
||||
.expect("invalid endpoint")
|
||||
.timeout(connect_timeout)
|
||||
.connect_with_connector(connector),
|
||||
)
|
||||
.await;
|
||||
match attempt {
|
||||
Ok(Ok(channel)) => {
|
||||
// Connection succeeded
|
||||
Ok(Ok(channel))
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
Ok(Err(tonic::Status::new(
|
||||
tonic::Code::Unavailable,
|
||||
format!("Failed to connect: {}", e),
|
||||
)))
|
||||
}
|
||||
Err(e) => {
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// A pooled gRPC client with capacity tracking and error handling.
|
||||
pub struct ConnectionPool<T> {
|
||||
inner: Mutex<Inner<T>>,
|
||||
|
||||
fact: Arc<dyn PooledItemFactory<T> + Send + Sync>,
|
||||
|
||||
connect_timeout: Duration,
|
||||
connect_backoff: Duration,
|
||||
/// The maximum number of consumers that can use a single connection.
|
||||
max_consumers: usize,
|
||||
/// The number of consecutive errors before a connection is removed from the pool.
|
||||
error_threshold: usize,
|
||||
/// The maximum duration a connection can be idle before being removed.
|
||||
max_idle_duration: Duration,
|
||||
max_total_connections: usize,
|
||||
|
||||
channel_semaphore: Arc<Semaphore>,
|
||||
|
||||
shutdown_token: CancellationToken,
|
||||
aggregate_metrics: Option<Arc<crate::PageserverClientAggregateMetrics>>,
|
||||
}
|
||||
|
||||
struct Inner {
|
||||
entries: HashMap<uuid::Uuid, ConnectionEntry>,
|
||||
struct Inner<T> {
|
||||
entries: HashMap<uuid::Uuid, ConnectionEntry<T>>,
|
||||
pq: PriorityQueue<uuid::Uuid, usize>,
|
||||
// This is updated when a connection is dropped, or we fail
|
||||
// to create a new connection.
|
||||
@@ -197,54 +298,50 @@ struct Inner {
|
||||
waiters: usize,
|
||||
in_progress: usize,
|
||||
}
|
||||
|
||||
struct ConnectionEntry {
|
||||
channel: Channel,
|
||||
struct ConnectionEntry<T> {
|
||||
channel: T,
|
||||
active_consumers: usize,
|
||||
consecutive_errors: usize,
|
||||
last_used: Instant,
|
||||
}
|
||||
|
||||
/// A client borrowed from the pool.
|
||||
pub struct PooledClient {
|
||||
pub channel: Channel,
|
||||
pool: Arc<ConnectionPool>,
|
||||
pub struct PooledClient<T> {
|
||||
pub channel: T,
|
||||
pool: Arc<ConnectionPool<T>>,
|
||||
is_ok: bool,
|
||||
id: uuid::Uuid,
|
||||
permit: OwnedSemaphorePermit,
|
||||
}
|
||||
|
||||
impl ConnectionPool {
|
||||
impl<T: Clone + Send + 'static> ConnectionPool<T> {
|
||||
pub fn new(
|
||||
endpoint: &String,
|
||||
max_consumers: usize,
|
||||
error_threshold: usize,
|
||||
fact: Arc<dyn PooledItemFactory<T> + Send + Sync>,
|
||||
connect_timeout: Duration,
|
||||
connect_backoff: Duration,
|
||||
max_consumers: usize,
|
||||
error_threshold: usize,
|
||||
max_idle_duration: Duration,
|
||||
max_delay_ms: u64,
|
||||
drop_rate: f64,
|
||||
hang_rate: f64,
|
||||
max_total_connections: usize,
|
||||
aggregate_metrics: Option<Arc<crate::PageserverClientAggregateMetrics>>,
|
||||
) -> Arc<Self> {
|
||||
let shutdown_token = CancellationToken::new();
|
||||
let pool = Arc::new(Self {
|
||||
inner: Mutex::new(Inner {
|
||||
inner: Mutex::new(Inner::<T> {
|
||||
entries: HashMap::new(),
|
||||
pq: PriorityQueue::new(),
|
||||
last_connect_failure: None,
|
||||
waiters: 0,
|
||||
in_progress: 0,
|
||||
}),
|
||||
fact: Arc::clone(&fact),
|
||||
connect_timeout,
|
||||
connect_backoff,
|
||||
max_consumers,
|
||||
error_threshold,
|
||||
max_idle_duration,
|
||||
max_total_connections,
|
||||
channel_semaphore: Arc::new(Semaphore::new(0)),
|
||||
endpoint: endpoint.clone(),
|
||||
max_consumers: max_consumers,
|
||||
error_threshold: error_threshold,
|
||||
connect_timeout: connect_timeout,
|
||||
connect_backoff: connect_backoff,
|
||||
max_idle_duration: max_idle_duration,
|
||||
max_delay_ms: max_delay_ms,
|
||||
drop_rate: drop_rate,
|
||||
hang_rate: hang_rate,
|
||||
shutdown_token: shutdown_token.clone(),
|
||||
aggregate_metrics: aggregate_metrics.clone(),
|
||||
});
|
||||
@@ -325,7 +422,7 @@ impl ConnectionPool {
|
||||
async fn get_conn_with_permit(
|
||||
self: Arc<Self>,
|
||||
permit: OwnedSemaphorePermit,
|
||||
) -> Option<PooledClient> {
|
||||
) -> Option<PooledClient<T>> {
|
||||
let mut inner = self.inner.lock().await;
|
||||
|
||||
// Pop the highest-active-consumers connection. There are no connections
|
||||
@@ -340,9 +437,10 @@ impl ConnectionPool {
|
||||
entry.active_consumers += 1;
|
||||
entry.last_used = Instant::now();
|
||||
|
||||
let client = PooledClient {
|
||||
let client = PooledClient::<T> {
|
||||
channel: entry.channel.clone(),
|
||||
pool: Arc::clone(&self),
|
||||
is_ok: true,
|
||||
id,
|
||||
permit: permit,
|
||||
};
|
||||
@@ -365,7 +463,7 @@ impl ConnectionPool {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_client(self: Arc<Self>) -> Result<PooledClient, tonic::Status> {
|
||||
pub async fn get_client(self: Arc<Self>) -> Result<PooledClient<T>, tonic::Status> {
|
||||
// The pool is shutting down. Don't accept new connections.
|
||||
if self.shutdown_token.is_cancelled() {
|
||||
return Err(tonic::Status::unavailable("Pool is shutting down"));
|
||||
@@ -412,12 +510,16 @@ impl ConnectionPool {
|
||||
//
|
||||
let mut inner = self_clone.inner.lock().await;
|
||||
inner.waiters += 1;
|
||||
if inner.waiters >= (inner.in_progress * self_clone.max_consumers) {
|
||||
let self_clone_spawn = Arc::clone(&self_clone);
|
||||
tokio::task::spawn(async move {
|
||||
self_clone_spawn.create_connection().await;
|
||||
});
|
||||
inner.in_progress += 1;
|
||||
if inner.waiters > (inner.in_progress * self_clone.max_consumers) {
|
||||
if (inner.entries.len() + inner.in_progress) < self_clone.max_total_connections {
|
||||
|
||||
let self_clone_spawn = Arc::clone(&self_clone);
|
||||
tokio::task::spawn(async move {
|
||||
self_clone_spawn.create_connection().await;
|
||||
});
|
||||
inner.in_progress += 1;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
// Wait for a connection to become available, either because it
|
||||
@@ -446,46 +548,6 @@ impl ConnectionPool {
|
||||
}
|
||||
|
||||
async fn create_connection(&self) -> () {
|
||||
let max_delay_ms = self.max_delay_ms;
|
||||
let drop_rate = self.drop_rate;
|
||||
let hang_rate = self.hang_rate;
|
||||
|
||||
// This is a custom connector that inserts delays and errors, for
|
||||
// testing purposes. It would normally be disabled by the config.
|
||||
let connector = service_fn(move |uri: Uri| {
|
||||
let drop_rate = drop_rate;
|
||||
let hang_rate = hang_rate;
|
||||
async move {
|
||||
let mut rng = StdRng::from_entropy();
|
||||
// Simulate an indefinite hang
|
||||
if hang_rate > 0.0 && rng.gen_bool(hang_rate) {
|
||||
// never completes, to test timeout
|
||||
return future::pending::<Result<TokioIo<TokioTcp>, std::io::Error>>().await;
|
||||
}
|
||||
|
||||
// Random drop (connect error)
|
||||
if drop_rate > 0.0 && rng.gen_bool(drop_rate) {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
"simulated connect drop",
|
||||
));
|
||||
}
|
||||
|
||||
// Otherwise perform real TCP connect
|
||||
let addr = match (uri.host(), uri.port()) {
|
||||
// host + explicit port
|
||||
(Some(host), Some(port)) => format!("{}:{}", host, port.as_str()),
|
||||
// host only (no port)
|
||||
(Some(host), None) => host.to_string(),
|
||||
// neither? error out
|
||||
_ => return Err(Error::new(ErrorKind::InvalidInput, "no host or port")),
|
||||
};
|
||||
|
||||
let tcp = TcpStream::connect(addr).await?;
|
||||
let tcpwrapper = TokioTcp::new(tcp, max_delay_ms);
|
||||
Ok(TokioIo::new(tcpwrapper))
|
||||
}
|
||||
});
|
||||
|
||||
// Generate a random backoff to add some jitter so that connections
|
||||
// don't all retry at the same time.
|
||||
@@ -533,14 +595,9 @@ impl ConnectionPool {
|
||||
None => {}
|
||||
}
|
||||
|
||||
let attempt = tokio::time::timeout(
|
||||
self.connect_timeout,
|
||||
Endpoint::from_shared(self.endpoint.clone())
|
||||
.expect("invalid endpoint")
|
||||
.timeout(self.connect_timeout)
|
||||
.connect_with_connector(connector),
|
||||
)
|
||||
.await;
|
||||
let attempt = self.fact
|
||||
.create(self.connect_timeout)
|
||||
.await;
|
||||
|
||||
match attempt {
|
||||
// Connection succeeded
|
||||
@@ -559,7 +616,7 @@ impl ConnectionPool {
|
||||
let id = uuid::Uuid::new_v4();
|
||||
inner.entries.insert(
|
||||
id,
|
||||
ConnectionEntry {
|
||||
ConnectionEntry::<T> {
|
||||
channel: channel.clone(),
|
||||
active_consumers: 0,
|
||||
consecutive_errors: 0,
|
||||
@@ -641,6 +698,11 @@ impl ConnectionPool {
|
||||
inner.pq.remove(&id);
|
||||
}
|
||||
|
||||
// remove from entries
|
||||
// check if entry is in inner
|
||||
if inner.entries.contains_key(&id) {
|
||||
inner.entries.remove(&id);
|
||||
}
|
||||
inner.last_connect_failure = Some(Instant::now());
|
||||
|
||||
// The connection has been removed, it's permits will be
|
||||
@@ -661,18 +723,19 @@ impl ConnectionPool {
|
||||
}
|
||||
}
|
||||
}
|
||||
// The semaphore permit is released when the pooled client is dropped.
|
||||
}
|
||||
}
|
||||
|
||||
impl PooledClient {
|
||||
pub fn channel(&self) -> Channel {
|
||||
impl<T: Clone + Send + 'static> PooledClient<T> {
|
||||
pub fn channel(&self) -> T {
|
||||
return self.channel.clone();
|
||||
}
|
||||
|
||||
pub async fn finish(self, result: Result<(), tonic::Status>) {
|
||||
self.pool
|
||||
.return_client(self.id, result.is_ok(), self.permit)
|
||||
.await;
|
||||
pub async fn finish(mut self, result: Result<(), tonic::Status>) {
|
||||
self.is_ok = result.is_ok();
|
||||
self.pool.return_client(
|
||||
self.id,
|
||||
self.is_ok,
|
||||
self.permit,
|
||||
).await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,9 +20,16 @@ use pageserver_page_api::proto::PageServiceClient;
|
||||
use utils::shard::ShardIndex;
|
||||
|
||||
use std::fmt::Debug;
|
||||
mod client_cache;
|
||||
pub mod client_cache;
|
||||
pub mod request_tracker;
|
||||
use tonic::transport::Channel;
|
||||
|
||||
use metrics::{IntCounterVec, core::Collector};
|
||||
use crate::client_cache::{PooledItemFactory};
|
||||
|
||||
use tokio::sync::mpsc;
|
||||
use async_trait::async_trait;
|
||||
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum PageserverClientError {
|
||||
@@ -77,6 +84,7 @@ impl PageserverClientAggregateMetrics {
|
||||
metrics
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PageserverClient {
|
||||
_tenant_id: String,
|
||||
_timeline_id: String,
|
||||
@@ -85,7 +93,7 @@ pub struct PageserverClient {
|
||||
|
||||
shard_map: HashMap<ShardIndex, String>,
|
||||
|
||||
channels: RwLock<HashMap<ShardIndex, Arc<client_cache::ConnectionPool>>>,
|
||||
channels: RwLock<HashMap<ShardIndex, Arc<client_cache::ConnectionPool<Channel>>>>,
|
||||
|
||||
auth_interceptor: AuthInterceptor,
|
||||
|
||||
@@ -93,13 +101,14 @@ pub struct PageserverClient {
|
||||
|
||||
aggregate_metrics: Option<Arc<PageserverClientAggregateMetrics>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ClientCacheOptions {
|
||||
pub max_consumers: usize,
|
||||
pub error_threshold: usize,
|
||||
pub connect_timeout: Duration,
|
||||
pub connect_backoff: Duration,
|
||||
pub max_idle_duration: Duration,
|
||||
pub max_total_connections: usize,
|
||||
pub max_delay_ms: u64,
|
||||
pub drop_rate: f64,
|
||||
pub hang_rate: f64,
|
||||
@@ -119,6 +128,7 @@ impl PageserverClient {
|
||||
connect_timeout: Duration::from_secs(5),
|
||||
connect_backoff: Duration::from_secs(1),
|
||||
max_idle_duration: Duration::from_secs(60),
|
||||
max_total_connections: 100000,
|
||||
max_delay_ms: 0,
|
||||
drop_rate: 0.0,
|
||||
hang_rate: 0.0,
|
||||
@@ -349,13 +359,13 @@ impl PageserverClient {
|
||||
///
|
||||
/// Get a client from the pool for this shard, also creating the pool if it doesn't exist.
|
||||
///
|
||||
async fn get_client(&self, shard: ShardIndex) -> client_cache::PooledClient {
|
||||
let reused_pool: Option<Arc<client_cache::ConnectionPool>> = {
|
||||
async fn get_client(&self, shard: ShardIndex) -> client_cache::PooledClient<Channel> {
|
||||
let reused_pool: Option<Arc<client_cache::ConnectionPool<Channel>>> = {
|
||||
let channels = self.channels.read().unwrap();
|
||||
channels.get(&shard).cloned()
|
||||
};
|
||||
|
||||
let usable_pool: Arc<client_cache::ConnectionPool>;
|
||||
let usable_pool: Arc<client_cache::ConnectionPool<Channel>>;
|
||||
match reused_pool {
|
||||
Some(pool) => {
|
||||
let pooled_client = pool.get_client().await.unwrap();
|
||||
@@ -365,17 +375,21 @@ impl PageserverClient {
|
||||
// Create a new pool using client_cache_options
|
||||
// declare new_pool
|
||||
|
||||
let new_pool: Arc<client_cache::ConnectionPool>;
|
||||
new_pool = client_cache::ConnectionPool::new(
|
||||
self.shard_map.get(&shard).unwrap(),
|
||||
self.client_cache_options.max_consumers,
|
||||
self.client_cache_options.error_threshold,
|
||||
self.client_cache_options.connect_timeout,
|
||||
self.client_cache_options.connect_backoff,
|
||||
self.client_cache_options.max_idle_duration,
|
||||
let new_pool: Arc<client_cache::ConnectionPool<Channel>>;
|
||||
let channel_fact = Arc::new(client_cache::ChannelFactory::new(
|
||||
self.shard_map.get(&shard).unwrap().clone(),
|
||||
self.client_cache_options.max_delay_ms,
|
||||
self.client_cache_options.drop_rate,
|
||||
self.client_cache_options.hang_rate,
|
||||
));
|
||||
new_pool = client_cache::ConnectionPool::new(
|
||||
channel_fact,
|
||||
self.client_cache_options.connect_timeout,
|
||||
self.client_cache_options.connect_backoff,
|
||||
self.client_cache_options.max_consumers,
|
||||
self.client_cache_options.error_threshold,
|
||||
self.client_cache_options.max_idle_duration,
|
||||
self.client_cache_options.max_total_connections,
|
||||
self.aggregate_metrics.clone(),
|
||||
);
|
||||
let mut write_pool = self.channels.write().unwrap();
|
||||
@@ -391,7 +405,7 @@ impl PageserverClient {
|
||||
|
||||
/// Inject tenant_id, timeline_id and authentication token to all pageserver requests.
|
||||
#[derive(Clone)]
|
||||
struct AuthInterceptor {
|
||||
pub struct AuthInterceptor {
|
||||
tenant_id: AsciiMetadataValue,
|
||||
shard_id: Option<AsciiMetadataValue>,
|
||||
timeline_id: AsciiMetadataValue,
|
||||
@@ -400,7 +414,7 @@ struct AuthInterceptor {
|
||||
}
|
||||
|
||||
impl AuthInterceptor {
|
||||
fn new(tenant_id: &str, timeline_id: &str, auth_token: Option<&str>) -> Self {
|
||||
pub fn new(tenant_id: &str, timeline_id: &str, auth_token: Option<&str>) -> Self {
|
||||
Self {
|
||||
tenant_id: tenant_id.parse().expect("could not parse tenant id"),
|
||||
shard_id: None,
|
||||
|
||||
590
pageserver/client_grpc/src/request_tracker.rs
Normal file
590
pageserver/client_grpc/src/request_tracker.rs
Normal file
@@ -0,0 +1,590 @@
|
||||
|
||||
//
|
||||
// API Visible to the spawner, just a function call that is async
|
||||
//
|
||||
use std::sync::Arc;
|
||||
use crate::client_cache;
|
||||
use pageserver_page_api::GetPageRequest;
|
||||
use pageserver_page_api::GetPageResponse;
|
||||
use pageserver_page_api::*;
|
||||
use pageserver_page_api::proto;
|
||||
use crate::client_cache::ConnectionPool;
|
||||
use crate::client_cache::ChannelFactory;
|
||||
use crate::AuthInterceptor;
|
||||
use tonic::{transport::{Channel}, Request};
|
||||
use crate::ClientCacheOptions;
|
||||
use crate::PageserverClientAggregateMetrics;
|
||||
use tokio::sync::Mutex;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
|
||||
use utils::shard::ShardIndex;
|
||||
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use pageserver_page_api::proto::PageServiceClient;
|
||||
|
||||
use tonic::{
|
||||
Status,
|
||||
Code,
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use std::time::Duration;
|
||||
|
||||
use client_cache::PooledItemFactory;
|
||||
//use tracing::info;
|
||||
//
|
||||
// A mock stream pool that just returns a sending channel, and whenever a GetPageRequest
|
||||
// comes in on that channel, it randomly sleeps before sending a GetPageResponse
|
||||
//
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct StreamReturner {
|
||||
sender: tokio::sync::mpsc::Sender<proto::GetPageRequest>,
|
||||
sender_hashmap: Arc<Mutex<std::collections::HashMap<u64, tokio::sync::mpsc::Sender<Result<proto::GetPageResponse, Status>>>>>,
|
||||
}
|
||||
pub struct MockStreamFactory {
|
||||
}
|
||||
|
||||
impl MockStreamFactory {
|
||||
pub fn new() -> Self {
|
||||
MockStreamFactory {
|
||||
}
|
||||
}
|
||||
}
|
||||
#[async_trait]
|
||||
impl PooledItemFactory<StreamReturner> for MockStreamFactory {
|
||||
async fn create(&self, _connect_timeout: Duration) -> Result<Result<StreamReturner, tonic::Status>, tokio::time::error::Elapsed> {
|
||||
let (sender, mut receiver) = tokio::sync::mpsc::channel::<proto::GetPageRequest>(1000);
|
||||
// Create a StreamReturner that will send requests to the receiver channel
|
||||
let stream_returner = StreamReturner {
|
||||
sender: sender.clone(),
|
||||
sender_hashmap: Arc::new(Mutex::new(std::collections::HashMap::new())),
|
||||
};
|
||||
|
||||
let map : Arc<Mutex<std::collections::HashMap<u64, tokio::sync::mpsc::Sender<Result<proto::GetPageResponse, _>>>>>
|
||||
= Arc::clone(&stream_returner.sender_hashmap);
|
||||
tokio::spawn(async move {
|
||||
while let Some(request) = receiver.recv().await {
|
||||
|
||||
// Break out of the loop with 1% chance
|
||||
if rand::random::<f32>() < 0.001 {
|
||||
break;
|
||||
}
|
||||
// Generate a random number between 0 and 100
|
||||
// Simulate some processing time
|
||||
let mapclone = Arc::clone(&map);
|
||||
tokio::spawn(async move {
|
||||
let sleep_ms = rand::random::<u64>() % 100;
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(sleep_ms)).await;
|
||||
let response = proto::GetPageResponse {
|
||||
request_id: request.request_id,
|
||||
..Default::default()
|
||||
};
|
||||
// look up stream in hash map
|
||||
let mut hashmap = mapclone.lock().await;
|
||||
if let Some(sender) = hashmap.get(&request.request_id) {
|
||||
// Send the response to the original request sender
|
||||
if let Err(e) = sender.send(Ok(response.clone())).await {
|
||||
eprintln!("Failed to send response: {}", e);
|
||||
}
|
||||
hashmap.remove(&request.request_id);
|
||||
} else {
|
||||
eprintln!("No sender found for request ID: {}", request.request_id);
|
||||
}
|
||||
});
|
||||
}
|
||||
// Close every sender stream in the hashmap
|
||||
let hashmap = map.lock().await;
|
||||
for sender in hashmap.values() {
|
||||
let error = Status::new(Code::Unknown, "Stream closed");
|
||||
if let Err(e) = sender.send(Err(error)).await {
|
||||
eprintln!("Failed to send close response: {}", e);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Ok(stream_returner))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub struct StreamFactory {
|
||||
connection_pool: Arc<client_cache::ConnectionPool<Channel>>,
|
||||
auth_interceptor: AuthInterceptor,
|
||||
shard: ShardIndex,
|
||||
}
|
||||
|
||||
impl StreamFactory {
|
||||
pub fn new(
|
||||
connection_pool: Arc<ConnectionPool<Channel>>,
|
||||
auth_interceptor: AuthInterceptor,
|
||||
shard: ShardIndex,
|
||||
) -> Self {
|
||||
StreamFactory {
|
||||
connection_pool,
|
||||
auth_interceptor,
|
||||
shard,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl PooledItemFactory<StreamReturner> for StreamFactory {
|
||||
async fn create(&self, _connect_timeout: Duration) ->
|
||||
Result<Result<StreamReturner, tonic::Status>, tokio::time::error::Elapsed>
|
||||
{
|
||||
let pool_clone : Arc<ConnectionPool<Channel>> = Arc::clone(&self.connection_pool);
|
||||
let pooled_client = pool_clone.get_client().await;
|
||||
let channel = pooled_client.unwrap().channel();
|
||||
let mut client =
|
||||
PageServiceClient::with_interceptor(channel, self.auth_interceptor.for_shard(self.shard));
|
||||
|
||||
let (sender, receiver) = tokio::sync::mpsc::channel::<proto::GetPageRequest>(1000);
|
||||
let outbound = ReceiverStream::new(receiver);
|
||||
|
||||
let client_resp = client
|
||||
.get_pages(Request::new(outbound))
|
||||
.await;
|
||||
|
||||
match client_resp {
|
||||
Err(status) => {
|
||||
// TODO: Convert this error correctly
|
||||
Ok(Err(tonic::Status::new(
|
||||
status.code(),
|
||||
format!("Failed to connect to pageserver: {}", status.message()),
|
||||
)))
|
||||
}
|
||||
Ok(resp) => {
|
||||
let stream_returner = StreamReturner {
|
||||
sender: sender.clone(),
|
||||
sender_hashmap: Arc::new(Mutex::new(std::collections::HashMap::new())),
|
||||
};
|
||||
let map : Arc<Mutex<std::collections::HashMap<u64, tokio::sync::mpsc::Sender<Result<proto::GetPageResponse, _>>>>>
|
||||
= Arc::clone(&stream_returner.sender_hashmap);
|
||||
|
||||
tokio::spawn(async move {
|
||||
|
||||
let map_clone = Arc::clone(&map);
|
||||
let mut inner = resp.into_inner();
|
||||
loop {
|
||||
|
||||
let resp = inner.message().await;
|
||||
if !resp.is_ok() {
|
||||
break; // Exit the loop if no more messages
|
||||
}
|
||||
let response = resp.unwrap().unwrap();
|
||||
|
||||
// look up stream in hash map
|
||||
let mut hashmap = map_clone.lock().await;
|
||||
if let Some(sender) = hashmap.get(&response.request_id) {
|
||||
// Send the response to the original request sender
|
||||
if let Err(e) = sender.send(Ok(response.clone())).await {
|
||||
eprintln!("Failed to send response: {}", e);
|
||||
}
|
||||
hashmap.remove(&response.request_id);
|
||||
} else {
|
||||
eprintln!("No sender found for request ID: {}", response.request_id);
|
||||
}
|
||||
}
|
||||
// Close every sender stream in the hashmap
|
||||
let hashmap = map_clone.lock().await;
|
||||
for sender in hashmap.values() {
|
||||
let error = Status::new(Code::Unknown, "Stream closed");
|
||||
if let Err(e) = sender.send(Err(error)).await {
|
||||
eprintln!("Failed to send close response: {}", e);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Ok(stream_returner))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RequestTracker {
|
||||
cur_id: Arc<AtomicU64>,
|
||||
stream_pool: Arc<ConnectionPool<StreamReturner>>,
|
||||
unary_pool: Arc<ConnectionPool<Channel>>,
|
||||
auth_interceptor: AuthInterceptor,
|
||||
shard: ShardIndex,
|
||||
}
|
||||
|
||||
impl RequestTracker {
|
||||
pub fn new(stream_pool: Arc<ConnectionPool<StreamReturner>>,
|
||||
unary_pool: Arc<ConnectionPool<Channel>>,
|
||||
auth_interceptor: AuthInterceptor,
|
||||
shard: ShardIndex,
|
||||
) -> Self {
|
||||
let cur_id = Arc::new(AtomicU64::new(0));
|
||||
|
||||
RequestTracker {
|
||||
cur_id: cur_id.clone(),
|
||||
stream_pool: stream_pool,
|
||||
unary_pool: unary_pool,
|
||||
auth_interceptor: auth_interceptor,
|
||||
shard: shard.clone()
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send_process_check_rel_exists_request(
|
||||
&self,
|
||||
req: CheckRelExistsRequest,
|
||||
) -> Result<bool, tonic::Status> {
|
||||
loop {
|
||||
let unary_pool = Arc::clone(&self.unary_pool);
|
||||
let pooled_client = unary_pool.get_client().await.unwrap();
|
||||
let channel = pooled_client.channel();
|
||||
let mut ps_client = PageServiceClient::with_interceptor(channel, self.auth_interceptor.for_shard(self.shard));
|
||||
let request = proto::CheckRelExistsRequest::from(req.clone());
|
||||
let response = ps_client.check_rel_exists(tonic::Request::new(request)).await;
|
||||
|
||||
match response {
|
||||
Err(status) => {
|
||||
pooled_client.finish(Err(status.clone())).await; // Pass error to finish
|
||||
continue;
|
||||
}
|
||||
Ok(resp) => {
|
||||
pooled_client.finish(Ok(())).await; // Pass success to finish
|
||||
return Ok(resp.get_ref().exists);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send_process_get_rel_size_request(
|
||||
&self,
|
||||
req: GetRelSizeRequest,
|
||||
) -> Result<u32, tonic::Status> {
|
||||
loop {
|
||||
// Current sharding model assumes that all metadata is present only at shard 0.
|
||||
let unary_pool = Arc::clone(&self.unary_pool);
|
||||
let pooled_client = unary_pool.get_client().await.unwrap();
|
||||
let channel = pooled_client.channel();
|
||||
let mut ps_client = PageServiceClient::with_interceptor(channel, self.auth_interceptor.for_shard(self.shard));
|
||||
|
||||
let request = proto::GetRelSizeRequest::from(req.clone());
|
||||
let response = ps_client.get_rel_size(tonic::Request::new(request)).await;
|
||||
|
||||
match response {
|
||||
Err(status) => {
|
||||
pooled_client.finish(Err(status.clone())).await; // Pass error to finish
|
||||
continue;
|
||||
}
|
||||
Ok(resp) => {
|
||||
pooled_client.finish(Ok(())).await; // Pass success to finish
|
||||
return Ok(resp.get_ref().num_blocks);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send_process_get_dbsize_request(
|
||||
&self,
|
||||
req: GetDbSizeRequest,
|
||||
) -> Result<u64, tonic::Status> {
|
||||
loop {
|
||||
// Current sharding model assumes that all metadata is present only at shard 0.
|
||||
let unary_pool = Arc::clone(&self.unary_pool);
|
||||
let pooled_client = unary_pool.get_client().await.unwrap();let channel = pooled_client.channel();
|
||||
let mut ps_client = PageServiceClient::with_interceptor(channel, self.auth_interceptor.for_shard(self.shard));
|
||||
|
||||
let request = proto::GetDbSizeRequest::from(req.clone());
|
||||
let response = ps_client.get_db_size(tonic::Request::new(request)).await;
|
||||
|
||||
match response {
|
||||
Err(status) => {
|
||||
pooled_client.finish(Err(status.clone())).await; // Pass error to finish
|
||||
continue;
|
||||
}
|
||||
Ok(resp) => {
|
||||
pooled_client.finish(Ok(())).await; // Pass success to finish
|
||||
return Ok(resp.get_ref().num_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send_getpage_request(
|
||||
&mut self,
|
||||
req: GetPageRequest,
|
||||
) -> Result<GetPageResponse, tonic::Status> {
|
||||
loop {
|
||||
let mut request = req.clone();
|
||||
// Increment cur_id
|
||||
//let request_id = self.cur_id.fetch_add(1, Ordering::SeqCst) + 1;
|
||||
let request_id = request.request_id;
|
||||
let response_sender: tokio::sync::mpsc::Sender<Result<proto::GetPageResponse, Status>>;
|
||||
let mut response_receiver: tokio::sync::mpsc::Receiver<Result<proto::GetPageResponse, Status>>;
|
||||
|
||||
(response_sender, response_receiver) = tokio::sync::mpsc::channel(1);
|
||||
//request.request_id = request_id;
|
||||
|
||||
// Get a stream from the stream pool
|
||||
let pool_clone = Arc::clone(&self.stream_pool);
|
||||
let sender_stream_pool = pool_clone.get_client().await;
|
||||
let stream_returner = match sender_stream_pool {
|
||||
Ok(stream_ret) => stream_ret,
|
||||
Err(_e) => {
|
||||
// retry
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let returner = stream_returner.channel();
|
||||
let map = returner.sender_hashmap.clone();
|
||||
// Insert the response sender into the hashmap
|
||||
{
|
||||
let mut map_inner = map.lock().await;
|
||||
map_inner.insert(request_id, response_sender);
|
||||
}
|
||||
let sent = returner.sender.send(proto::GetPageRequest::from(request))
|
||||
.await;
|
||||
|
||||
if let Err(_e) = sent {
|
||||
// Remove the request from the map if sending failed
|
||||
{
|
||||
let mut map_inner = map.lock().await;
|
||||
// remove from hashmap
|
||||
map_inner.remove(&request_id);
|
||||
}
|
||||
stream_returner.finish(Err(Status::new(Code::Unknown,
|
||||
"Failed to send request"))).await;
|
||||
continue;
|
||||
}
|
||||
|
||||
let response: Option<Result<proto::GetPageResponse, Status>>;
|
||||
response = response_receiver.recv().await;
|
||||
match response {
|
||||
Some (resp) => {
|
||||
match resp {
|
||||
Err(_status) => {
|
||||
// Handle the case where the response was not received
|
||||
stream_returner.finish(Err(Status::new(Code::Unknown,
|
||||
"Failed to receive response"))).await;
|
||||
continue;
|
||||
},
|
||||
Ok(resp) => {
|
||||
stream_returner.finish(Result::Ok(())).await;
|
||||
return Ok(resp.clone().into());
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
// Handle the case where the response channel was closed
|
||||
stream_returner.finish(Err(Status::new(Code::Unknown,
|
||||
"Response channel closed"))).await;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ShardedRequestTrackerInner {
|
||||
// Hashmap of shard index to RequestTracker
|
||||
trackers: std::collections::HashMap<ShardIndex, RequestTracker>,
|
||||
}
|
||||
pub struct ShardedRequestTracker {
|
||||
inner: Arc<Mutex<ShardedRequestTrackerInner>>,
|
||||
tcp_client_cache_options: ClientCacheOptions,
|
||||
stream_client_cache_options: ClientCacheOptions,
|
||||
}
|
||||
|
||||
//
|
||||
// TODO: Functions in the ShardedRequestTracker should be able to timeout and
|
||||
// cancel a reqeust. The request should return an error if it is cancelled.
|
||||
//
|
||||
impl ShardedRequestTracker {
|
||||
pub fn new() -> Self {
|
||||
//
|
||||
// Default configuration for the client. These could be added to a config file
|
||||
//
|
||||
let tcp_client_cache_options = ClientCacheOptions {
|
||||
max_delay_ms: 0,
|
||||
drop_rate: 0.0,
|
||||
hang_rate: 0.0,
|
||||
connect_timeout: Duration::from_secs(1),
|
||||
connect_backoff: Duration::from_millis(100),
|
||||
max_consumers: 8, // Streams per connection
|
||||
error_threshold: 10,
|
||||
max_idle_duration: Duration::from_secs(5),
|
||||
max_total_connections: 8,
|
||||
};
|
||||
let stream_client_cache_options = ClientCacheOptions {
|
||||
max_delay_ms: 0,
|
||||
drop_rate: 0.0,
|
||||
hang_rate: 0.0,
|
||||
connect_timeout: Duration::from_secs(1),
|
||||
connect_backoff: Duration::from_millis(100),
|
||||
max_consumers: 64, // Requests per stream
|
||||
error_threshold: 10,
|
||||
max_idle_duration: Duration::from_secs(5),
|
||||
max_total_connections: 64, // Total allowable number of streams
|
||||
};
|
||||
ShardedRequestTracker {
|
||||
inner: Arc::new(Mutex::new(ShardedRequestTrackerInner {
|
||||
trackers: std::collections::HashMap::new(),
|
||||
})),
|
||||
tcp_client_cache_options,
|
||||
stream_client_cache_options,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn update_shard_map(&self,
|
||||
shard_urls: std::collections::HashMap<ShardIndex, String>,
|
||||
metrics: Option<Arc<PageserverClientAggregateMetrics>>,
|
||||
tenant_id: String, timeline_id: String, auth_str: Option<&str>) {
|
||||
|
||||
|
||||
let mut trackers = std::collections::HashMap::new();
|
||||
for (shard, endpoint_url) in shard_urls {
|
||||
//
|
||||
// Create a pool of streams for streaming get_page requests
|
||||
//
|
||||
let channel_fact : Arc<dyn PooledItemFactory<Channel> + Send + Sync> = Arc::new(ChannelFactory::new(
|
||||
endpoint_url.clone(),
|
||||
self.tcp_client_cache_options.max_delay_ms,
|
||||
self.tcp_client_cache_options.drop_rate,
|
||||
self.tcp_client_cache_options.hang_rate,
|
||||
));
|
||||
let new_pool: Arc<ConnectionPool<Channel>>;
|
||||
new_pool = ConnectionPool::new(
|
||||
Arc::clone(&channel_fact),
|
||||
self.tcp_client_cache_options.connect_timeout,
|
||||
self.tcp_client_cache_options.connect_backoff,
|
||||
self.tcp_client_cache_options.max_consumers,
|
||||
self.tcp_client_cache_options.error_threshold,
|
||||
self.tcp_client_cache_options.max_idle_duration,
|
||||
self.tcp_client_cache_options.max_total_connections,
|
||||
metrics.clone(),
|
||||
);
|
||||
|
||||
let auth_interceptor = AuthInterceptor::new(tenant_id.as_str(),
|
||||
timeline_id.as_str(),
|
||||
auth_str);
|
||||
|
||||
let stream_pool = ConnectionPool::<StreamReturner>::new(
|
||||
Arc::new(StreamFactory::new(new_pool.clone(),
|
||||
auth_interceptor.clone(), ShardIndex::unsharded())),
|
||||
self.stream_client_cache_options.connect_timeout,
|
||||
self.stream_client_cache_options.connect_backoff,
|
||||
self.stream_client_cache_options.max_consumers,
|
||||
self.stream_client_cache_options.error_threshold,
|
||||
self.stream_client_cache_options.max_idle_duration,
|
||||
self.stream_client_cache_options.max_total_connections,
|
||||
metrics.clone(),
|
||||
);
|
||||
|
||||
//
|
||||
// Create a client pool for unary requests
|
||||
//
|
||||
|
||||
let unary_pool: Arc<ConnectionPool<Channel>>;
|
||||
unary_pool = ConnectionPool::new(
|
||||
Arc::clone(&channel_fact),
|
||||
self.tcp_client_cache_options.connect_timeout,
|
||||
self.tcp_client_cache_options.connect_backoff,
|
||||
self.tcp_client_cache_options.max_consumers,
|
||||
self.tcp_client_cache_options.error_threshold,
|
||||
self.tcp_client_cache_options.max_idle_duration,
|
||||
self.tcp_client_cache_options.max_total_connections,
|
||||
metrics.clone()
|
||||
);
|
||||
//
|
||||
// Create a new RequestTracker for this shard
|
||||
//
|
||||
let new_tracker = RequestTracker::new(stream_pool, unary_pool, auth_interceptor, shard);
|
||||
trackers.insert(shard, new_tracker);
|
||||
}
|
||||
let mut inner = self.inner.lock().await;
|
||||
inner.trackers = trackers;
|
||||
}
|
||||
|
||||
pub async fn get_page(
|
||||
&self,
|
||||
req: GetPageRequest,
|
||||
) -> Result<GetPageResponse, tonic::Status> {
|
||||
|
||||
// Get shard index from the request
|
||||
let shard_index = ShardIndex::unsharded();
|
||||
let inner = self.inner.lock().await;
|
||||
let mut tracker : RequestTracker;
|
||||
if let Some(t) = inner.trackers.get(&shard_index) {
|
||||
tracker = t.clone();
|
||||
} else {
|
||||
return Err(tonic::Status::not_found(format!("Shard {} not found", shard_index)));
|
||||
}
|
||||
drop(inner);
|
||||
// Call the send_getpage_request method on the tracker
|
||||
let response = tracker.send_getpage_request(req).await;
|
||||
match response {
|
||||
Ok(resp) => Ok(resp),
|
||||
Err(e) => Err(tonic::Status::unknown(format!("Failed to get page: {}", e))),
|
||||
}
|
||||
}
|
||||
pub async fn process_get_dbsize_request(
|
||||
&self,
|
||||
request: GetDbSizeRequest,
|
||||
) -> Result<u64, tonic::Status> {
|
||||
let shard_index = ShardIndex::unsharded();
|
||||
let inner = self.inner.lock().await;
|
||||
let mut tracker: RequestTracker;
|
||||
if let Some(t) = inner.trackers.get(&shard_index) {
|
||||
tracker = t.clone();
|
||||
} else {
|
||||
return Err(tonic::Status::not_found(format!("Shard {} not found", shard_index)));
|
||||
}
|
||||
drop(inner); // Release the lock before calling send_process_get_dbsize_request
|
||||
// Call the send_process_get_dbsize_request method on the tracker
|
||||
let response = tracker.send_process_get_dbsize_request(request).await;
|
||||
match response {
|
||||
Ok(resp) => Ok(resp),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn process_get_rel_size_request(
|
||||
&self,
|
||||
request: GetRelSizeRequest,
|
||||
) -> Result<u32, tonic::Status> {
|
||||
let shard_index = ShardIndex::unsharded();
|
||||
let inner = self.inner.lock().await;
|
||||
let mut tracker: RequestTracker;
|
||||
if let Some(t) = inner.trackers.get(&shard_index) {
|
||||
tracker = t.clone();
|
||||
} else {
|
||||
return Err(tonic::Status::not_found(format!("Shard {} not found", shard_index)));
|
||||
}
|
||||
drop(inner); // Release the lock before calling send_process_get_rel_size_request
|
||||
// Call the send_process_get_rel_size_request method on the tracker
|
||||
let response = tracker.send_process_get_rel_size_request(request).await;
|
||||
match response {
|
||||
Ok(resp) => Ok(resp),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn process_check_rel_exists_request(
|
||||
&self,
|
||||
request: CheckRelExistsRequest,
|
||||
) -> Result<bool, tonic::Status> {
|
||||
let shard_index = ShardIndex::unsharded();
|
||||
let inner = self.inner.lock().await;
|
||||
let mut tracker: RequestTracker;
|
||||
if let Some(t) = inner.trackers.get(&shard_index) {
|
||||
tracker = t.clone();
|
||||
} else {
|
||||
return Err(tonic::Status::not_found(format!("Shard {} not found", shard_index)));
|
||||
}
|
||||
drop(inner); // Release the lock before calling send_process_check_rel_exists_request
|
||||
// Call the send_process_check_rel_exists_request method on the tracker
|
||||
let response = tracker.send_process_check_rel_exists_request(request).await;
|
||||
match response {
|
||||
Ok(resp) => Ok(resp),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -487,6 +487,7 @@ impl From<GetPageStatusCode> for i32 {
|
||||
|
||||
// Fetches the size of a relation at a given LSN, as # of blocks. Only valid on shard 0, other
|
||||
// shards will error.
|
||||
#[derive(Clone)]
|
||||
pub struct GetRelSizeRequest {
|
||||
pub read_lsn: ReadLsn,
|
||||
pub rel: RelTag,
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
use std::collections::{HashMap, HashSet, VecDeque};
|
||||
use std::collections::{HashSet, HashMap, VecDeque};
|
||||
use std::future::Future;
|
||||
use std::num::NonZeroUsize;
|
||||
use std::pin::Pin;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{Duration, Instant};
|
||||
use std::io::Error;
|
||||
|
||||
use anyhow::Context;
|
||||
use async_trait::async_trait;
|
||||
@@ -23,6 +24,8 @@ use tracing::info;
|
||||
use utils::id::TenantTimelineId;
|
||||
use utils::lsn::Lsn;
|
||||
|
||||
use tonic::transport::Channel;
|
||||
|
||||
use axum::Router;
|
||||
use axum::body::Body;
|
||||
use axum::extract::State;
|
||||
@@ -427,6 +430,7 @@ async fn main_impl(
|
||||
.await
|
||||
.unwrap(),
|
||||
),
|
||||
|
||||
};
|
||||
run_worker(args, client, ss, cancel, rps_period, ranges, weights).await
|
||||
})
|
||||
@@ -694,6 +698,7 @@ impl Client for LibpqClient {
|
||||
struct GrpcClient {
|
||||
req_tx: tokio::sync::mpsc::Sender<proto::GetPageRequest>,
|
||||
resp_rx: tonic::Streaming<proto::GetPageResponse>,
|
||||
start_times: Vec<Instant>,
|
||||
}
|
||||
|
||||
impl GrpcClient {
|
||||
@@ -717,6 +722,7 @@ impl GrpcClient {
|
||||
Ok(Self {
|
||||
req_tx,
|
||||
resp_rx: resp_stream,
|
||||
start_times: Vec::new(),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -741,6 +747,7 @@ impl Client for GrpcClient {
|
||||
rel: Some(rel.into()),
|
||||
block_number: blks,
|
||||
};
|
||||
self.start_times.push(Instant::now());
|
||||
self.req_tx.send(req).await?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -755,3 +762,4 @@ impl Client for GrpcClient {
|
||||
Ok((resp.request_id, resp.page_image))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ use crate::integrated_cache::{CacheResult, IntegratedCacheWriteAccess};
|
||||
use crate::neon_request::{CGetPageVRequest, CPrefetchVRequest};
|
||||
use crate::neon_request::{NeonIORequest, NeonIOResult};
|
||||
use crate::worker_process::in_progress_ios::{RequestInProgressKey, RequestInProgressTable};
|
||||
use pageserver_client_grpc::PageserverClient;
|
||||
use pageserver_client_grpc::request_tracker::ShardedRequestTracker;
|
||||
use pageserver_page_api as page_api;
|
||||
|
||||
use metrics::{IntCounter, IntCounterVec};
|
||||
@@ -30,7 +30,7 @@ use utils::lsn::Lsn;
|
||||
pub struct CommunicatorWorkerProcessStruct<'a> {
|
||||
neon_request_slots: &'a [NeonIOHandle],
|
||||
|
||||
pageserver_client: PageserverClient,
|
||||
request_tracker: ShardedRequestTracker,
|
||||
|
||||
pub(crate) cache: IntegratedCacheWriteAccess<'a>,
|
||||
|
||||
@@ -74,6 +74,7 @@ pub(super) async fn init(
|
||||
initial_file_cache_size: u64,
|
||||
file_cache_path: Option<PathBuf>,
|
||||
) -> CommunicatorWorkerProcessStruct<'static> {
|
||||
info!("Test log message");
|
||||
let last_lsn = get_request_lsn();
|
||||
|
||||
let file_cache = if let Some(path) = file_cache_path {
|
||||
@@ -97,7 +98,12 @@ pub(super) async fn init(
|
||||
.integrated_cache_init_struct
|
||||
.worker_process_init(last_lsn, file_cache);
|
||||
|
||||
let pageserver_client = PageserverClient::new(&tenant_id, &timeline_id, &auth_token, shard_map);
|
||||
let mut request_tracker = ShardedRequestTracker::new();
|
||||
request_tracker.update_shard_map(shard_map,
|
||||
None,
|
||||
tenant_id,
|
||||
timeline_id,
|
||||
auth_token.as_deref()).await;
|
||||
|
||||
let request_counters = IntCounterVec::new(
|
||||
metrics::core::Opts::new(
|
||||
@@ -148,7 +154,7 @@ pub(super) async fn init(
|
||||
|
||||
CommunicatorWorkerProcessStruct {
|
||||
neon_request_slots: cis.neon_request_slots,
|
||||
pageserver_client,
|
||||
request_tracker,
|
||||
cache,
|
||||
submission_pipe_read_fd: cis.submission_pipe_read_fd,
|
||||
next_request_id: AtomicU64::new(1),
|
||||
@@ -257,7 +263,7 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> {
|
||||
};
|
||||
|
||||
match self
|
||||
.pageserver_client
|
||||
.request_tracker
|
||||
.process_check_rel_exists_request(page_api::CheckRelExistsRequest {
|
||||
read_lsn: self.request_lsns(not_modified_since),
|
||||
rel,
|
||||
@@ -291,7 +297,7 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> {
|
||||
|
||||
let read_lsn = self.request_lsns(not_modified_since);
|
||||
match self
|
||||
.pageserver_client
|
||||
.request_tracker
|
||||
.process_get_rel_size_request(page_api::GetRelSizeRequest {
|
||||
read_lsn,
|
||||
rel: rel.clone(),
|
||||
@@ -344,7 +350,7 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> {
|
||||
};
|
||||
|
||||
match self
|
||||
.pageserver_client
|
||||
.request_tracker
|
||||
.process_get_dbsize_request(page_api::GetDbSizeRequest {
|
||||
read_lsn: self.request_lsns(not_modified_since),
|
||||
db_oid: req.db_oid,
|
||||
@@ -467,7 +473,7 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> {
|
||||
// TODO: Use batched protocol
|
||||
for (blkno, _lsn, dest, _guard) in cache_misses.iter() {
|
||||
match self
|
||||
.pageserver_client
|
||||
.request_tracker
|
||||
.get_page(page_api::GetPageRequest {
|
||||
request_id: self.next_request_id.fetch_add(1, Ordering::Relaxed),
|
||||
request_class: page_api::GetPageClass::Normal,
|
||||
@@ -477,11 +483,11 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> {
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(page_images) => {
|
||||
Ok(resp) => {
|
||||
// Write the received page image directly to the shared memory location
|
||||
// that the backend requested.
|
||||
assert!(page_images.len() == 1);
|
||||
let page_image = page_images[0].clone();
|
||||
assert!(resp.page_images.len() == 1);
|
||||
let page_image = resp.page_images[0].clone();
|
||||
let src: &[u8] = page_image.as_ref();
|
||||
let len = std::cmp::min(src.len(), dest.bytes_total() as usize);
|
||||
unsafe {
|
||||
@@ -545,7 +551,7 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> {
|
||||
// TODO: Use batched protocol
|
||||
for (blkno, _lsn, _guard) in cache_misses.iter() {
|
||||
match self
|
||||
.pageserver_client
|
||||
.request_tracker
|
||||
.get_page(page_api::GetPageRequest {
|
||||
request_id: self.next_request_id.fetch_add(1, Ordering::Relaxed),
|
||||
request_class: page_api::GetPageClass::Prefetch,
|
||||
@@ -555,13 +561,13 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> {
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(page_images) => {
|
||||
Ok(resp) => {
|
||||
trace!(
|
||||
"prefetch completed, remembering blk {} in rel {:?} in LFC",
|
||||
*blkno, rel
|
||||
);
|
||||
assert!(page_images.len() == 1);
|
||||
let page_image = page_images[0].clone();
|
||||
assert!(resp.page_images.len() == 1);
|
||||
let page_image = resp.page_images[0].clone();
|
||||
self.cache
|
||||
.remember_page(&rel, *blkno, page_image, not_modified_since, false)
|
||||
.await;
|
||||
|
||||
Reference in New Issue
Block a user