feat: combine coalesced min max aggregates

This commit is contained in:
Ruihang Xia
2026-03-28 05:58:15 +08:00
parent 44144df6d6
commit bde03d47d1
2 changed files with 202 additions and 8 deletions

View File

@@ -22,7 +22,8 @@ use datafusion::physical_plan::repartition::RepartitionExec;
use datafusion::physical_plan::{ExecutionPlan, ExecutionPlanProperties, InputOrderMode};
use datafusion_common::Result as DfResult;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_physical_expr::{Partitioning, physical_exprs_equal};
use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
use datafusion_physical_expr::{Partitioning, PhysicalExpr, physical_exprs_equal};
/// Replaces a redundant hash repartition before a coarser aggregate with a
/// single fan-in.
@@ -58,6 +59,10 @@ impl ReduceAggregateRepartition {
return Ok(Transformed::no(plan));
};
if let Some(new_plan) = Self::combine_coalesced_min_max_partial_final(agg_exec)? {
return Ok(Transformed::yes(new_plan));
}
let new_mode = match agg_exec.mode() {
AggregateMode::FinalPartitioned => AggregateMode::Final,
AggregateMode::SinglePartitioned => AggregateMode::Single,
@@ -115,6 +120,107 @@ impl ReduceAggregateRepartition {
})
.data()
}
fn combine_coalesced_min_max_partial_final(
agg_exec: &AggregateExec,
) -> DfResult<Option<Arc<dyn ExecutionPlan>>> {
if agg_exec.mode() != &AggregateMode::Final {
return Ok(None);
}
let Some(coalesce_exec) = agg_exec
.input()
.as_any()
.downcast_ref::<CoalescePartitionsExec>()
else {
return Ok(None);
};
let Some(partial_exec) = coalesce_exec
.input()
.as_any()
.downcast_ref::<AggregateExec>()
else {
return Ok(None);
};
if *partial_exec.mode() != AggregateMode::Partial
|| !supports_min_max_family(agg_exec.aggr_expr())
|| !supports_min_max_family(partial_exec.aggr_expr())
|| !can_combine(
(
agg_exec.group_expr(),
agg_exec.aggr_expr(),
agg_exec.filter_expr(),
),
(
partial_exec.group_expr(),
partial_exec.aggr_expr(),
partial_exec.filter_expr(),
),
)
{
return Ok(None);
}
let new_input = Arc::new(CoalescePartitionsExec::new(partial_exec.input().clone()));
let new_agg = AggregateExec::try_new(
AggregateMode::Single,
partial_exec.group_expr().clone(),
partial_exec.aggr_expr().to_vec(),
partial_exec.filter_expr().to_vec(),
new_input,
partial_exec.input_schema(),
)?
.with_limit_options(agg_exec.limit_options());
Ok(Some(Arc::new(new_agg)))
}
}
type GroupExprsRef<'a> = (
&'a datafusion::physical_plan::aggregates::PhysicalGroupBy,
&'a [Arc<AggregateFunctionExpr>],
&'a [Option<Arc<dyn PhysicalExpr>>],
);
fn can_combine(final_agg: GroupExprsRef<'_>, partial_agg: GroupExprsRef<'_>) -> bool {
let (final_group_by, final_aggr_expr, final_filter_expr) = final_agg;
let (input_group_by, input_aggr_expr, input_filter_expr) = partial_agg;
physical_exprs_equal(
&input_group_by.output_exprs(),
&final_group_by.input_exprs(),
) && input_group_by.groups() == final_group_by.groups()
&& input_group_by.null_expr().len() == final_group_by.null_expr().len()
&& input_group_by
.null_expr()
.iter()
.zip(final_group_by.null_expr().iter())
.all(|((lhs_expr, lhs_str), (rhs_expr, rhs_str))| {
lhs_expr.eq(rhs_expr) && lhs_str == rhs_str
})
&& final_aggr_expr.len() == input_aggr_expr.len()
&& final_aggr_expr
.iter()
.zip(input_aggr_expr.iter())
.all(|(final_expr, partial_expr)| final_expr.eq(partial_expr))
&& final_filter_expr.len() == input_filter_expr.len()
&& final_filter_expr.iter().zip(input_filter_expr.iter()).all(
|(final_expr, partial_expr)| match (final_expr, partial_expr) {
(Some(l), Some(r)) => l.eq(r),
(None, None) => true,
_ => false,
},
)
}
fn supports_min_max_family(aggr_exprs: &[Arc<AggregateFunctionExpr>]) -> bool {
!aggr_exprs.is_empty()
&& aggr_exprs.iter().all(|expr| {
let name = expr.fun().name();
name.eq_ignore_ascii_case("min") || name.eq_ignore_ascii_case("max")
})
}
fn is_strict_subset(
@@ -137,13 +243,17 @@ mod tests {
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion::datasource::memory::MemorySourceConfig;
use datafusion::datasource::source::DataSourceExec;
use datafusion::functions_aggregate::count::count_udaf;
use datafusion::functions_aggregate::min_max::min_udaf;
use datafusion::physical_optimizer::PhysicalOptimizerRule;
use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use datafusion::physical_plan::repartition::RepartitionExec;
use datafusion::physical_plan::{ExecutionPlan, displayable};
use datafusion_common::Result;
use datafusion_physical_expr::aggregate::AggregateExprBuilder;
use datafusion_physical_expr::expressions::col;
use datafusion_physical_expr::{Partitioning, PhysicalExpr};
use datafusion_physical_expr::{Partitioning, PhysicalExpr, aggregate::AggregateFunctionExpr};
use pretty_assertions::assert_eq;
use super::ReduceAggregateRepartition;
@@ -189,17 +299,37 @@ mod tests {
input: Arc<dyn ExecutionPlan>,
group_by: PhysicalGroupBy,
input_schema: SchemaRef,
aggr_expr: Vec<Arc<AggregateFunctionExpr>>,
) -> Result<Arc<dyn ExecutionPlan>> {
let filter_expr = vec![None; aggr_expr.len()];
Ok(Arc::new(AggregateExec::try_new(
mode,
group_by,
vec![],
vec![],
aggr_expr,
filter_expr,
input,
input_schema,
)?))
}
fn min_expr(name: &str, schema: &SchemaRef) -> Result<Arc<AggregateFunctionExpr>> {
Ok(Arc::new(
AggregateExprBuilder::new(min_udaf(), vec![col(name, schema)?])
.schema(schema.clone())
.alias(format!("min({name})"))
.build()?,
))
}
fn count_expr(name: &str, schema: &SchemaRef) -> Result<Arc<AggregateFunctionExpr>> {
Ok(Arc::new(
AggregateExprBuilder::new(count_udaf(), vec![col(name, schema)?])
.schema(schema.clone())
.alias(format!("count({name})"))
.build()?,
))
}
fn optimize(plan: Arc<dyn ExecutionPlan>) -> Result<String> {
let optimized = ReduceAggregateRepartition.optimize(plan, &Default::default())?;
Ok(displayable(optimized.as_ref()).indent(true).to_string())
@@ -214,6 +344,7 @@ mod tests {
finer,
group_by(&["a", "b"], &raw.schema())?,
raw.schema(),
vec![],
)?;
let final_repartition = repartition(partial.clone(), &["a"], &partial.schema())?;
let final_agg = aggregate(
@@ -221,6 +352,7 @@ mod tests {
final_repartition,
group_by(&["a"], &partial.schema())?,
raw.schema(),
vec![],
)?;
assert_eq!(
@@ -244,6 +376,7 @@ mod tests {
final_repartition,
group_by(&["a"], &finer.schema())?,
raw.schema(),
vec![],
)?;
assert_eq!(
@@ -265,6 +398,7 @@ mod tests {
finer,
group_by(&["a", "b"], &raw.schema())?,
raw.schema(),
vec![],
)?;
let final_repartition = repartition(partial.clone(), &["a", "b"], &partial.schema())?;
let final_agg = aggregate(
@@ -272,6 +406,7 @@ mod tests {
final_repartition,
group_by(&["a", "b"], &partial.schema())?,
raw.schema(),
vec![],
)?;
assert_eq!(
@@ -284,4 +419,65 @@ mod tests {
);
Ok(())
}
#[test]
fn combines_coalesced_partial_final_for_min() -> Result<()> {
let raw = input_exec(schema());
let finer = repartition(raw.clone(), &["a", "b"], &raw.schema())?;
let aggr_expr = vec![min_expr("c", &raw.schema())?];
let partial = aggregate(
AggregateMode::Partial,
finer,
group_by(&["a"], &raw.schema())?,
raw.schema(),
aggr_expr.clone(),
)?;
let final_agg = aggregate(
AggregateMode::Final,
Arc::new(CoalescePartitionsExec::new(partial)),
group_by(&["a"], &raw.schema())?,
raw.schema(),
aggr_expr,
)?;
assert_eq!(
optimize(final_agg)?.trim(),
r#"AggregateExec: mode=Single, gby=[a@0 as a], aggr=[min(c)]
CoalescePartitionsExec
RepartitionExec: partitioning=Hash([a@0, b@1], 8), input_partitions=1
DataSourceExec: partitions=1, partition_sizes=[0]"#
);
Ok(())
}
#[test]
fn keeps_coalesced_partial_final_for_count() -> Result<()> {
let raw = input_exec(schema());
let finer = repartition(raw.clone(), &["a", "b"], &raw.schema())?;
let aggr_expr = vec![count_expr("c", &raw.schema())?];
let partial = aggregate(
AggregateMode::Partial,
finer,
group_by(&["a"], &raw.schema())?,
raw.schema(),
aggr_expr.clone(),
)?;
let final_agg = aggregate(
AggregateMode::Final,
Arc::new(CoalescePartitionsExec::new(partial)),
group_by(&["a"], &raw.schema())?,
raw.schema(),
aggr_expr,
)?;
assert_eq!(
optimize(final_agg)?.trim(),
r#"AggregateExec: mode=Final, gby=[a@0 as a], aggr=[count(c)]
CoalescePartitionsExec
AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count(c)]
RepartitionExec: partitioning=Hash([a@0, b@1], 8), input_partitions=1
DataSourceExec: partitions=1, partition_sizes=[0]"#
);
Ok(())
}
}

View File

@@ -58,9 +58,8 @@ tql analyze (1752591864, 1752592164, '30s') max by (a, b, c) (max_over_time(aggr
|_|_|_MergeScanExec: REDACTED
|_|_|_|
| 1_| 0_|_SortExec: expr=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST, greptime_timestamp@3 ASC NULLS LAST], preserve_partitioning=[false] REDACTED
|_|_|_AggregateExec: mode=Final, gby=[a@0 as a, b@1 as b, c@2 as c, greptime_timestamp@3 as greptime_timestamp], aggr=[max(prom_max_over_time(greptime_timestamp_range,greptime_value))] REDACTED
|_|_|_AggregateExec: mode=Single, gby=[a@2 as a, b@3 as b, c@4 as c, greptime_timestamp@0 as greptime_timestamp], aggr=[max(prom_max_over_time(greptime_timestamp_range,greptime_value))] REDACTED
|_|_|_CoalescePartitionsExec REDACTED
|_|_|_AggregateExec: mode=Partial, gby=[a@2 as a, b@3 as b, c@4 as c, greptime_timestamp@0 as greptime_timestamp], aggr=[max(prom_max_over_time(greptime_timestamp_range,greptime_value))] REDACTED
|_|_|_FilterExec: prom_max_over_time(greptime_timestamp_range,greptime_value)@1 IS NOT NULL REDACTED
|_|_|_ProjectionExec: expr=[greptime_timestamp@4 as greptime_timestamp, prom_max_over_time(greptime_timestamp_range@6, greptime_value@5) as prom_max_over_time(greptime_timestamp_range,greptime_value), a@0 as a, b@1 as b, c@2 as c] REDACTED
|_|_|_PromRangeManipulateExec: req range=[1752591864000..1752592164000], interval=[30000], eval range=[120000], time index=[greptime_timestamp] REDACTED
@@ -69,9 +68,8 @@ tql analyze (1752591864, 1752592164, '30s') max by (a, b, c) (max_over_time(aggr
|_|_|_SeriesScan: region=REDACTED, "partition_count":{"count":0, "mem_ranges":0, "files":0, "file_ranges":0}, "distribution":"PerSeries" REDACTED
|_|_|_|
| 1_| 1_|_SortExec: expr=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST, greptime_timestamp@3 ASC NULLS LAST], preserve_partitioning=[false] REDACTED
|_|_|_AggregateExec: mode=Final, gby=[a@0 as a, b@1 as b, c@2 as c, greptime_timestamp@3 as greptime_timestamp], aggr=[max(prom_max_over_time(greptime_timestamp_range,greptime_value))] REDACTED
|_|_|_AggregateExec: mode=Single, gby=[a@2 as a, b@3 as b, c@4 as c, greptime_timestamp@0 as greptime_timestamp], aggr=[max(prom_max_over_time(greptime_timestamp_range,greptime_value))] REDACTED
|_|_|_CoalescePartitionsExec REDACTED
|_|_|_AggregateExec: mode=Partial, gby=[a@2 as a, b@3 as b, c@4 as c, greptime_timestamp@0 as greptime_timestamp], aggr=[max(prom_max_over_time(greptime_timestamp_range,greptime_value))] REDACTED
|_|_|_FilterExec: prom_max_over_time(greptime_timestamp_range,greptime_value)@1 IS NOT NULL REDACTED
|_|_|_ProjectionExec: expr=[greptime_timestamp@4 as greptime_timestamp, prom_max_over_time(greptime_timestamp_range@6, greptime_value@5) as prom_max_over_time(greptime_timestamp_range,greptime_value), a@0 as a, b@1 as b, c@2 as c] REDACTED
|_|_|_PromRangeManipulateExec: req range=[1752591864000..1752592164000], interval=[30000], eval range=[120000], time index=[greptime_timestamp] REDACTED