libs: refactor ShardCount.0 to private (#6690)

## Problem

The ShardCount type has a magic '0' value that represents a legacy
single-sharded tenant, whose TenantShardId is formatted without a
`-0001` suffix (i.e. formatted as a traditional TenantId).

This was error-prone in code locations that wanted the actual number of
shards: they had to handle the 0 case specially.

## Summary of changes

- Make the internal value of ShardCount private, and expose `count()`
and `literal()` getters so that callers have to explicitly say whether
they want the literal value (e.g. for storing in a TenantShardId), or
the actual number of shards in the tenant.


---------

Co-authored-by: Arpad Müller <arpad-m@users.noreply.github.com>
This commit is contained in:
John Spray
2024-02-15 21:59:39 +00:00
committed by GitHub
parent f0d8bd7855
commit 6b980f38da
13 changed files with 75 additions and 54 deletions

View File

@@ -3,7 +3,7 @@ use std::{collections::HashMap, time::Duration};
use control_plane::endpoint::{ComputeControlPlane, EndpointStatus};
use control_plane::local_env::LocalEnv;
use hyper::{Method, StatusCode};
use pageserver_api::shard::{ShardCount, ShardIndex, ShardNumber, TenantShardId};
use pageserver_api::shard::{ShardIndex, ShardNumber, TenantShardId};
use postgres_connection::parse_host_port;
use serde::{Deserialize, Serialize};
use tokio_util::sync::CancellationToken;
@@ -77,7 +77,7 @@ impl ComputeHookTenant {
self.shards
.sort_by_key(|(shard, _node_id)| shard.shard_number);
if self.shards.len() == shard_count.0 as usize || shard_count == ShardCount(0) {
if self.shards.len() == shard_count.count() as usize || shard_count.is_unsharded() {
// We have pageservers for all the shards: emit a configuration update
return Some(ComputeHookNotifyRequest {
tenant_id,
@@ -94,7 +94,7 @@ impl ComputeHookTenant {
tracing::info!(
"ComputeHookTenant::maybe_reconfigure: not enough shards ({}/{})",
self.shards.len(),
shard_count.0
shard_count.count()
);
}

View File

@@ -222,7 +222,7 @@ impl Persistence {
let tenant_shard_id = TenantShardId {
tenant_id: TenantId::from_str(tsp.tenant_id.as_str())?,
shard_number: ShardNumber(tsp.shard_number as u8),
shard_count: ShardCount(tsp.shard_count as u8),
shard_count: ShardCount::new(tsp.shard_count as u8),
};
tenants_map.insert(tenant_shard_id, tsp);
@@ -318,7 +318,7 @@ impl Persistence {
tenant_id: TenantId::from_str(tsp.tenant_id.as_str())
.map_err(|e| DatabaseError::Logical(format!("Malformed tenant id: {e}")))?,
shard_number: ShardNumber(tsp.shard_number as u8),
shard_count: ShardCount(tsp.shard_count as u8),
shard_count: ShardCount::new(tsp.shard_count as u8),
};
result.insert(tenant_shard_id, Generation::new(tsp.generation as u32));
}
@@ -340,7 +340,7 @@ impl Persistence {
let updated = diesel::update(tenant_shards)
.filter(tenant_id.eq(tenant_shard_id.tenant_id.to_string()))
.filter(shard_number.eq(tenant_shard_id.shard_number.0 as i32))
.filter(shard_count.eq(tenant_shard_id.shard_count.0 as i32))
.filter(shard_count.eq(tenant_shard_id.shard_count.literal() as i32))
.set((
generation.eq(generation + 1),
generation_pageserver.eq(node_id.0 as i64),
@@ -362,7 +362,7 @@ impl Persistence {
let updated = diesel::update(tenant_shards)
.filter(tenant_id.eq(tenant_shard_id.tenant_id.to_string()))
.filter(shard_number.eq(tenant_shard_id.shard_number.0 as i32))
.filter(shard_count.eq(tenant_shard_id.shard_count.0 as i32))
.filter(shard_count.eq(tenant_shard_id.shard_count.literal() as i32))
.set((
generation_pageserver.eq(i64::MAX),
placement_policy.eq(serde_json::to_string(&PlacementPolicy::Detached).unwrap()),
@@ -392,21 +392,19 @@ impl Persistence {
conn.transaction(|conn| -> DatabaseResult<()> {
// Mark parent shards as splitting
let expect_parent_records = std::cmp::max(1, old_shard_count.0);
let updated = diesel::update(tenant_shards)
.filter(tenant_id.eq(split_tenant_id.to_string()))
.filter(shard_count.eq(old_shard_count.0 as i32))
.filter(shard_count.eq(old_shard_count.literal() as i32))
.set((splitting.eq(1),))
.execute(conn)?;
if u8::try_from(updated)
.map_err(|_| DatabaseError::Logical(
format!("Overflow existing shard count {} while splitting", updated))
)? != expect_parent_records {
)? != old_shard_count.count() {
// Perhaps a deletion or another split raced with this attempt to split, mutating
// the parent shards that we intend to split. In this case the split request should fail.
return Err(DatabaseError::Logical(
format!("Unexpected existing shard count {updated} when preparing tenant for split (expected {expect_parent_records})")
format!("Unexpected existing shard count {updated} when preparing tenant for split (expected {})", old_shard_count.count())
));
}
@@ -418,7 +416,7 @@ impl Persistence {
let mut parent = crate::schema::tenant_shards::table
.filter(tenant_id.eq(parent_shard_id.tenant_id.to_string()))
.filter(shard_number.eq(parent_shard_id.shard_number.0 as i32))
.filter(shard_count.eq(parent_shard_id.shard_count.0 as i32))
.filter(shard_count.eq(parent_shard_id.shard_count.literal() as i32))
.load::<TenantShardPersistence>(conn)?;
let parent = if parent.len() != 1 {
return Err(DatabaseError::Logical(format!(
@@ -459,7 +457,7 @@ impl Persistence {
// Drop parent shards
diesel::delete(tenant_shards)
.filter(tenant_id.eq(split_tenant_id.to_string()))
.filter(shard_count.eq(old_shard_count.0 as i32))
.filter(shard_count.eq(old_shard_count.literal() as i32))
.execute(conn)?;
// Clear sharding flag

View File

@@ -263,7 +263,7 @@ impl Reconciler {
secondary_conf,
tenant_conf: config.clone(),
shard_number: shard.number.0,
shard_count: shard.count.0,
shard_count: shard.count.literal(),
shard_stripe_size: shard.stripe_size.0,
}
}
@@ -458,7 +458,7 @@ impl Reconciler {
generation: None,
secondary_conf: None,
shard_number: self.shard.number.0,
shard_count: self.shard.count.0,
shard_count: self.shard.count.literal(),
shard_stripe_size: self.shard.stripe_size.0,
tenant_conf: self.config.clone(),
},
@@ -506,7 +506,7 @@ pub(crate) fn attached_location_conf(
generation: generation.into(),
secondary_conf: None,
shard_number: shard.number.0,
shard_count: shard.count.0,
shard_count: shard.count.literal(),
shard_stripe_size: shard.stripe_size.0,
tenant_conf: config.clone(),
}
@@ -521,7 +521,7 @@ pub(crate) fn secondary_location_conf(
generation: None,
secondary_conf: Some(LocationConfigSecondary { warm: true }),
shard_number: shard.number.0,
shard_count: shard.count.0,
shard_count: shard.count.literal(),
shard_stripe_size: shard.stripe_size.0,
tenant_conf: config.clone(),
}

View File

@@ -292,7 +292,7 @@ impl Service {
generation: None,
secondary_conf: None,
shard_number: tenant_shard_id.shard_number.0,
shard_count: tenant_shard_id.shard_count.0,
shard_count: tenant_shard_id.shard_count.literal(),
shard_stripe_size: 0,
tenant_conf: models::TenantConfig::default(),
},
@@ -389,14 +389,14 @@ impl Service {
let tenant_shard_id = TenantShardId {
tenant_id: TenantId::from_str(tsp.tenant_id.as_str())?,
shard_number: ShardNumber(tsp.shard_number as u8),
shard_count: ShardCount(tsp.shard_count as u8),
shard_count: ShardCount::new(tsp.shard_count as u8),
};
let shard_identity = if tsp.shard_count == 0 {
ShardIdentity::unsharded()
} else {
ShardIdentity::new(
ShardNumber(tsp.shard_number as u8),
ShardCount(tsp.shard_count as u8),
ShardCount::new(tsp.shard_count as u8),
ShardStripeSize(tsp.shard_stripe_size as u32),
)?
};
@@ -526,7 +526,7 @@ impl Service {
let tsp = TenantShardPersistence {
tenant_id: attach_req.tenant_shard_id.tenant_id.to_string(),
shard_number: attach_req.tenant_shard_id.shard_number.0 as i32,
shard_count: attach_req.tenant_shard_id.shard_count.0 as i32,
shard_count: attach_req.tenant_shard_id.shard_count.literal() as i32,
shard_stripe_size: 0,
generation: 0,
generation_pageserver: i64::MAX,
@@ -726,16 +726,9 @@ impl Service {
&self,
create_req: TenantCreateRequest,
) -> Result<TenantCreateResponse, ApiError> {
// Shard count 0 is valid: it means create a single shard (ShardCount(0) means "unsharded")
let literal_shard_count = if create_req.shard_parameters.is_unsharded() {
1
} else {
create_req.shard_parameters.count.0
};
// This service expects to handle sharding itself: it is an error to try and directly create
// a particular shard here.
let tenant_id = if create_req.new_tenant_id.shard_count > ShardCount(1) {
let tenant_id = if !create_req.new_tenant_id.is_unsharded() {
return Err(ApiError::BadRequest(anyhow::anyhow!(
"Attempted to create a specific shard, this API is for creating the whole tenant"
)));
@@ -749,7 +742,7 @@ impl Service {
create_req.shard_parameters.count,
);
let create_ids = (0..literal_shard_count)
let create_ids = (0..create_req.shard_parameters.count.count())
.map(|i| TenantShardId {
tenant_id,
shard_number: ShardNumber(i),
@@ -769,7 +762,7 @@ impl Service {
.map(|tenant_shard_id| TenantShardPersistence {
tenant_id: tenant_shard_id.tenant_id.to_string(),
shard_number: tenant_shard_id.shard_number.0 as i32,
shard_count: tenant_shard_id.shard_count.0 as i32,
shard_count: tenant_shard_id.shard_count.literal() as i32,
shard_stripe_size: create_req.shard_parameters.stripe_size.0 as i32,
generation: create_req.generation.map(|g| g as i32).unwrap_or(0),
generation_pageserver: i64::MAX,
@@ -914,7 +907,7 @@ impl Service {
tenant_id: TenantId,
req: TenantLocationConfigRequest,
) -> Result<TenantLocationConfigResponse, ApiError> {
if req.tenant_id.shard_count.0 > 1 {
if !req.tenant_id.is_unsharded() {
return Err(ApiError::BadRequest(anyhow::anyhow!(
"This API is for importing single-sharded or unsharded tenants"
)));
@@ -1449,7 +1442,7 @@ impl Service {
for (tenant_shard_id, shard) in
locked.tenants.range(TenantShardId::tenant_range(tenant_id))
{
match shard.shard.count.0.cmp(&split_req.new_shard_count) {
match shard.shard.count.count().cmp(&split_req.new_shard_count) {
Ordering::Equal => {
// Already split this
children_found.push(*tenant_shard_id);
@@ -1459,7 +1452,7 @@ impl Service {
return Err(ApiError::BadRequest(anyhow::anyhow!(
"Requested count {} but already have shards at count {}",
split_req.new_shard_count,
shard.shard.count.0
shard.shard.count.count()
)));
}
Ordering::Less => {
@@ -1489,7 +1482,7 @@ impl Service {
shard_ident = Some(shard.shard);
}
if tenant_shard_id.shard_count == ShardCount(split_req.new_shard_count) {
if tenant_shard_id.shard_count.count() == split_req.new_shard_count {
tracing::info!(
"Tenant shard {} already has shard count {}",
tenant_shard_id,
@@ -1515,7 +1508,7 @@ impl Service {
targets.push(SplitTarget {
parent_id: *tenant_shard_id,
node: node.clone(),
child_ids: tenant_shard_id.split(ShardCount(split_req.new_shard_count)),
child_ids: tenant_shard_id.split(ShardCount::new(split_req.new_shard_count)),
});
}
@@ -1562,7 +1555,7 @@ impl Service {
this_child_tsps.push(TenantShardPersistence {
tenant_id: child.tenant_id.to_string(),
shard_number: child.shard_number.0 as i32,
shard_count: child.shard_count.0 as i32,
shard_count: child.shard_count.literal() as i32,
shard_stripe_size: shard_ident.stripe_size.0 as i32,
// Note: this generation is a placeholder, [`Persistence::begin_shard_split`] will
// populate the correct generation as part of its transaction, to protect us

View File

@@ -450,7 +450,7 @@ async fn handle_tenant(
new_tenant_id: TenantShardId::unsharded(tenant_id),
generation: None,
shard_parameters: ShardParameters {
count: ShardCount(shard_count),
count: ShardCount::new(shard_count),
stripe_size: shard_stripe_size
.map(ShardStripeSize)
.unwrap_or(ShardParameters::DEFAULT_STRIPE_SIZE),

View File

@@ -214,14 +214,14 @@ impl ShardParameters {
pub const DEFAULT_STRIPE_SIZE: ShardStripeSize = ShardStripeSize(256 * 1024 / 8);
pub fn is_unsharded(&self) -> bool {
self.count == ShardCount(0)
self.count.is_unsharded()
}
}
impl Default for ShardParameters {
fn default() -> Self {
Self {
count: ShardCount(0),
count: ShardCount::new(0),
stripe_size: Self::DEFAULT_STRIPE_SIZE,
}
}

View File

@@ -13,10 +13,41 @@ use utils::id::TenantId;
pub struct ShardNumber(pub u8);
#[derive(Ord, PartialOrd, Eq, PartialEq, Clone, Copy, Serialize, Deserialize, Debug, Hash)]
pub struct ShardCount(pub u8);
pub struct ShardCount(u8);
impl ShardCount {
pub const MAX: Self = Self(u8::MAX);
/// The internal value of a ShardCount may be zero, which means "1 shard, but use
/// legacy format for TenantShardId that excludes the shard suffix", also known
/// as `TenantShardId::unsharded`.
///
/// This method returns the actual number of shards, i.e. if our internal value is
/// zero, we return 1 (unsharded tenants have 1 shard).
pub fn count(&self) -> u8 {
if self.0 > 0 {
self.0
} else {
1
}
}
/// The literal internal value: this is **not** the number of shards in the
/// tenant, as we have a special zero value for legacy unsharded tenants. Use
/// [`Self::count`] if you want to know the cardinality of shards.
pub fn literal(&self) -> u8 {
self.0
}
pub fn is_unsharded(&self) -> bool {
self.0 == 0
}
/// `v` may be zero, or the number of shards in the tenant. `v` is what
/// [`Self::literal`] would return.
pub fn new(val: u8) -> Self {
Self(val)
}
}
impl ShardNumber {
@@ -86,7 +117,7 @@ impl TenantShardId {
}
pub fn is_unsharded(&self) -> bool {
self.shard_number == ShardNumber(0) && self.shard_count == ShardCount(0)
self.shard_number == ShardNumber(0) && self.shard_count.is_unsharded()
}
/// Convenience for dropping the tenant_id and just getting the ShardIndex: this

View File

@@ -1136,7 +1136,7 @@ async fn tenant_shard_split_handler(
let new_shards = state
.tenant_manager
.shard_split(tenant_shard_id, ShardCount(req.new_shard_count), &ctx)
.shard_split(tenant_shard_id, ShardCount::new(req.new_shard_count), &ctx)
.await
.map_err(ApiError::InternalServerError)?;

View File

@@ -26,7 +26,7 @@ use pageserver_api::models::{
PagestreamNblocksResponse,
};
use pageserver_api::shard::ShardIndex;
use pageserver_api::shard::{ShardCount, ShardNumber};
use pageserver_api::shard::ShardNumber;
use postgres_backend::{self, is_expected_io_error, AuthType, PostgresBackend, QueryError};
use pq_proto::framed::ConnectionError;
use pq_proto::FeStartupPacket;
@@ -998,7 +998,7 @@ impl PageServerHandler {
) -> Result<&Arc<Timeline>, Key> {
let key = if let Some((first_idx, first_timeline)) = self.shard_timelines.iter().next() {
// Fastest path: single sharded case
if first_idx.shard_count < ShardCount(2) {
if first_idx.shard_count.count() == 1 {
return Ok(&first_timeline.timeline);
}

View File

@@ -2370,7 +2370,7 @@ impl Tenant {
generation: self.generation.into(),
secondary_conf: None,
shard_number: self.shard_identity.number.0,
shard_count: self.shard_identity.count.0,
shard_count: self.shard_identity.count.literal(),
shard_stripe_size: self.shard_identity.stripe_size.0,
tenant_conf: tenant_config,
}

View File

@@ -251,7 +251,7 @@ impl LocationConf {
} else {
ShardIdentity::new(
ShardNumber(conf.shard_number),
ShardCount(conf.shard_count),
ShardCount::new(conf.shard_count),
ShardStripeSize(conf.shard_stripe_size),
)?
};

View File

@@ -794,7 +794,7 @@ pub(crate) async fn set_new_tenant_config(
info!("configuring tenant {tenant_id}");
let tenant = get_tenant(tenant_shard_id, true)?;
if tenant.tenant_shard_id().shard_count > ShardCount(0) {
if !tenant.tenant_shard_id().shard_count.is_unsharded() {
// Note that we use ShardParameters::default below.
return Err(SetNewTenantConfigError::Other(anyhow::anyhow!(
"This API may only be used on single-sharded tenants, use the /location_config API for sharded tenants"
@@ -1376,7 +1376,7 @@ impl TenantManager {
result
}
#[instrument(skip_all, fields(tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(), new_shard_count=%new_shard_count.0))]
#[instrument(skip_all, fields(tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(), new_shard_count=%new_shard_count.literal()))]
pub(crate) async fn shard_split(
&self,
tenant_shard_id: TenantShardId,
@@ -1386,11 +1386,10 @@ impl TenantManager {
let tenant = get_tenant(tenant_shard_id, true)?;
// Plan: identify what the new child shards will be
let effective_old_shard_count = std::cmp::max(tenant_shard_id.shard_count.0, 1);
if new_shard_count <= ShardCount(effective_old_shard_count) {
if new_shard_count.count() <= tenant_shard_id.shard_count.count() {
anyhow::bail!("Requested shard count is not an increase");
}
let expansion_factor = new_shard_count.0 / effective_old_shard_count;
let expansion_factor = new_shard_count.count() / tenant_shard_id.shard_count.count();
if !expansion_factor.is_power_of_two() {
anyhow::bail!("Requested split is not a power of two");
}

View File

@@ -150,7 +150,7 @@ impl SecondaryTenant {
generation: None,
secondary_conf: Some(conf),
shard_number: self.tenant_shard_id.shard_number.0,
shard_count: self.tenant_shard_id.shard_count.0,
shard_count: self.tenant_shard_id.shard_count.literal(),
shard_stripe_size: self.shard_identity.stripe_size.0,
tenant_conf: tenant_conf.into(),
}