From 7160fd16cd3c0ed4279f2573ce499e8047974edc Mon Sep 17 00:00:00 2001 From: Elizabeth Murray Date: Wed, 28 May 2025 12:40:21 -0700 Subject: [PATCH] Response to review comments, code cleanup. --- Cargo.lock | 2 - Cargo.toml | 2 +- pageserver/client_grpc/Cargo.toml | 4 +- pageserver/client_grpc/src/lib.rs | 90 +++++++++++++++++-------------- 4 files changed, 53 insertions(+), 45 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 94dba46aa9..f71514507c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4442,7 +4442,6 @@ dependencies = [ "pageserver_page_api", "thiserror 1.0.69", "tokio", - "tokio-util", "tonic 0.13.1", "tracing", "utils", @@ -7554,7 +7553,6 @@ dependencies = [ "axum", "base64 0.22.1", "bytes", - "flate2", "h2 0.4.4", "http 1.1.0", "http-body 1.0.0", diff --git a/Cargo.toml b/Cargo.toml index 4790497d8b..4f0fe6cd88 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -200,7 +200,7 @@ tokio-tar = "0.3" tokio-util = { version = "0.7.10", features = ["io", "rt"] } toml = "0.8" toml_edit = "0.22" -tonic = { version = "0.13.1", default-features = false, features = ["gzip", "channel", "codegen", "prost", "router", "server", "tls-ring", "tls-native-roots"] } +tonic = { version = "0.13.1", default-features = false, features = ["channel", "codegen", "prost", "router", "server", "tls-ring", "tls-native-roots"] } tonic-reflection = { version = "0.13.1", features = ["server"] } tower = { version = "0.5.2", default-features = false } tower-http = { version = "0.6.2", features = ["auth", "request-id", "trace"] } diff --git a/pageserver/client_grpc/Cargo.toml b/pageserver/client_grpc/Cargo.toml index fe644cf529..8dcb94e842 100644 --- a/pageserver/client_grpc/Cargo.toml +++ b/pageserver/client_grpc/Cargo.toml @@ -1,7 +1,8 @@ [package] name = "pageserver_client_grpc" version = "0.1.0" -edition = "2024" +edition.workspace = true +license.workspace = true [dependencies] bytes.workspace = true @@ -13,4 +14,3 @@ tracing.workspace = true pageserver_page_api.workspace = true utils.workspace = true tokio.workspace = true -tokio-util = { version = "0.7", features = ["compat"] } diff --git a/pageserver/client_grpc/src/lib.rs b/pageserver/client_grpc/src/lib.rs index c7e3cb10b4..479054d0e2 100644 --- a/pageserver/client_grpc/src/lib.rs +++ b/pageserver/client_grpc/src/lib.rs @@ -1,13 +1,12 @@ -// -// Pageserver gRPC client library -// -// This library provides a gRPC client for the pageserver for the -// communicator project. -// -// This library is a work in progress. -// -// TODO: This should properly use the shard map -// +//! +//! Pageserver gRPC client library +//! +//! This library provides a gRPC client for the pageserver for the +//! communicator project. +//! +//! This library is a work in progress. +//! +//! use std::collections::HashMap; use bytes::Bytes; @@ -38,7 +37,7 @@ pub enum PageserverClientError { } pub struct PageserverClient { - shard_map: HashMap, + endpoint_map: HashMap, channels: tokio::sync::RwLock>, auth_interceptor: AuthInterceptor, } @@ -46,24 +45,38 @@ pub struct PageserverClient { impl PageserverClient { /// TODO: this doesn't currently react to changes in the shard map. pub fn new( - auth_interceptor: AuthInterceptor, + tenant_id: AsciiMetadataValue, + timeline_id: AsciiMetadataValue, + auth_token: Option, shard_map: HashMap, - ) -> Self { - Self { - shard_map, + ) -> Result { + let endpoint_map: HashMap = shard_map + .into_iter() + .map(|(shard, url)| { + let endpoint = Endpoint::from_shared(url) + .map_err(|_e| PageserverClientError::Other("Unable to parse endpoint {url}".to_string()))?; + Ok::<(ShardIndex, Endpoint), PageserverClientError>((shard, endpoint)) + }) + .collect::>()?; + Ok(Self { + endpoint_map, channels: RwLock::new(HashMap::new()), - auth_interceptor: auth_interceptor, - } + auth_interceptor: AuthInterceptor::new( + tenant_id, + timeline_id, + auth_token, + ), + }) } // // TODO: This opens a new gRPC stream for every request, which is extremely inefficient pub async fn get_page( &self, + shard: ShardIndex, request: &GetPageRequest, ) -> Result, PageserverClientError> { // FIXME: calculate the shard number correctly - let shard = ShardIndex::unsharded(); - let chan = self.get_client(shard).await; + let chan = self.get_client(shard).await?; let mut client = PageServiceClient::with_interceptor(chan, self.auth_interceptor.for_shard(shard)); @@ -98,34 +111,31 @@ impl PageserverClient { // TODO: this should use a connection pool with concurrency limits, // not a single connection to the shard. // - async fn get_client(&self, shard: ShardIndex) -> Channel { + async fn get_client(&self, shard: ShardIndex) -> Result { // Get channel from the hashmap let mut channels = self.channels.write(); if let Some(channel) = channels.await.get(&shard) { - return channel.clone(); + return Ok(channel.clone()); } // Create a new channel if it doesn't exist - let shard_url = self - .shard_map - .get(&shard) - .expect("shard not found in shard map"); + let shard_endpoint = self + .endpoint_map + .get(&shard); - let attempt = Endpoint::from_shared(shard_url.clone()) - .expect("invalid endpoint") - .connect() - .await; + let endpoint = match shard_endpoint{ + Some(_endpoint) => _endpoint, + None => { + error!("Shard {shard} not found in shard map"); + return Err(PageserverClientError::Other(format!( + "Shard {shard} not found in shard map" + ))); + } + }; - match attempt { - Ok(channel) => { - channels = self.channels.write(); - channels.await.insert(shard, channel.clone()); - channel.clone() - } - Err(e) => { - // TODO: handle this more gracefully, e.g. with a connection pool retry - panic!("Failed to connect to shard {shard}: {e}"); - } - } + let channel = endpoint.connect().await?; + channels = self.channels.write(); + channels.await.insert(shard, channel.clone()); + Ok(channel.clone()) } }