feat: generalize aggregate repartition reduction

This commit is contained in:
Ruihang Xia
2026-03-28 10:08:24 +08:00
parent dd1781f412
commit b5d83bf087
3 changed files with 704 additions and 26 deletions

View File

@@ -22,14 +22,14 @@ use datafusion::physical_plan::repartition::RepartitionExec;
use datafusion::physical_plan::{ExecutionPlan, ExecutionPlanProperties, InputOrderMode};
use datafusion_common::Result as DfResult;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_physical_expr::{Partitioning, physical_exprs_equal};
use datafusion_physical_expr::{Distribution, Partitioning};
/// Replaces a redundant hash repartition before a coarser aggregate with a
/// single fan-in.
///
/// This only applies when the aggregate already receives explicit hash
/// partitioning on its final grouping keys and the repartition input is already
/// hash partitioned on a strict superset of those keys.
/// This only applies when the aggregate already receives hash partitioning
/// satisfying its grouping keys and the repartition input is already hash
/// partitioned on a strict superset of those repartition keys.
#[derive(Debug)]
pub struct ReduceAggregateRepartition;
@@ -80,21 +80,35 @@ impl ReduceAggregateRepartition {
return Ok(Transformed::no(plan));
}
let Partitioning::Hash(final_partition_exprs, _) = repartition_exec.partitioning()
else {
return Ok(Transformed::no(plan));
};
let Partitioning::Hash(finer_partition_exprs, _) =
repartition_exec.input().output_partitioning()
else {
return Ok(Transformed::no(plan));
};
let group_exprs = agg_exec.group_expr().input_exprs();
if !physical_exprs_equal(group_exprs.as_slice(), final_partition_exprs.as_slice())
|| !is_strict_subset(final_partition_exprs, finer_partition_exprs)
{
let Some(required_distribution) =
agg_exec.required_input_distribution().into_iter().next()
else {
return Ok(Transformed::no(plan));
};
let repartition_satisfaction = repartition_exec.partitioning().satisfaction(
&required_distribution,
repartition_exec.properties().equivalence_properties(),
true,
);
if !repartition_satisfaction.is_satisfied() {
return Ok(Transformed::no(plan));
}
let coarsening_satisfaction = repartition_exec.partitioning().satisfaction(
&Distribution::HashPartitioned(finer_partition_exprs.clone()),
repartition_exec
.input()
.properties()
.equivalence_properties(),
true,
);
if !coarsening_satisfaction.is_subset() {
return Ok(Transformed::no(plan));
}
@@ -117,19 +131,6 @@ impl ReduceAggregateRepartition {
}
}
fn is_strict_subset(
subset_exprs: &[Arc<dyn datafusion_physical_expr::PhysicalExpr>],
superset_exprs: &[Arc<dyn datafusion_physical_expr::PhysicalExpr>],
) -> bool {
if subset_exprs.is_empty() || subset_exprs.len() >= superset_exprs.len() {
return false;
}
subset_exprs
.iter()
.all(|subset_expr| superset_exprs.iter().any(|expr| subset_expr.eq(expr)))
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
@@ -139,6 +140,7 @@ mod tests {
use datafusion::datasource::source::DataSourceExec;
use datafusion::physical_optimizer::PhysicalOptimizerRule;
use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
use datafusion::physical_plan::projection::ProjectionExec;
use datafusion::physical_plan::repartition::RepartitionExec;
use datafusion::physical_plan::{ExecutionPlan, displayable};
use datafusion_common::Result;
@@ -185,6 +187,13 @@ mod tests {
)?))
}
fn round_robin_repartition(input: Arc<dyn ExecutionPlan>) -> Result<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(RepartitionExec::try_new(
input,
Partitioning::RoundRobinBatch(8),
)?))
}
fn aggregate(
mode: AggregateMode,
input: Arc<dyn ExecutionPlan>,
@@ -208,6 +217,17 @@ mod tests {
Ok(displayable(optimized.as_ref()).indent(true).to_string())
}
fn project_with_aliases(
input: Arc<dyn ExecutionPlan>,
aliases: &[(&str, &str)],
) -> Result<Arc<dyn ExecutionPlan>> {
let exprs: Result<Vec<(Arc<dyn PhysicalExpr>, String)>> = aliases
.iter()
.map(|(from, to)| Ok((col(from, &input.schema())?, (*to).to_string())))
.collect();
Ok(Arc::new(ProjectionExec::try_new(exprs?, input)?))
}
#[test]
fn rewrites_final_partitioned_subset_repartition() -> Result<()> {
let raw = input_exec(schema());
@@ -292,4 +312,140 @@ mod tests {
);
Ok(())
}
#[test]
fn rewrites_when_finer_key_order_differs() -> Result<()> {
let raw = input_exec(schema());
let finer = repartition(raw.clone(), &["c", "a", "b"], &raw.schema())?;
let partial = aggregate(
AggregateMode::Partial,
finer,
group_by(&["c", "a", "b"], &raw.schema())?,
raw.schema(),
vec![],
)?;
let final_repartition = repartition(partial.clone(), &["b", "c"], &partial.schema())?;
let final_agg = aggregate(
AggregateMode::FinalPartitioned,
final_repartition,
group_by(&["b", "c"], &partial.schema())?,
raw.schema(),
vec![],
)?;
assert_eq!(
optimize(final_agg)?.trim(),
r#"AggregateExec: mode=Final, gby=[b@2 as b, c@0 as c], aggr=[]
CoalescePartitionsExec
AggregateExec: mode=Partial, gby=[c@2 as c, a@0 as a, b@1 as b], aggr=[]
RepartitionExec: partitioning=Hash([c@2, a@0, b@1], 8), input_partitions=1
DataSourceExec: partitions=1, partition_sizes=[0]"#
);
Ok(())
}
#[test]
fn rewrites_when_repartition_satisfies_group_by_with_subset_keys() -> Result<()> {
let raw = input_exec(schema());
let finer = repartition(raw.clone(), &["a", "b", "c"], &raw.schema())?;
let final_repartition = repartition(finer.clone(), &["a"], &finer.schema())?;
let final_agg = aggregate(
AggregateMode::FinalPartitioned,
final_repartition,
group_by(&["a", "b"], &finer.schema())?,
raw.schema(),
vec![],
)?;
assert_eq!(
optimize(final_agg)?.trim(),
r#"AggregateExec: mode=Final, gby=[a@0 as a, b@1 as b], aggr=[]
CoalescePartitionsExec
RepartitionExec: partitioning=Hash([a@0, b@1, c@2], 8), input_partitions=1
DataSourceExec: partitions=1, partition_sizes=[0]"#
);
Ok(())
}
#[test]
fn keeps_non_hash_repartition_child() -> Result<()> {
let raw = input_exec(schema());
let finer = repartition(raw.clone(), &["a", "b"], &raw.schema())?;
let partial = aggregate(
AggregateMode::Partial,
finer,
group_by(&["a", "b"], &raw.schema())?,
raw.schema(),
vec![],
)?;
let final_repartition = round_robin_repartition(partial.clone())?;
let final_agg = aggregate(
AggregateMode::FinalPartitioned,
final_repartition,
group_by(&["a"], &partial.schema())?,
raw.schema(),
vec![],
)?;
assert_eq!(
optimize(final_agg)?.trim(),
r#"AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]
RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=8
AggregateExec: mode=Partial, gby=[a@0 as a, b@1 as b], aggr=[]
RepartitionExec: partitioning=Hash([a@0, b@1], 8), input_partitions=1
DataSourceExec: partitions=1, partition_sizes=[0]"#
);
Ok(())
}
#[test]
fn rewrites_subset_partitioning_through_projection() -> Result<()> {
let raw = input_exec(schema());
let finer = repartition(raw.clone(), &["a", "b", "c"], &raw.schema())?;
let projected = project_with_aliases(finer, &[("a", "x"), ("b", "y"), ("c", "z")])?;
let final_repartition = repartition(projected.clone(), &["x", "y"], &projected.schema())?;
let final_agg = aggregate(
AggregateMode::SinglePartitioned,
final_repartition,
group_by(&["x", "y"], &projected.schema())?,
projected.schema(),
vec![],
)?;
let optimized = optimize(final_agg)?;
assert!(
optimized.contains("AggregateExec: mode=Single, gby=[x@0 as x, y@1 as y], aggr=[]")
);
assert!(optimized.contains("CoalescePartitionsExec"));
assert!(optimized.contains("ProjectionExec: expr=[a@0 as x, b@1 as y, c@2 as z]"));
Ok(())
}
#[test]
fn keeps_non_subset_repartition() -> Result<()> {
let raw = input_exec(schema());
let coarser = repartition(raw.clone(), &["a"], &raw.schema())?;
let final_repartition = repartition(coarser.clone(), &["a", "b"], &coarser.schema())?;
let final_agg = aggregate(
AggregateMode::FinalPartitioned,
final_repartition,
group_by(&["a", "b"], &coarser.schema())?,
raw.schema(),
vec![],
)?;
let optimized = optimize(final_agg)?;
assert!(
optimized.contains(
"AggregateExec: mode=FinalPartitioned, gby=[a@0 as a, b@1 as b], aggr=[]"
),
"{optimized}"
);
assert!(
optimized.contains("RepartitionExec: partitioning=Hash([a@0, b@1], 8)"),
"{optimized}"
);
assert!(!optimized.contains("CoalescePartitionsExec"), "{optimized}");
Ok(())
}
}