Response to review comments, code cleanup.

This commit is contained in:
Elizabeth Murray
2025-05-28 12:40:21 -07:00
parent 13b9d4cb67
commit 7160fd16cd
4 changed files with 53 additions and 45 deletions

2
Cargo.lock generated
View File

@@ -4442,7 +4442,6 @@ dependencies = [
"pageserver_page_api",
"thiserror 1.0.69",
"tokio",
"tokio-util",
"tonic 0.13.1",
"tracing",
"utils",
@@ -7554,7 +7553,6 @@ dependencies = [
"axum",
"base64 0.22.1",
"bytes",
"flate2",
"h2 0.4.4",
"http 1.1.0",
"http-body 1.0.0",

View File

@@ -200,7 +200,7 @@ tokio-tar = "0.3"
tokio-util = { version = "0.7.10", features = ["io", "rt"] }
toml = "0.8"
toml_edit = "0.22"
tonic = { version = "0.13.1", default-features = false, features = ["gzip", "channel", "codegen", "prost", "router", "server", "tls-ring", "tls-native-roots"] }
tonic = { version = "0.13.1", default-features = false, features = ["channel", "codegen", "prost", "router", "server", "tls-ring", "tls-native-roots"] }
tonic-reflection = { version = "0.13.1", features = ["server"] }
tower = { version = "0.5.2", default-features = false }
tower-http = { version = "0.6.2", features = ["auth", "request-id", "trace"] }

View File

@@ -1,7 +1,8 @@
[package]
name = "pageserver_client_grpc"
version = "0.1.0"
edition = "2024"
edition.workspace = true
license.workspace = true
[dependencies]
bytes.workspace = true
@@ -13,4 +14,3 @@ tracing.workspace = true
pageserver_page_api.workspace = true
utils.workspace = true
tokio.workspace = true
tokio-util = { version = "0.7", features = ["compat"] }

View File

@@ -1,13 +1,12 @@
//
// Pageserver gRPC client library
//
// This library provides a gRPC client for the pageserver for the
// communicator project.
//
// This library is a work in progress.
//
// TODO: This should properly use the shard map
//
//!
//! Pageserver gRPC client library
//!
//! This library provides a gRPC client for the pageserver for the
//! communicator project.
//!
//! This library is a work in progress.
//!
//!
use std::collections::HashMap;
use bytes::Bytes;
@@ -38,7 +37,7 @@ pub enum PageserverClientError {
}
pub struct PageserverClient {
shard_map: HashMap<ShardIndex, String>,
endpoint_map: HashMap<ShardIndex, Endpoint>,
channels: tokio::sync::RwLock<HashMap<ShardIndex, Channel>>,
auth_interceptor: AuthInterceptor,
}
@@ -46,24 +45,38 @@ pub struct PageserverClient {
impl PageserverClient {
/// TODO: this doesn't currently react to changes in the shard map.
pub fn new(
auth_interceptor: AuthInterceptor,
tenant_id: AsciiMetadataValue,
timeline_id: AsciiMetadataValue,
auth_token: Option<String>,
shard_map: HashMap<ShardIndex, String>,
) -> Self {
Self {
shard_map,
) -> Result<Self, PageserverClientError> {
let endpoint_map: HashMap<ShardIndex, Endpoint> = shard_map
.into_iter()
.map(|(shard, url)| {
let endpoint = Endpoint::from_shared(url)
.map_err(|_e| PageserverClientError::Other("Unable to parse endpoint {url}".to_string()))?;
Ok::<(ShardIndex, Endpoint), PageserverClientError>((shard, endpoint))
})
.collect::<Result<_, _>>()?;
Ok(Self {
endpoint_map,
channels: RwLock::new(HashMap::new()),
auth_interceptor: auth_interceptor,
}
auth_interceptor: AuthInterceptor::new(
tenant_id,
timeline_id,
auth_token,
),
})
}
//
// TODO: This opens a new gRPC stream for every request, which is extremely inefficient
pub async fn get_page(
&self,
shard: ShardIndex,
request: &GetPageRequest,
) -> Result<Vec<Bytes>, PageserverClientError> {
// FIXME: calculate the shard number correctly
let shard = ShardIndex::unsharded();
let chan = self.get_client(shard).await;
let chan = self.get_client(shard).await?;
let mut client =
PageServiceClient::with_interceptor(chan, self.auth_interceptor.for_shard(shard));
@@ -98,34 +111,31 @@ impl PageserverClient {
// TODO: this should use a connection pool with concurrency limits,
// not a single connection to the shard.
//
async fn get_client(&self, shard: ShardIndex) -> Channel {
async fn get_client(&self, shard: ShardIndex) -> Result<Channel, PageserverClientError> {
// Get channel from the hashmap
let mut channels = self.channels.write();
if let Some(channel) = channels.await.get(&shard) {
return channel.clone();
return Ok(channel.clone());
}
// Create a new channel if it doesn't exist
let shard_url = self
.shard_map
.get(&shard)
.expect("shard not found in shard map");
let shard_endpoint = self
.endpoint_map
.get(&shard);
let attempt = Endpoint::from_shared(shard_url.clone())
.expect("invalid endpoint")
.connect()
.await;
let endpoint = match shard_endpoint{
Some(_endpoint) => _endpoint,
None => {
error!("Shard {shard} not found in shard map");
return Err(PageserverClientError::Other(format!(
"Shard {shard} not found in shard map"
)));
}
};
match attempt {
Ok(channel) => {
channels = self.channels.write();
channels.await.insert(shard, channel.clone());
channel.clone()
}
Err(e) => {
// TODO: handle this more gracefully, e.g. with a connection pool retry
panic!("Failed to connect to shard {shard}: {e}");
}
}
let channel = endpoint.connect().await?;
channels = self.channels.write();
channels.await.insert(shard, channel.clone());
Ok(channel.clone())
}
}