pageserver/client_grpc: add shard map updates

This commit is contained in:
Erik Grinaker
2025-07-06 13:14:43 +02:00
parent 74e0d85a04
commit 4b06b547c1
6 changed files with 181 additions and 74 deletions

1
Cargo.lock generated
View File

@@ -4631,6 +4631,7 @@ name = "pageserver_client_grpc"
version = "0.1.0"
dependencies = [
"anyhow",
"arc-swap",
"async-trait",
"bytes",
"chrono",

View File

@@ -8,6 +8,7 @@ testing = ["pageserver_api/testing"]
[dependencies]
anyhow.workspace = true
arc-swap.workspace = true
bytes.workspace = true
futures.workspace = true
http.workspace = true

View File

@@ -3,9 +3,10 @@ use std::num::NonZero;
use std::sync::Arc;
use anyhow::anyhow;
use arc_swap::ArcSwap;
use futures::stream::FuturesUnordered;
use futures::{FutureExt as _, StreamExt as _};
use tracing::instrument;
use tracing::{instrument, warn};
use crate::pool::{ChannelPool, ClientGuard, ClientPool, StreamGuard, StreamPool};
use crate::retry::Retry;
@@ -55,8 +56,15 @@ const MAX_BULK_STREAM_QUEUE_DEPTH: NonZero<usize> = NonZero::new(4).unwrap();
/// TODO: this client does not support base backups or LSN leases, as these are only used by
/// compute_ctl. Consider adding this, but LSN leases need concurrent requests on all shards.
pub struct PageserverClient {
// TODO: support swapping out the shard map, e.g. via an ArcSwap.
shards: Shards,
/// The tenant ID.
tenant_id: TenantId,
/// The timeline ID.
timeline_id: TimelineId,
/// The JWT auth token for this tenant, if any.
auth_token: Option<String>,
/// The shards for this tenant.
shards: ArcSwap<Shards>,
/// The retry configuration.
retry: Retry,
}
@@ -66,17 +74,35 @@ impl PageserverClient {
pub fn new(
tenant_id: TenantId,
timeline_id: TimelineId,
shard_map: HashMap<ShardIndex, String>,
stripe_size: ShardStripeSize,
shard_map: ShardMap,
auth_token: Option<String>,
) -> anyhow::Result<Self> {
let shards = Shards::new(tenant_id, timeline_id, shard_map, stripe_size, auth_token)?;
let shards = Shards::new(tenant_id, timeline_id, shard_map, auth_token.clone())?;
Ok(Self {
shards,
tenant_id,
timeline_id,
auth_token,
shards: ArcSwap::new(Arc::new(shards)),
retry: Retry,
})
}
/// Updates the shard map. In-flight requests will complete using the existing shard map, but
/// may retry with the new shard map if they fail.
///
/// TODO: make sure in-flight requests are allowed to complete, and that the old pools are
/// properly spun down and dropped afterwards.
pub fn update_shards(&self, shard_map: ShardMap) -> anyhow::Result<()> {
let shards = Shards::new(
self.tenant_id,
self.timeline_id,
shard_map,
self.auth_token.clone(),
)?;
self.shards.store(Arc::new(shards));
Ok(())
}
/// Returns whether a relation exists.
#[instrument(skip_all, fields(rel=%req.rel, lsn=%req.read_lsn))]
pub async fn check_rel_exists(
@@ -86,7 +112,7 @@ impl PageserverClient {
self.retry
.with(async || {
// Relation metadata is only available on shard 0.
let mut client = self.shards.get_zero().client().await?;
let mut client = self.shards.load().get_zero().client().await?;
client.check_rel_exists(req).await
})
.await
@@ -101,7 +127,7 @@ impl PageserverClient {
self.retry
.with(async || {
// Relation metadata is only available on shard 0.
let mut client = self.shards.get_zero().client().await?;
let mut client = self.shards.load().get_zero().client().await?;
client.get_db_size(req).await
})
.await
@@ -129,28 +155,67 @@ impl PageserverClient {
return Err(tonic::Status::invalid_argument("no block number"));
}
// The shard map may change while we're fetching pages. We execute the request with a stable
// view of the current shards, but if it fails and the shard map was changed concurrently,
// we retry with the new shard map. We have to do this in an outer retry loop because the
// shard map change may require us to resplit the request along different shard boundaries.
//
// TODO: do we need similary retry logic for other requests? Consider moving this into Retry
// somehow.
//
// TODO: we clone the request a bunch of places because of retries. We should pass a
// reference instead and clone at the leaves, but it requires some lifetime juggling.
loop {
let shards = self.shards.load_full();
match Self::get_page_with_shards(req.clone(), self.shards.load_full(), self.retry).await
{
Ok(resp) => return Ok(resp),
Err(status) => {
// If the shard map didn't change, just return the error.
if Arc::ptr_eq(&shards, &self.shards.load()) {
return Err(status);
}
// Otherwise, retry the request with the new shard map.
//
// TODO: we retry all errors here. Moved shards will typically return NotFound
// which is not normally retried. Consider only retrying NotFound here. This
// also needs to be coordinated with the server-side shard split logic.
warn!(
"shard map changed, retrying GetPage error {}: {}",
status.code(),
status.message()
);
}
}
}
}
/// Fetches pages using the given shards. This uses a stable view of the shards, regardless of
/// any concurrent shard map updates.
async fn get_page_with_shards(
req: page_api::GetPageRequest,
shards: Arc<Shards>,
retry: Retry,
) -> tonic::Result<page_api::GetPageResponse> {
// Fast path: request is for a single shard.
if let Some(shard_id) =
GetPageSplitter::is_single_shard(&req, self.shards.count, self.shards.stripe_size)
GetPageSplitter::is_single_shard(&req, shards.count, shards.stripe_size)
{
return self.get_page_for_shard(shard_id, req).await;
return Self::get_page_with_shard(req, shards.get(shard_id)?, retry).await;
}
// Request spans multiple shards. Split it, dispatch concurrent per-shard requests, and
// reassemble the responses.
//
// TODO: when we support shard map updates, we need to detect when it changes and re-split
// the request on errors.
let mut splitter = GetPageSplitter::split(req, self.shards.count, self.shards.stripe_size);
let mut splitter = GetPageSplitter::split(req, shards.count, shards.stripe_size);
let mut shard_requests: FuturesUnordered<_> = splitter
.drain_requests()
.map(|(shard_id, shard_req)| {
// NB: each request will retry internally.
self.get_page_for_shard(shard_id, shard_req)
.map(move |result| result.map(|resp| (shard_id, resp)))
})
.collect();
let mut shard_requests = FuturesUnordered::new();
for (shard_id, shard_req) in splitter.drain_requests() {
// NB: each request will retry internally.
let future = Self::get_page_with_shard(shard_req, shards.get(shard_id)?, retry)
.map(move |result| result.map(|resp| (shard_id, resp)));
shard_requests.push(future);
}
while let Some((shard_id, shard_response)) = shard_requests.next().await.transpose()? {
splitter.add_response(shard_id, shard_response)?;
@@ -159,21 +224,16 @@ impl PageserverClient {
splitter.assemble_response()
}
/// Fetches pages that belong to the given shard.
#[instrument(skip_all, fields(shard = %shard_id))]
async fn get_page_for_shard(
&self,
shard_id: ShardIndex,
/// Fetches pages on the given shard.
#[instrument(skip_all, fields(shard = %shard.id))]
async fn get_page_with_shard(
req: page_api::GetPageRequest,
shard: &Shard,
retry: Retry,
) -> tonic::Result<page_api::GetPageResponse> {
let resp = self
.retry
let resp = retry
.with(async || {
let stream = self
.shards
.get(shard_id)?
.stream(req.request_class.is_bulk())
.await;
let stream = shard.stream(req.request_class.is_bulk()).await;
let resp = stream.send(req.clone()).await?;
// Convert per-request errors into a tonic::Status.
@@ -193,7 +253,8 @@ impl PageserverClient {
let (expected, actual) = (req.block_numbers.len(), resp.page_images.len());
if expected != actual {
return Err(tonic::Status::internal(format!(
"expected {expected} pages for shard {shard_id}, got {actual}",
"expected {expected} pages for shard {}, got {actual}",
shard.id,
)));
}
@@ -209,7 +270,7 @@ impl PageserverClient {
self.retry
.with(async || {
// Relation metadata is only available on shard 0.
let mut client = self.shards.get_zero().client().await?;
let mut client = self.shards.load().get_zero().client().await?;
client.get_rel_size(req).await
})
.await
@@ -224,48 +285,51 @@ impl PageserverClient {
self.retry
.with(async || {
// SLRU segments are only available on shard 0.
let mut client = self.shards.get_zero().client().await?;
let mut client = self.shards.load().get_zero().client().await?;
client.get_slru_segment(req).await
})
.await
}
}
/// Tracks the tenant's shards.
struct Shards {
/// Shard specification for a PageserverClient.
pub struct ShardMap {
/// Maps shard indices to gRPC URLs.
///
/// INVARIANT: every shard 0..count is present, and shard 0 is always present.
/// INVARIANT: every URL is valid and uses grpc:// scheme.
urls: HashMap<ShardIndex, String>,
/// The shard count.
///
/// NB: this is 0 for unsharded tenants, following `ShardIndex::unsharded()` convention.
count: ShardCount,
/// The stripe size. Only used for sharded tenants.
/// The stripe size for this shard map.
stripe_size: ShardStripeSize,
/// Shards by shard index.
///
/// NB: unsharded tenants use count 0, like `ShardIndex::unsharded()`.
///
/// INVARIANT: every shard 0..count is present.
/// INVARIANT: shard 0 is always present.
map: HashMap<ShardIndex, Shard>,
}
impl Shards {
/// Creates a new set of shards based on a shard map.
fn new(
tenant_id: TenantId,
timeline_id: TimelineId,
shard_map: HashMap<ShardIndex, String>,
stripe_size: ShardStripeSize,
auth_token: Option<String>,
impl ShardMap {
/// Creates a new shard map with the given URLs and stripe size. All shards must be given.
/// The stripe size may be omitted for unsharded tenants.
pub fn new(
urls: HashMap<ShardIndex, String>,
stripe_size: Option<ShardStripeSize>,
) -> anyhow::Result<Self> {
let count = match shard_map.len() {
// Compute the shard count.
let count = match urls.len() {
0 => return Err(anyhow!("no shards provided")),
1 => ShardCount::new(0), // NB: unsharded tenants use 0, like `ShardIndex::unsharded()`
n if n > u8::MAX as usize => return Err(anyhow!("too many shards: {n}")),
n => ShardCount::new(n as u8),
};
let mut map = HashMap::new();
for (shard_id, url) in shard_map {
// Determine the stripe size. It doesn't matter for unsharded tenants.
if stripe_size.is_none() && !count.is_unsharded() {
return Err(anyhow!("stripe size must be given for sharded tenants"));
}
let stripe_size = stripe_size.unwrap_or_default();
// Validate the shard map.
for (shard_id, url) in &urls {
// The shard index must match the computed shard count, even for unsharded tenants.
if shard_id.shard_count != count {
return Err(anyhow!("invalid shard index {shard_id}, expected {count}"));
@@ -276,21 +340,64 @@ impl Shards {
}
// The above conditions guarantee that we have all shards 0..count: len() matches count,
// shard number < count, and numbers are unique (via hashmap).
let shard = Shard::new(url, tenant_id, timeline_id, shard_id, auth_token.clone())?;
map.insert(shard_id, shard);
// Validate the URL.
if PageserverProtocol::from_connstring(url)? != PageserverProtocol::Grpc {
return Err(anyhow!("invalid shard URL {url}: must use gRPC"));
}
}
Ok(Self {
urls,
count,
stripe_size,
map,
})
}
}
/// Tracks the tenant's shards.
struct Shards {
/// Shards by shard index.
///
/// INVARIANT: every shard 0..count is present.
/// INVARIANT: shard 0 is always present.
by_index: HashMap<ShardIndex, Shard>,
/// The shard count.
///
/// NB: this is 0 for unsharded tenants, following `ShardIndex::unsharded()` convention.
count: ShardCount,
/// The stripe size. Only used for sharded tenants.
stripe_size: ShardStripeSize,
}
impl Shards {
/// Creates a new set of shards based on a shard map.
fn new(
tenant_id: TenantId,
timeline_id: TimelineId,
shard_map: ShardMap,
auth_token: Option<String>,
) -> anyhow::Result<Self> {
// NB: the shard map has already been validated when constructed.
let mut shards = HashMap::with_capacity(shard_map.urls.len());
for (shard_id, url) in shard_map.urls {
shards.insert(
shard_id,
Shard::new(url, tenant_id, timeline_id, shard_id, auth_token.clone())?,
);
}
Ok(Self {
count: shard_map.count,
stripe_size: shard_map.stripe_size,
by_index: shards,
})
}
/// Looks up the given shard.
#[allow(clippy::result_large_err)] // TODO: check perf impact
fn get(&self, shard_id: ShardIndex) -> tonic::Result<&Shard> {
self.map
self.by_index
.get(&shard_id)
.ok_or_else(|| tonic::Status::not_found(format!("unknown shard {shard_id}")))
}
@@ -312,6 +419,8 @@ impl Shards {
/// * Bulk client pool: unbounded.
/// * Bulk stream pool: MAX_BULK_STREAMS and MAX_BULK_STREAM_QUEUE_DEPTH.
struct Shard {
/// The shard ID.
id: ShardIndex,
/// Unary gRPC client pool.
client_pool: Arc<ClientPool>,
/// GetPage stream pool.
@@ -329,11 +438,6 @@ impl Shard {
shard_id: ShardIndex,
auth_token: Option<String>,
) -> anyhow::Result<Self> {
// Sanity-check that the URL uses gRPC.
if PageserverProtocol::from_connstring(&url)? != PageserverProtocol::Grpc {
return Err(anyhow!("invalid shard URL {url}: must use gRPC"));
}
// Common channel pool for unary and stream requests. Bounded by client/stream pools.
let channel_pool = ChannelPool::new(url.clone(), MAX_CLIENTS_PER_CHANNEL)?;
@@ -378,6 +482,7 @@ impl Shard {
);
Ok(Self {
id: shard_id,
client_pool,
stream_pool,
bulk_stream_pool,

View File

@@ -3,4 +3,4 @@ mod pool;
mod retry;
mod split;
pub use client::PageserverClient;
pub use client::{PageserverClient, ShardMap};

View File

@@ -8,6 +8,7 @@ use utils::backoff::exponential_backoff_duration;
/// A retry handler for Pageserver gRPC requests.
///
/// This is used instead of backoff::retry for better control and observability.
#[derive(Clone, Copy)]
pub struct Retry;
impl Retry {

View File

@@ -12,8 +12,7 @@ use crate::integrated_cache::{CacheResult, IntegratedCacheWriteAccess};
use crate::neon_request::{CGetPageVRequest, CPrefetchVRequest};
use crate::neon_request::{NeonIORequest, NeonIOResult};
use crate::worker_process::in_progress_ios::{RequestInProgressKey, RequestInProgressTable};
use pageserver_api::shard::ShardStripeSize;
use pageserver_client_grpc::PageserverClient;
use pageserver_client_grpc::{PageserverClient, ShardMap};
use pageserver_page_api as page_api;
use metrics::{IntCounter, IntCounterVec};
@@ -93,11 +92,11 @@ pub(super) async fn init(
.worker_process_init(last_lsn, file_cache);
// TODO: plumb through the stripe size.
let stripe_size = ShardStripeSize::default();
let tenant_id = TenantId::from_str(&tenant_id).expect("invalid tenant ID");
let timeline_id = TimelineId::from_str(&timeline_id).expect("invalid timeline ID");
let client = PageserverClient::new(tenant_id, timeline_id, shard_map, stripe_size, auth_token)
.expect("count not create client");
let shard_map = ShardMap::new(shard_map, None).expect("invalid shard map");
let client = PageserverClient::new(tenant_id, timeline_id, shard_map, auth_token)
.expect("could not create client");
let request_counters = IntCounterVec::new(
metrics::core::Opts::new(