Add GetPage request splitting

This commit is contained in:
Erik Grinaker
2025-07-03 18:31:06 +02:00
parent 96a817fa2b
commit 42e4e5a418
3 changed files with 239 additions and 32 deletions

View File

@@ -2,13 +2,15 @@ use std::collections::HashMap;
use std::sync::Arc;
use anyhow::anyhow;
use futures::stream::FuturesUnordered;
use futures::{FutureExt as _, StreamExt};
use tokio::time::Instant;
use tracing::{error, info, instrument, warn};
use crate::pool::{ChannelPool, ClientGuard, ClientPool, StreamGuard, StreamPool};
use crate::split::GetPageSplitter;
use compute_api::spec::PageserverProtocol;
use pageserver_api::key::{Key, rel_block_to_key};
use pageserver_api::shard::{ShardStripeSize, key_to_shard_number};
use pageserver_api::shard::ShardStripeSize;
use pageserver_page_api as page_api;
use utils::backoff::exponential_backoff_duration;
use utils::id::{TenantId, TimelineId};
@@ -73,7 +75,8 @@ impl PageserverClient {
.await
}
/// Fetches a page. The `request_id` must be unique across all in-flight requests.
/// Fetches pages. The `request_id` must be unique across all in-flight requests. Will
/// automatically split requests that span multiple shards, and reassemble the responses.
///
/// Unlike the `page_api::Client`, this client automatically converts `status_code` into
/// `tonic::Status` errors. All responses will have `GetPageStatusCode::Ok`.
@@ -88,31 +91,75 @@ impl PageserverClient {
&self,
req: page_api::GetPageRequest,
) -> tonic::Result<page_api::GetPageResponse> {
// TODO: this needs to split batch requests across shards and reassemble responses into a
// single response. It must also re-split the batch in case the shard map changes. For now,
// just use the first page.
let key = rel_block_to_key(
req.rel,
req.block_numbers
.first()
.copied()
.ok_or_else(|| tonic::Status::invalid_argument("no block numbers provided"))?,
);
// Make sure we have at least one page.
if req.block_numbers.is_empty() {
return Err(tonic::Status::invalid_argument("no block number"));
}
self.with_retries(async || {
let stream = self.shards.get_for_key(key).stream().await;
let resp = stream.send(req.clone()).await?;
// Fast path: request is for a single shard.
if let Some(shard_id) =
GetPageSplitter::is_single_shard(&req, self.shards.count, self.shards.stripe_size)
{
return self.get_page_for_shard(shard_id, req).await;
}
if resp.status_code != page_api::GetPageStatusCode::Ok {
return Err(tonic::Status::new(
resp.status_code.into(),
resp.reason.unwrap_or_else(|| String::from("unknown error")),
));
}
// Slow path: request spans multiple shards. Split it, dispatch per-shard requests in
// parallel, and reassemble the responses.
//
// TODO: when we add shard map updates, we need to detect that case and re-split the
// request on errors.
let mut splitter = GetPageSplitter::split(req, self.shards.count, self.shards.stripe_size);
Ok(resp)
})
.await
let mut shard_requests: FuturesUnordered<_> = splitter
.drain_requests()
.map(|(shard_id, shard_req)| {
// NB: each request will retry internally.
self.get_page_for_shard(shard_id, shard_req)
.map(move |result| result.map(|resp| (shard_id, resp)))
})
.collect();
while let Some((shard_id, shard_response)) = shard_requests.next().await.transpose()? {
splitter.add_response(shard_id, shard_response)?;
}
splitter.reassemble()
}
/// Fetches pages that belong to the given shard.
#[instrument(skip_all, fields(shard = %shard_id))]
async fn get_page_for_shard(
&self,
shard_id: ShardIndex,
req: page_api::GetPageRequest,
) -> tonic::Result<page_api::GetPageResponse> {
let resp = self
.with_retries(async || {
let stream = self.shards.get(shard_id)?.stream().await;
let resp = stream.send(req.clone()).await?;
// Convert per-request errors into a tonic::Status.
if resp.status_code != page_api::GetPageStatusCode::Ok {
return Err(tonic::Status::new(
resp.status_code.into(),
resp.reason.unwrap_or_else(|| String::from("unknown error")),
));
}
Ok(resp)
})
.await?;
// Make sure we got the right number of pages.
// NB: check outside of the retry loop, since we don't want to retry this.
let (expected, actual) = (req.block_numbers.len(), resp.page_images.len());
if expected != actual {
return Err(tonic::Status::internal(format!(
"expected {expected} pages for shard {shard_id}, got {actual}",
)));
}
Ok(resp)
}
/// Returns the size of a relation, as # of blocks.
@@ -319,13 +366,6 @@ impl Shards {
.ok_or_else(|| tonic::Status::not_found(format!("unknown shard {shard_id}")))
}
/// Looks up the shard that owns the given key.
fn get_for_key(&self, key: Key) -> &Shard {
let shard_number = key_to_shard_number(self.count, self.stripe_size, &key);
self.get(ShardIndex::new(shard_number, self.count))
.expect("must exist")
}
/// Returns shard 0.
fn get_zero(&self) -> &Shard {
self.get(ShardIndex::new(ShardNumber(0), self.count))

View File

@@ -1,4 +1,5 @@
mod client;
mod pool;
mod split;
pub use client::PageserverClient;

View File

@@ -0,0 +1,166 @@
use std::collections::HashMap;
use bytes::Bytes;
use pageserver_api::key::rel_block_to_key;
use pageserver_api::shard::{ShardStripeSize, key_to_shard_number};
use pageserver_page_api as page_api;
use utils::shard::{ShardCount, ShardIndex};
/// Splits GetPageRequests across shard boundaries and reassembles the responses.
/// TODO: add tests for this.
pub struct GetPageSplitter {
/// The original request ID. Used for all shard requests.
request_id: page_api::RequestID,
/// Requests by shard index.
requests: HashMap<ShardIndex, page_api::GetPageRequest>,
/// Maps the page offset in the input request (index) to the shard index. This is used to
/// reassemble the responses in the same order as the original request.
block_shards: Vec<ShardIndex>,
/// Page responses by shard index. Will be reassembled into a single response.
responses: HashMap<ShardIndex, Vec<Bytes>>,
}
impl GetPageSplitter {
/// Checks if the given request belongs to a single shard, and returns the shard ID. This is the
/// common case, so we do a full scan in order to avoid unnecessary allocations and overhead.
/// The caller must ensure that the request has at least one block number, or this will panic.
pub fn is_single_shard(
req: &page_api::GetPageRequest,
count: ShardCount,
stripe_size: ShardStripeSize,
) -> Option<ShardIndex> {
// Fast path: unsharded tenant.
if count.is_unsharded() {
return Some(ShardIndex::unsharded());
}
// Find the base shard index for the first page, and compare with the rest.
let key = rel_block_to_key(req.rel, *req.block_numbers.first().expect("no pages"));
let shard_number = key_to_shard_number(count, stripe_size, &key);
req.block_numbers
.iter()
.skip(1) // computed above
.all(|&blkno| {
let key = rel_block_to_key(req.rel, blkno);
key_to_shard_number(count, stripe_size, &key) == shard_number
})
.then_some(ShardIndex::new(shard_number, count))
}
/// Splits the given request.
pub fn split(
req: page_api::GetPageRequest,
count: ShardCount,
stripe_size: ShardStripeSize,
) -> Self {
// The caller should make sure we don't split requests unnecessarily.
debug_assert!(
Self::is_single_shard(&req, count, stripe_size).is_some(),
"unnecessary request split"
);
// Split the requests by shard index.
let mut requests = HashMap::with_capacity(2); // common case
let mut block_shards = Vec::with_capacity(req.block_numbers.len());
for blkno in req.block_numbers {
let key = rel_block_to_key(req.rel, blkno);
let shard_number = key_to_shard_number(count, stripe_size, &key);
let shard_id = ShardIndex::new(shard_number, count);
let shard_req = requests
.entry(shard_id)
.or_insert_with(|| page_api::GetPageRequest {
request_id: req.request_id,
request_class: req.request_class,
rel: req.rel,
read_lsn: req.read_lsn,
block_numbers: Vec::new(),
});
shard_req.block_numbers.push(blkno);
block_shards.push(shard_id);
}
Self {
request_id: req.request_id,
responses: HashMap::with_capacity(requests.len()),
requests,
block_shards,
}
}
/// Drains the per-shard requests, moving them out of the hashmap to avoid extra allocations.
pub fn drain_requests(
&mut self,
) -> impl Iterator<Item = (ShardIndex, page_api::GetPageRequest)> {
self.requests.drain()
}
/// Adds a response for the given shard index.
#[allow(clippy::result_large_err)]
pub fn add_response(
&mut self,
shard_id: ShardIndex,
response: page_api::GetPageResponse,
) -> tonic::Result<()> {
// The caller should already have converted status codes into tonic::Status.
assert_eq!(response.status_code, page_api::GetPageStatusCode::Ok);
// Ensure the response is for the same request ID.
if response.request_id != self.request_id {
return Err(tonic::Status::internal(format!(
"response ID {} does not match request ID {}",
response.request_id, self.request_id
)));
}
// Add the response data to the map.
let old = self.responses.insert(shard_id, response.page_images);
assert!(old.is_none(), "duplicate response for shard {shard_id}");
Ok(())
}
/// Reassembles the shard responses into a single response.
#[allow(clippy::result_large_err)]
pub fn reassemble(self) -> tonic::Result<page_api::GetPageResponse> {
let mut response = page_api::GetPageResponse {
request_id: self.request_id,
status_code: page_api::GetPageStatusCode::Ok,
reason: None,
page_images: Vec::with_capacity(self.block_shards.len()),
};
// Convert the shard responses into iterators we can conveniently pull from.
let mut shard_responses = HashMap::with_capacity(self.responses.len());
for (shard_id, responses) in self.responses {
shard_responses.insert(shard_id, responses.into_iter());
}
// Reassemble the responses in the same order as the original request.
for shard_id in &self.block_shards {
let page = shard_responses
.get_mut(shard_id)
.ok_or_else(|| {
tonic::Status::internal(format!("missing response for shard {shard_id}"))
})?
.next()
.ok_or_else(|| {
tonic::Status::internal(format!("missing page from shard {shard_id}"))
})?;
response.page_images.push(page);
}
// Make sure we didn't get any additional pages.
for (shard_id, mut pages) in shard_responses {
if pages.next().is_some() {
return Err(tonic::Status::internal(format!(
"extra pages returned from shard {shard_id}"
)));
}
}
Ok(response)
}
}