Fix a bunch of linter warnings

This commit is contained in:
Erik Grinaker
2025-06-30 11:10:02 +02:00
parent 9d9e3cd08a
commit 67b04f8ab3
28 changed files with 228 additions and 404 deletions

View File

@@ -48,18 +48,11 @@ impl Drop for MockConnection {
}
}
#[derive(Default)]
pub struct MockConnectionFactory {
counter: AtomicU64,
}
impl MockConnectionFactory {
pub fn new() -> Self {
MockConnectionFactory {
counter: AtomicU64::new(1),
}
}
}
#[async_trait::async_trait]
impl PooledItemFactory<MockConnection> for MockConnectionFactory {
/// The trait on ConnectionPool expects:
@@ -171,7 +164,7 @@ async fn main() {
// --------------------------------------
// 1. Create factory and shared instrumentation
// --------------------------------------
let factory = Arc::new(MockConnectionFactory::new());
let factory = Arc::new(MockConnectionFactory::default());
// Shared map: connection ID → Arc<AtomicUsize>
let usage_map: Arc<Mutex<HashMap<u64, Arc<AtomicUsize>>>> =

View File

@@ -1,137 +0,0 @@
// examples/request_tracker_load_test.rs
use pageserver_client_grpc::AuthInterceptor;
use pageserver_client_grpc::ClientCacheOptions;
use pageserver_client_grpc::PageserverClientAggregateMetrics;
use pageserver_client_grpc::client_cache::ConnectionPool;
use pageserver_client_grpc::client_cache::PooledItemFactory;
use pageserver_client_grpc::request_tracker::MockStreamFactory;
use pageserver_client_grpc::request_tracker::RequestTracker;
use pageserver_client_grpc::request_tracker::StreamReturner;
use std::{sync::Arc, time::Duration};
use tokio;
use pageserver_client_grpc::client_cache::ChannelFactory;
use tonic::transport::Channel;
use rand::prelude::*;
use pageserver_api::key::Key;
use utils::lsn::Lsn;
use utils::shard::ShardIndex;
use futures::StreamExt;
use futures::stream::FuturesOrdered;
use pageserver_page_api::proto;
#[tokio::main]
async fn main() {
// 1) configure the clientpool behavior
let client_cache_options = ClientCacheOptions {
max_delay_ms: 0,
drop_rate: 0.0,
hang_rate: 0.0,
connect_timeout: Duration::from_secs(10),
connect_backoff: Duration::from_millis(200),
max_consumers: 64,
error_threshold: 10,
max_idle_duration: Duration::from_secs(60),
max_total_connections: 12,
};
// 2) metrics collector (we assume Default is implemented)
let metrics = Arc::new(PageserverClientAggregateMetrics::new());
let pool = ConnectionPool::<StreamReturner>::new(
Arc::new(MockStreamFactory::new()),
client_cache_options.connect_timeout,
client_cache_options.connect_backoff,
client_cache_options.max_consumers,
client_cache_options.error_threshold,
client_cache_options.max_idle_duration,
client_cache_options.max_total_connections,
Some(Arc::clone(&metrics)),
);
// -----------
// There is no mock for the unary connection pool, so for now just
// don't use this pool
//
let channel_fact: Arc<dyn PooledItemFactory<Channel> + Send + Sync> =
Arc::new(ChannelFactory::new(
"".to_string(),
client_cache_options.max_delay_ms,
client_cache_options.drop_rate,
client_cache_options.hang_rate,
));
let unary_pool: Arc<ConnectionPool<Channel>> = ConnectionPool::new(
Arc::clone(&channel_fact),
client_cache_options.connect_timeout,
client_cache_options.connect_backoff,
client_cache_options.max_consumers,
client_cache_options.error_threshold,
client_cache_options.max_idle_duration,
client_cache_options.max_total_connections,
Some(Arc::clone(&metrics)),
);
// -----------
// Dummy auth interceptor. This is not used in this test.
let auth_interceptor = AuthInterceptor::new("dummy_tenant_id", "dummy_timeline_id", None);
let tracker = RequestTracker::new(pool, unary_pool, auth_interceptor, ShardIndex::unsharded());
// 4) fire off 10 000 requests in parallel
let mut handles = FuturesOrdered::new();
for _i in 0..500000 {
let mut rng = rand::thread_rng();
let r = 0..=1000000i128;
let key: i128 = rng.gen_range(r.clone());
let key = Key::from_i128(key);
let (rel_tag, block_no) = key
.to_rel_block()
.expect("we filter non-rel-block keys out above");
let req2 = proto::GetPageRequest {
request_id: 0,
request_class: proto::GetPageClass::Normal as i32,
read_lsn: Some(proto::ReadLsn {
request_lsn: if rng.gen_bool(0.5) {
u64::from(Lsn::MAX)
} else {
10000
},
not_modified_since_lsn: 10000,
}),
rel: Some(rel_tag.into()),
block_number: vec![block_no],
};
let req_model = pageserver_page_api::GetPageRequest::try_from(req2.clone());
// RequestTracker is Clone, so we can share it
let mut tr = tracker.clone();
let fut = async move {
let resp = tr.send_getpage_request(req_model.unwrap()).await.unwrap();
// sanitycheck: the mock echo returns the same request_id
assert!(resp.request_id > 0);
};
handles.push_back(fut);
// empty future
let fut = async move {};
fut.await;
}
// print timestamp
println!("Starting 5000000 requests at: {}", chrono::Utc::now());
// 5) wait for them all
for _i in 0..500000 {
handles.next().await.expect("Failed to get next handle");
}
// print timestamp
println!("Finished 5000000 requests at: {}", chrono::Utc::now());
println!("✅ All 100000 requests completed successfully");
}

View File

@@ -217,10 +217,7 @@ impl PooledItemFactory<Channel> for ChannelFactory {
// Random drop (connect error)
if drop_rate > 0.0 && rng.gen_bool(drop_rate) {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"simulated connect drop",
));
return Err(std::io::Error::other("simulated connect drop"));
}
// Otherwise perform real TCP connect
@@ -309,6 +306,7 @@ pub struct PooledClient<T> {
}
impl<T: Clone + Send + 'static> ConnectionPool<T> {
#[allow(clippy::too_many_arguments)]
pub fn new(
fact: Arc<dyn PooledItemFactory<T> + Send + Sync>,
connect_timeout: Duration,
@@ -391,14 +389,11 @@ impl<T: Clone + Send + 'static> ConnectionPool<T> {
&& now.duration_since(entry.last_used) > self.max_idle_duration
{
// metric
match self.aggregate_metrics {
Some(ref metrics) => {
metrics
.retry_counters
.with_label_values(&["connection_swept"])
.inc();
}
None => {}
if let Some(ref metrics) = self.aggregate_metrics {
metrics
.retry_counters
.with_label_values(&["connection_swept"])
.inc();
}
ids_to_remove.push(*id);
return false; // remove this entry
@@ -436,7 +431,7 @@ impl<T: Clone + Send + 'static> ConnectionPool<T> {
pool: Arc::clone(&self),
is_ok: true,
id,
permit: permit,
permit,
};
// reinsert with updated priority
@@ -444,7 +439,7 @@ impl<T: Clone + Send + 'static> ConnectionPool<T> {
if active_consumers < self.max_consumers {
inner.pq.push(id, active_consumers as usize);
}
return Some(client);
Some(client)
} else {
// If there is no connection to take, it is because permits for a connection
// need to drain. This can happen if a connection is removed because it has
@@ -453,7 +448,7 @@ impl<T: Clone + Send + 'static> ConnectionPool<T> {
//
// Just forget the permit and retry.
permit.forget();
return None;
None
}
}
@@ -485,14 +480,11 @@ impl<T: Clone + Send + 'static> ConnectionPool<T> {
}
}
Err(_) => {
match self_clone.aggregate_metrics {
Some(ref metrics) => {
metrics
.retry_counters
.with_label_values(&["sema_acquire_failed"])
.inc();
}
None => {}
if let Some(ref metrics) = self_clone.aggregate_metrics {
metrics
.retry_counters
.with_label_values(&["sema_acquire_success"])
.inc();
}
{
@@ -504,16 +496,15 @@ impl<T: Clone + Send + 'static> ConnectionPool<T> {
//
let mut inner = self_clone.inner.lock().await;
inner.waiters += 1;
if inner.waiters > (inner.in_progress * self_clone.max_consumers) {
if (inner.entries.len() + inner.in_progress)
if inner.waiters > (inner.in_progress * self_clone.max_consumers)
&& (inner.entries.len() + inner.in_progress)
< self_clone.max_total_connections
{
let self_clone_spawn = Arc::clone(&self_clone);
tokio::task::spawn(async move {
self_clone_spawn.create_connection().await;
});
inner.in_progress += 1;
}
{
let self_clone_spawn = Arc::clone(&self_clone);
tokio::task::spawn(async move {
self_clone_spawn.create_connection().await;
});
inner.in_progress += 1;
}
}
// Wait for a connection to become available, either because it
@@ -541,7 +532,7 @@ impl<T: Clone + Send + 'static> ConnectionPool<T> {
}
}
async fn create_connection(&self) -> () {
async fn create_connection(&self) {
// Generate a random backoff to add some jitter so that connections
// don't all retry at the same time.
let mut backoff_delay = Duration::from_millis(
@@ -558,17 +549,13 @@ impl<T: Clone + Send + 'static> ConnectionPool<T> {
// until the failure stopped for at least one backoff period. Backoff
// period includes some jitter, so that if multiple connections are
// failing, they don't all retry at the same time.
loop {
if let Some(delay) = {
let inner = self.inner.lock().await;
inner.last_connect_failure.and_then(|at| {
(at.elapsed() < backoff_delay).then(|| backoff_delay - at.elapsed())
})
} {
sleep(delay).await;
} else {
break; // No delay, so we can create a connection
}
while let Some(delay) = {
let inner = self.inner.lock().await;
inner.last_connect_failure.and_then(|at| {
(at.elapsed() < backoff_delay).then(|| backoff_delay - at.elapsed())
})
} {
sleep(delay).await;
}
//
@@ -578,14 +565,11 @@ impl<T: Clone + Send + 'static> ConnectionPool<T> {
// on this connection. (Requests made later on this channel will time out
// with the same timeout.)
//
match self.aggregate_metrics {
Some(ref metrics) => {
metrics
.retry_counters
.with_label_values(&["connection_attempt"])
.inc();
}
None => {}
if let Some(ref metrics) = self.aggregate_metrics {
metrics
.retry_counters
.with_label_values(&["connection_attempt"])
.inc();
}
let attempt = self.fact.create(self.connect_timeout).await;
@@ -594,14 +578,11 @@ impl<T: Clone + Send + 'static> ConnectionPool<T> {
// Connection succeeded
Ok(Ok(channel)) => {
{
match self.aggregate_metrics {
Some(ref metrics) => {
metrics
.retry_counters
.with_label_values(&["connection_success"])
.inc();
}
None => {}
if let Some(ref metrics) = self.aggregate_metrics {
metrics
.retry_counters
.with_label_values(&["connection_success"])
.inc();
}
let mut inner = self.inner.lock().await;
let id = uuid::Uuid::new_v4();
@@ -622,14 +603,11 @@ impl<T: Clone + Send + 'static> ConnectionPool<T> {
}
// Connection failed, back off and retry
Ok(Err(_)) | Err(_) => {
match self.aggregate_metrics {
Some(ref metrics) => {
metrics
.retry_counters
.with_label_values(&["connect_failed"])
.inc();
}
None => {}
if let Some(ref metrics) = self.aggregate_metrics {
metrics
.retry_counters
.with_label_values(&["connect_failed"])
.inc();
}
let mut inner = self.inner.lock().await;
inner.last_connect_failure = Some(Instant::now());
@@ -653,10 +631,10 @@ impl<T: Clone + Send + 'static> ConnectionPool<T> {
let mut inner = self.inner.lock().await;
if let Some(entry) = inner.entries.get_mut(&id) {
entry.last_used = Instant::now();
if entry.active_consumers <= 0 {
if entry.active_consumers == 0 {
panic!("A consumer completed when active_consumers was zero!")
}
entry.active_consumers = entry.active_consumers - 1;
entry.active_consumers -= 1;
if success {
if entry.consecutive_errors < self.error_threshold {
entry.consecutive_errors = 0;
@@ -664,14 +642,11 @@ impl<T: Clone + Send + 'static> ConnectionPool<T> {
} else {
entry.consecutive_errors += 1;
if entry.consecutive_errors == self.error_threshold {
match self.aggregate_metrics {
Some(ref metrics) => {
metrics
.retry_counters
.with_label_values(&["connection_dropped"])
.inc();
}
None => {}
if let Some(ref metrics) = self.aggregate_metrics {
metrics
.retry_counters
.with_label_values(&["connection_dropped"])
.inc();
}
}
}
@@ -719,7 +694,7 @@ impl<T: Clone + Send + 'static> ConnectionPool<T> {
impl<T: Clone + Send + 'static> PooledClient<T> {
pub fn channel(&self) -> T {
return self.channel.clone();
self.channel.clone()
}
pub async fn finish(mut self, result: Result<(), tonic::Status>) {
self.is_ok = result.is_ok();

View File

@@ -47,6 +47,13 @@ pub struct PageserverClientAggregateMetrics {
pub request_counters: IntCounterVec,
pub retry_counters: IntCounterVec,
}
impl Default for PageserverClientAggregateMetrics {
fn default() -> Self {
Self::new()
}
}
impl PageserverClientAggregateMetrics {
pub fn new() -> Self {
let request_counters = IntCounterVec::new(
@@ -167,11 +174,11 @@ impl PageserverClient {
match response {
Err(status) => {
pooled_client.finish(Err(status.clone())).await; // Pass error to finish
return Err(PageserverClientError::RequestError(status));
Err(PageserverClientError::RequestError(status))
}
Ok(resp) => {
pooled_client.finish(Ok(())).await; // Pass success to finish
return Ok(resp.get_ref().exists);
Ok(resp.get_ref().exists)
}
}
}
@@ -194,11 +201,11 @@ impl PageserverClient {
match response {
Err(status) => {
pooled_client.finish(Err(status.clone())).await; // Pass error to finish
return Err(PageserverClientError::RequestError(status));
Err(PageserverClientError::RequestError(status))
}
Ok(resp) => {
pooled_client.finish(Ok(())).await; // Pass success to finish
return Ok(resp.get_ref().num_blocks);
Ok(resp.get_ref().num_blocks)
}
}
}
@@ -233,25 +240,22 @@ impl PageserverClient {
));
};
match self.aggregate_metrics {
Some(ref metrics) => {
metrics
.request_counters
.with_label_values(&["get_page"])
.inc();
}
None => {}
if let Some(ref metrics) = self.aggregate_metrics {
metrics
.request_counters
.with_label_values(&["get_page"])
.inc();
}
match response {
Err(status) => {
pooled_client.finish(Err(status.clone())).await; // Pass error to finish
return Err(PageserverClientError::RequestError(status));
Err(PageserverClientError::RequestError(status))
}
Ok(resp) => {
pooled_client.finish(Ok(())).await; // Pass success to finish
let response: GetPageResponse = resp.into();
return Ok(response.page_images.to_vec());
Ok(response.page_images.to_vec())
}
}
}
@@ -280,11 +284,9 @@ impl PageserverClient {
match response {
Err(status) => {
pooled_client.finish(Err(status.clone())).await; // Pass error to finish
return Err(PageserverClientError::RequestError(status));
}
Ok(resp) => {
return Ok(resp);
Err(PageserverClientError::RequestError(status))
}
Ok(resp) => Ok(resp),
}
}
@@ -307,11 +309,11 @@ impl PageserverClient {
match response {
Err(status) => {
pooled_client.finish(Err(status.clone())).await; // Pass error to finish
return Err(PageserverClientError::RequestError(status));
Err(PageserverClientError::RequestError(status))
}
Ok(resp) => {
pooled_client.finish(Ok(())).await; // Pass success to finish
return Ok(resp.get_ref().num_bytes);
Ok(resp.get_ref().num_bytes)
}
}
}
@@ -342,11 +344,11 @@ impl PageserverClient {
match response {
Err(status) => {
pooled_client.finish(Err(status.clone())).await; // Pass error to finish
return Err(PageserverClientError::RequestError(status));
Err(PageserverClientError::RequestError(status))
}
Ok(resp) => {
pooled_client.finish(Ok(())).await; // Pass success to finish
return Ok(resp);
Ok(resp)
}
}
}
@@ -360,8 +362,7 @@ impl PageserverClient {
channels.get(&shard).cloned()
};
let usable_pool: Arc<client_cache::ConnectionPool<Channel>>;
match reused_pool {
let usable_pool = match reused_pool {
Some(pool) => {
let pooled_client = pool.get_client().await.unwrap();
return pooled_client;
@@ -370,14 +371,13 @@ impl PageserverClient {
// Create a new pool using client_cache_options
// declare new_pool
let new_pool: Arc<client_cache::ConnectionPool<Channel>>;
let channel_fact = Arc::new(client_cache::ChannelFactory::new(
self.shard_map.get(&shard).unwrap().clone(),
self.client_cache_options.max_delay_ms,
self.client_cache_options.drop_rate,
self.client_cache_options.hang_rate,
));
new_pool = client_cache::ConnectionPool::new(
let new_pool = client_cache::ConnectionPool::new(
channel_fact,
self.client_cache_options.connect_timeout,
self.client_cache_options.connect_backoff,
@@ -389,12 +389,11 @@ impl PageserverClient {
);
let mut write_pool = self.channels.write().unwrap();
write_pool.insert(shard, new_pool.clone());
usable_pool = new_pool.clone();
new_pool.clone()
}
}
};
let pooled_client = usable_pool.get_client().await.unwrap();
return pooled_client;
usable_pool.get_client().await.unwrap()
}
}

View File

@@ -41,12 +41,15 @@ use client_cache::PooledItemFactory;
#[derive(Clone)]
pub struct StreamReturner {
sender: tokio::sync::mpsc::Sender<proto::GetPageRequest>,
#[allow(clippy::type_complexity)]
sender_hashmap: Arc<
tokio::sync::Mutex<Option<
std::collections::HashMap<
u64,
tokio::sync::mpsc::Sender<Result<proto::GetPageResponse, Status>>,
>>,
tokio::sync::Mutex<
Option<
std::collections::HashMap<
u64,
tokio::sync::mpsc::Sender<Result<proto::GetPageResponse, Status>>,
>,
>,
>,
>,
}
@@ -101,9 +104,9 @@ impl PooledItemFactory<StreamReturner> for StreamFactory {
Ok(resp) => {
let stream_returner = StreamReturner {
sender: sender.clone(),
sender_hashmap: Arc::new(tokio::sync::Mutex::new(
Some(std::collections::HashMap::new()),
)),
sender_hashmap: Arc::new(tokio::sync::Mutex::new(Some(
std::collections::HashMap::new(),
))),
};
let map = Arc::clone(&stream_returner.sender_hashmap);
@@ -122,7 +125,8 @@ impl PooledItemFactory<StreamReturner> for StreamFactory {
Ok(Some(response)) => {
// look up stream in hash map
let mut hashmap = map_clone.lock().await;
let hashmap = hashmap.as_mut().expect("no other task clears the hashmap");
let hashmap =
hashmap.as_mut().expect("no other task clears the hashmap");
if let Some(sender) = hashmap.get(&response.request_id) {
// Send the response to the original request sender
if let Err(e) = sender.send(Ok(response.clone())).await {
@@ -130,7 +134,10 @@ impl PooledItemFactory<StreamReturner> for StreamFactory {
}
hashmap.remove(&response.request_id);
} else {
eprintln!("No sender found for request ID: {}", response.request_id);
eprintln!(
"No sender found for request ID: {}",
response.request_id
);
}
}
}
@@ -139,7 +146,9 @@ impl PooledItemFactory<StreamReturner> for StreamFactory {
// Close every sender stream in the hashmap
let mut hashmap_opt = map_clone.lock().await;
let hashmap = hashmap_opt.as_mut().expect("no other task clears the hashmap");
let hashmap = hashmap_opt
.as_mut()
.expect("no other task clears the hashmap");
for sender in hashmap.values() {
let error = Status::new(Code::Unknown, "Stream closed");
if let Err(e) = sender.send(Err(error)).await {
@@ -175,10 +184,10 @@ impl RequestTracker {
RequestTracker {
_cur_id: cur_id.clone(),
stream_pool: stream_pool,
unary_pool: unary_pool,
auth_interceptor: auth_interceptor,
shard: shard.clone(),
stream_pool,
unary_pool,
auth_interceptor,
shard,
}
}
@@ -194,7 +203,7 @@ impl RequestTracker {
channel,
self.auth_interceptor.for_shard(self.shard),
);
let request = proto::CheckRelExistsRequest::from(req.clone());
let request = proto::CheckRelExistsRequest::from(req);
let response = ps_client
.check_rel_exists(tonic::Request::new(request))
.await;
@@ -226,7 +235,7 @@ impl RequestTracker {
self.auth_interceptor.for_shard(self.shard),
);
let request = proto::GetRelSizeRequest::from(req.clone());
let request = proto::GetRelSizeRequest::from(req);
let response = ps_client.get_rel_size(tonic::Request::new(request)).await;
match response {
@@ -256,7 +265,7 @@ impl RequestTracker {
self.auth_interceptor.for_shard(self.shard),
);
let request = proto::GetDbSizeRequest::from(req.clone());
let request = proto::GetDbSizeRequest::from(req);
let response = ps_client.get_db_size(tonic::Request::new(request)).await;
match response {
@@ -335,8 +344,7 @@ impl RequestTracker {
continue;
}
let response: Option<Result<proto::GetPageResponse, Status>>;
response = response_receiver.recv().await;
let response = response_receiver.recv().await;
match response {
Some(resp) => {
match resp {
@@ -382,6 +390,13 @@ pub struct ShardedRequestTracker {
// TODO: Functions in the ShardedRequestTracker should be able to timeout and
// cancel a reqeust. The request should return an error if it is cancelled.
//
impl Default for ShardedRequestTracker {
fn default() -> Self {
ShardedRequestTracker::new()
}
}
impl ShardedRequestTracker {
pub fn new() -> Self {
//
@@ -438,8 +453,7 @@ impl ShardedRequestTracker {
self.tcp_client_cache_options.drop_rate,
self.tcp_client_cache_options.hang_rate,
));
let new_pool: Arc<ConnectionPool<Channel>>;
new_pool = ConnectionPool::new(
let new_pool = ConnectionPool::new(
Arc::clone(&channel_fact),
self.tcp_client_cache_options.connect_timeout,
self.tcp_client_cache_options.connect_backoff,
@@ -472,8 +486,7 @@ impl ShardedRequestTracker {
// Create a client pool for unary requests
//
let unary_pool: Arc<ConnectionPool<Channel>>;
unary_pool = ConnectionPool::new(
let unary_pool = ConnectionPool::new(
Arc::clone(&channel_fact),
self.tcp_client_cache_options.connect_timeout,
self.tcp_client_cache_options.connect_backoff,
@@ -547,6 +560,7 @@ impl ShardedRequestTracker {
}
}
#[allow(clippy::result_large_err)]
fn lookup_tracker_for_shard(
&self,
shard_index: ShardIndex,

View File

@@ -34,7 +34,6 @@ use axum::response::Response;
use http::StatusCode;
use http::header::CONTENT_TYPE;
use metrics;
use metrics::proto::MetricFamily;
use metrics::{Encoder, TextEncoder};