Add neon-shard-id header

This commit is contained in:
Erik Grinaker
2025-04-30 11:18:06 +02:00
parent 7bb58be546
commit 4c77397943
7 changed files with 89 additions and 44 deletions

1
Cargo.lock generated
View File

@@ -4533,6 +4533,7 @@ dependencies = [
"thiserror 1.0.69",
"tonic",
"tracing",
"utils",
]
[[package]]

View File

@@ -12,3 +12,4 @@ tonic.workspace = true
tracing.workspace = true
pageserver_page_api.workspace = true
utils.workspace = true

View File

@@ -15,9 +15,8 @@ use tonic::transport::Channel;
use pageserver_page_api::model::*;
use pageserver_page_api::proto;
type Shardno = u16;
use pageserver_page_api::proto::PageServiceClient;
use utils::shard::ShardIndex;
type MyPageServiceClient = pageserver_page_api::proto::PageServiceClient<
tonic::service::interceptor::InterceptedService<tonic::transport::Channel, AuthInterceptor>,
@@ -40,9 +39,9 @@ pub struct PageserverClient {
_auth_token: Option<String>,
shard_map: HashMap<Shardno, String>,
shard_map: HashMap<ShardIndex, String>,
channels: RwLock<HashMap<Shardno, Channel>>,
channels: RwLock<HashMap<ShardIndex, Channel>>,
auth_interceptor: AuthInterceptor,
}
@@ -53,7 +52,7 @@ impl PageserverClient {
tenant_id: &str,
timeline_id: &str,
auth_token: &Option<String>,
shard_map: HashMap<Shardno, String>,
shard_map: HashMap<ShardIndex, String>,
) -> Self {
Self {
_tenant_id: tenant_id.to_string(),
@@ -70,9 +69,9 @@ impl PageserverClient {
request: &RelExistsRequest,
) -> Result<bool, PageserverClientError> {
// Current sharding model assumes that all metadata is present only at shard 0.
let shard_no = 0;
let shard = ShardIndex::unsharded();
let mut client = self.get_client(shard_no).await?;
let mut client = self.get_client(shard).await?;
let request = proto::RelExistsRequest::from(request);
let response = client.rel_exists(tonic::Request::new(request)).await?;
@@ -85,9 +84,9 @@ impl PageserverClient {
request: &RelSizeRequest,
) -> Result<u32, PageserverClientError> {
// Current sharding model assumes that all metadata is present only at shard 0.
let shard_no = 0;
let shard = ShardIndex::unsharded();
let mut client = self.get_client(shard_no).await?;
let mut client = self.get_client(shard).await?;
let request = proto::RelSizeRequest::from(request);
let response = client.rel_size(tonic::Request::new(request)).await?;
@@ -97,9 +96,9 @@ impl PageserverClient {
pub async fn get_page(&self, request: &GetPageRequest) -> Result<Bytes, PageserverClientError> {
// FIXME: calculate the shard number correctly
let shard_no = 0;
let shard = ShardIndex::unsharded();
let mut client = self.get_client(shard_no).await?;
let mut client = self.get_client(shard).await?;
let request = proto::GetPageRequest::from(request);
let response = client.get_page(tonic::Request::new(request)).await?;
@@ -115,9 +114,9 @@ impl PageserverClient {
PageserverClientError,
> {
// FIXME: calculate the shard number correctly
let shard_no = 0;
let shard = ShardIndex::unsharded();
let mut client = self.get_client(shard_no).await?;
let mut client = self.get_client(shard).await?;
Ok(client.get_pages(tonic::Request::new(requests)).await?)
}
@@ -128,9 +127,9 @@ impl PageserverClient {
request: &DbSizeRequest,
) -> Result<u64, PageserverClientError> {
// Current sharding model assumes that all metadata is present only at shard 0.
let shard_no = 0;
let shard = ShardIndex::unsharded();
let mut client = self.get_client(shard_no).await?;
let mut client = self.get_client(shard).await?;
let request = proto::DbSizeRequest::from(request);
let response = client.db_size(tonic::Request::new(request)).await?;
@@ -148,9 +147,9 @@ impl PageserverClient {
PageserverClientError,
> {
// Current sharding model assumes that all metadata is present only at shard 0.
let shard_no = 0;
let shard = ShardIndex::unsharded();
let mut client = self.get_client(shard_no).await?;
let mut client = self.get_client(shard).await?;
if gzip {
client = client.accept_compressed(tonic::codec::CompressionEncoding::Gzip);
}
@@ -167,12 +166,12 @@ impl PageserverClient {
/// reuse it. If not, create a new client and put it to the cache.
async fn get_client(
&self,
shard_no: u16,
shard: ShardIndex,
) -> Result<MyPageServiceClient, PageserverClientError> {
let reused_channel: Option<Channel> = {
let channels = self.channels.read().unwrap();
channels.get(&shard_no).cloned()
channels.get(&shard).cloned()
};
let channel = if let Some(reused_channel) = reused_channel {
@@ -180,8 +179,8 @@ impl PageserverClient {
} else {
let endpoint: tonic::transport::Endpoint = self
.shard_map
.get(&shard_no)
.expect("no url for shard {shard_no}")
.get(&shard)
.expect("no url for shard {shard}")
.parse()?;
let channel = endpoint.connect().await?;
@@ -190,12 +189,13 @@ impl PageserverClient {
// client in the cache.
{
let mut channels = self.channels.write().unwrap();
channels.insert(shard_no, channel.clone());
channels.insert(shard, channel.clone());
}
channel
};
let client = PageServiceClient::with_interceptor(channel, self.auth_interceptor.clone());
let client =
PageServiceClient::with_interceptor(channel, self.auth_interceptor.for_shard(shard));
Ok(client)
}
}
@@ -204,6 +204,7 @@ impl PageserverClient {
#[derive(Clone)]
struct AuthInterceptor {
tenant_id: AsciiMetadataValue,
shard_id: Option<AsciiMetadataValue>,
timeline_id: AsciiMetadataValue,
auth_header: Option<AsciiMetadataValue>, // including "Bearer " prefix
@@ -213,12 +214,24 @@ impl AuthInterceptor {
fn new(tenant_id: &str, timeline_id: &str, auth_token: Option<&str>) -> Self {
Self {
tenant_id: tenant_id.parse().expect("could not parse tenant id"),
shard_id: None,
timeline_id: timeline_id.parse().expect("could not parse timeline id"),
auth_header: auth_token
.map(|t| format!("Bearer {t}"))
.map(|t| t.parse().expect("could not parse auth token")),
}
}
fn for_shard(&self, shard_id: ShardIndex) -> Self {
let mut with_shard = self.clone();
with_shard.shard_id = Some(
shard_id
.to_string()
.parse()
.expect("could not parse shard id"),
);
with_shard
}
}
impl tonic::service::Interceptor for AuthInterceptor {

View File

@@ -3,6 +3,7 @@
// Request metadata:
// - authorization: JWT token ("Bearer <token>"), if auth is enabled
// - neon-tenant-id: tenant ID ("7c4a1f9e3bd6470c8f3e21a65bd2e980")
// - neon-shard-id: shard ID, as <number><count> in hex ("0b10" = shard 11 of 16)
// - neon-timeline-id: timeline ID ("f08c4e9a2d5f76b1e3a7c2d8910f4b3e")
//
// TODO: what else? Priority? OpenTelemetry tracing?

View File

@@ -17,6 +17,7 @@ use tokio::task::JoinSet;
use tracing::{info, instrument};
use utils::id::TenantTimelineId;
use utils::lsn::Lsn;
use utils::shard::ShardIndex;
use crate::util::tokio_thread_local_stats::AllThreadLocalStats;
use crate::util::{request_stats, tokio_thread_local_stats};
@@ -297,7 +298,10 @@ async fn client_grpc(
all_work_done_barrier: Arc<Barrier>,
live_stats: Arc<LiveStats>,
) {
let shard_map = HashMap::from([(0, args.page_service_connstring.clone())]);
let shard_map = HashMap::from([(
ShardIndex::unsharded(),
args.page_service_connstring.clone(),
)]);
let client = pageserver_client_grpc::PageserverClient::new(
&timeline.tenant_id.to_string(),
&timeline.timeline_id.to_string(),

View File

@@ -20,6 +20,7 @@ use tokio_util::sync::CancellationToken;
use tracing::info;
use utils::id::TenantTimelineId;
use utils::lsn::Lsn;
use utils::shard::ShardIndex;
use crate::util::tokio_thread_local_stats::AllThreadLocalStats;
use crate::util::{request_stats, tokio_thread_local_stats};
@@ -457,7 +458,10 @@ async fn client_grpc(
ranges: Vec<KeyRange>,
weights: rand::distributions::weighted::WeightedIndex<i128>,
) {
let shard_map = HashMap::from([(0, args.page_service_connstring.clone())]);
let shard_map = HashMap::from([(
ShardIndex::unsharded(),
args.page_service_connstring.clone(),
)]);
let client = pageserver_client_grpc::PageserverClient::new(
&worker_id.timeline.tenant_id.to_string(),
&worker_id.timeline.timeline_id.to_string(),
@@ -554,7 +558,10 @@ async fn client_grpc_stream(
ranges: Vec<KeyRange>,
weights: rand::distributions::weighted::WeightedIndex<i128>,
) {
let shard_map = HashMap::from([(0, args.page_service_connstring.clone())]);
let shard_map = HashMap::from([(
ShardIndex::unsharded(),
args.page_service_connstring.clone(),
)]);
let client = pageserver_client_grpc::PageserverClient::new(
&worker_id.timeline.tenant_id.to_string(),
&worker_id.timeline.timeline_id.to_string(),

View File

@@ -65,8 +65,6 @@ use tonic;
use tonic::codec::CompressionEncoding;
use tonic::service::interceptor::InterceptedService;
use pageserver_api::key::rel_block_to_key;
use crate::pgdatadir_mapping::Version;
use postgres_ffi::pg_constants::DEFAULTTABLESPACE_OID;
@@ -156,13 +154,14 @@ impl PageService for PageServiceService {
request: tonic::Request<proto::RelExistsRequest>,
) -> std::result::Result<tonic::Response<proto::RelExistsResponse>, tonic::Status> {
let ttid = self.extract_ttid(request.metadata())?;
let shard = self.extract_shard(request.metadata())?;
let req: model::RelExistsRequest = request.get_ref().try_into()?;
let rel = convert_reltag(&req.rel);
let span = tracing::info_span!("rel_exists", tenant_id = %ttid.tenant_id, timeline_id = %ttid.timeline_id, rel = %rel, req_lsn = %req.common.request_lsn);
async {
let timeline = self.get_timeline(ttid, ShardSelector::Zero).await?;
let timeline = self.get_timeline(ttid, shard).await?;
let ctx = self.ctx.with_scope_timeline(&timeline);
let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn();
let lsn = Self::wait_or_get_last_lsn(
@@ -190,13 +189,14 @@ impl PageService for PageServiceService {
request: tonic::Request<proto::RelSizeRequest>,
) -> std::result::Result<tonic::Response<proto::RelSizeResponse>, tonic::Status> {
let ttid = self.extract_ttid(request.metadata())?;
let shard = self.extract_shard(request.metadata())?;
let req: model::RelSizeRequest = request.get_ref().try_into()?;
let rel = convert_reltag(&req.rel);
let span = tracing::info_span!("rel_size", tenant_id = %ttid.tenant_id, timeline_id = %ttid.timeline_id, rel = %rel, req_lsn = %req.common.request_lsn);
async {
let timeline = self.get_timeline(ttid, ShardSelector::Zero).await?;
let timeline = self.get_timeline(ttid, shard).await?;
let ctx = self.ctx.with_scope_timeline(&timeline);
let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn();
let lsn = Self::wait_or_get_last_lsn(
@@ -221,14 +221,11 @@ impl PageService for PageServiceService {
request: tonic::Request<proto::GetPageRequest>,
) -> std::result::Result<tonic::Response<proto::GetPageResponse>, tonic::Status> {
let ttid = self.extract_ttid(request.metadata())?;
let shard = self.extract_shard(request.metadata())?;
let req: model::GetPageRequest = request.get_ref().try_into()?;
// Calculate shard number.
//
// FIXME: this should probably be part of the data_api crate.
let rel = convert_reltag(&req.rel);
let key = rel_block_to_key(rel, req.block_number);
let timeline = self.get_timeline(ttid, ShardSelector::Page(key)).await?;
let timeline = self.get_timeline(ttid, shard).await?;
let ctx = self.ctx.with_scope_timeline(&timeline);
let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn();
@@ -274,11 +271,9 @@ impl PageService for PageServiceService {
&self,
request: tonic::Request<tonic::Streaming<proto::GetPageRequest>>,
) -> Result<tonic::Response<Self::GetPagesStream>, tonic::Status> {
// TODO: pass the shard index in the request metadata.
let ttid = self.extract_ttid(request.metadata())?;
let timeline = self
.get_timeline(ttid, ShardSelector::Known(ShardIndex::unsharded()))
.await?;
let shard = self.extract_shard(request.metadata())?;
let timeline = self.get_timeline(ttid, shard).await?;
let ctx = self.ctx.with_scope_timeline(&timeline);
let conf = self.conf;
@@ -327,12 +322,13 @@ impl PageService for PageServiceService {
request: tonic::Request<proto::DbSizeRequest>,
) -> Result<tonic::Response<proto::DbSizeResponse>, tonic::Status> {
let ttid = self.extract_ttid(request.metadata())?;
let shard = self.extract_shard(request.metadata())?;
let req: model::DbSizeRequest = request.get_ref().try_into()?;
let span = tracing::info_span!("get_page", tenant_id = %ttid.tenant_id, timeline_id = %ttid.timeline_id, db_oid = %req.db_oid, req_lsn = %req.common.request_lsn);
async {
let timeline = self.get_timeline(ttid, ShardSelector::Zero).await?;
let timeline = self.get_timeline(ttid, shard).await?;
let ctx = self.ctx.with_scope_timeline(&timeline);
let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn();
let lsn = Self::wait_or_get_last_lsn(
@@ -361,9 +357,10 @@ impl PageService for PageServiceService {
request: tonic::Request<proto::GetBaseBackupRequest>,
) -> Result<tonic::Response<Self::GetBaseBackupStream>, tonic::Status> {
let ttid = self.extract_ttid(request.metadata())?;
let shard = self.extract_shard(request.metadata())?;
let req: model::GetBaseBackupRequest = request.get_ref().try_into()?;
let timeline = self.get_timeline(ttid, ShardSelector::Zero).await?;
let timeline = self.get_timeline(ttid, shard).await?;
let ctx = self.ctx.with_scope_timeline(&timeline);
let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn();
@@ -500,12 +497,13 @@ impl PageService for PageServiceService {
request: tonic::Request<proto::GetSlruSegmentRequest>,
) -> Result<tonic::Response<proto::GetSlruSegmentResponse>, tonic::Status> {
let ttid = self.extract_ttid(request.metadata())?;
let shard = self.extract_shard(request.metadata())?;
let req: model::GetSlruSegmentRequest = request.get_ref().try_into()?;
let span = tracing::info_span!("get_slru_segment", tenant_id = %ttid.tenant_id, timeline_id = %ttid.timeline_id, kind = %req.kind, segno = %req.segno, req_lsn = %req.common.request_lsn);
async {
let timeline = self.get_timeline(ttid, ShardSelector::Zero).await?;
let timeline = self.get_timeline(ttid, shard).await?;
let ctx = self.ctx.with_scope_timeline(&timeline);
let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn();
let lsn = Self::wait_or_get_last_lsn(
@@ -540,7 +538,7 @@ impl PageServiceService {
async fn get_timeline(
&self,
ttid: TenantTimelineId,
shard_selector: ShardSelector,
shard: ShardIndex,
) -> Result<Arc<Timeline>, tonic::Status> {
let timeout = ACTIVE_TENANT_TIMEOUT;
let wait_start = Instant::now();
@@ -549,7 +547,7 @@ impl PageServiceService {
let tenant_shard = loop {
let resolved = self
.tenant_mgr
.resolve_attached_shard(&ttid.tenant_id, shard_selector);
.resolve_attached_shard(&ttid.tenant_id, ShardSelector::Known(shard));
match resolved {
ShardResolveResult::Found(tenant_shard) => break tenant_shard,
@@ -623,6 +621,26 @@ impl PageServiceService {
Ok(TenantTimelineId::new(tenant_id, timeline_id))
}
/// Extract ShardSelector from the request metadata.
fn extract_shard(
&self,
metadata: &tonic::metadata::MetadataMap,
) -> Result<ShardIndex, tonic::Status> {
let shard_id = metadata
.get("neon-shard-id")
.ok_or(tonic::Status::invalid_argument(
"neon-shard-id metadata missing",
))?
.to_str()
.map_err(|_| {
tonic::Status::invalid_argument(
"invalid UTF-8 characters in shard-selector metadata",
)
})?;
ShardIndex::from_str(shard_id)
.map_err(|err| tonic::Status::invalid_argument(format!("invalid neon-shard-id: {err}")))
}
// XXX: copied from PageServerHandler
async fn wait_or_get_last_lsn(
timeline: &Timeline,