diff --git a/src/query/src/optimizer/reduce_aggregate_repartition.rs b/src/query/src/optimizer/reduce_aggregate_repartition.rs index 722268b503..1bc7ae078e 100644 --- a/src/query/src/optimizer/reduce_aggregate_repartition.rs +++ b/src/query/src/optimizer/reduce_aggregate_repartition.rs @@ -22,7 +22,8 @@ use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::{ExecutionPlan, ExecutionPlanProperties, InputOrderMode}; use datafusion_common::Result as DfResult; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_physical_expr::{Partitioning, physical_exprs_equal}; +use datafusion_physical_expr::aggregate::AggregateFunctionExpr; +use datafusion_physical_expr::{Partitioning, PhysicalExpr, physical_exprs_equal}; /// Replaces a redundant hash repartition before a coarser aggregate with a /// single fan-in. @@ -58,6 +59,10 @@ impl ReduceAggregateRepartition { return Ok(Transformed::no(plan)); }; + if let Some(new_plan) = Self::combine_coalesced_min_max_partial_final(agg_exec)? { + return Ok(Transformed::yes(new_plan)); + } + let new_mode = match agg_exec.mode() { AggregateMode::FinalPartitioned => AggregateMode::Final, AggregateMode::SinglePartitioned => AggregateMode::Single, @@ -115,6 +120,107 @@ impl ReduceAggregateRepartition { }) .data() } + + fn combine_coalesced_min_max_partial_final( + agg_exec: &AggregateExec, + ) -> DfResult>> { + if agg_exec.mode() != &AggregateMode::Final { + return Ok(None); + } + + let Some(coalesce_exec) = agg_exec + .input() + .as_any() + .downcast_ref::() + else { + return Ok(None); + }; + + let Some(partial_exec) = coalesce_exec + .input() + .as_any() + .downcast_ref::() + else { + return Ok(None); + }; + + if *partial_exec.mode() != AggregateMode::Partial + || !supports_min_max_family(agg_exec.aggr_expr()) + || !supports_min_max_family(partial_exec.aggr_expr()) + || !can_combine( + ( + agg_exec.group_expr(), + agg_exec.aggr_expr(), + agg_exec.filter_expr(), + ), + ( + partial_exec.group_expr(), + partial_exec.aggr_expr(), + partial_exec.filter_expr(), + ), + ) + { + return Ok(None); + } + + let new_input = Arc::new(CoalescePartitionsExec::new(partial_exec.input().clone())); + let new_agg = AggregateExec::try_new( + AggregateMode::Single, + partial_exec.group_expr().clone(), + partial_exec.aggr_expr().to_vec(), + partial_exec.filter_expr().to_vec(), + new_input, + partial_exec.input_schema(), + )? + .with_limit_options(agg_exec.limit_options()); + + Ok(Some(Arc::new(new_agg))) + } +} + +type GroupExprsRef<'a> = ( + &'a datafusion::physical_plan::aggregates::PhysicalGroupBy, + &'a [Arc], + &'a [Option>], +); + +fn can_combine(final_agg: GroupExprsRef<'_>, partial_agg: GroupExprsRef<'_>) -> bool { + let (final_group_by, final_aggr_expr, final_filter_expr) = final_agg; + let (input_group_by, input_aggr_expr, input_filter_expr) = partial_agg; + + physical_exprs_equal( + &input_group_by.output_exprs(), + &final_group_by.input_exprs(), + ) && input_group_by.groups() == final_group_by.groups() + && input_group_by.null_expr().len() == final_group_by.null_expr().len() + && input_group_by + .null_expr() + .iter() + .zip(final_group_by.null_expr().iter()) + .all(|((lhs_expr, lhs_str), (rhs_expr, rhs_str))| { + lhs_expr.eq(rhs_expr) && lhs_str == rhs_str + }) + && final_aggr_expr.len() == input_aggr_expr.len() + && final_aggr_expr + .iter() + .zip(input_aggr_expr.iter()) + .all(|(final_expr, partial_expr)| final_expr.eq(partial_expr)) + && final_filter_expr.len() == input_filter_expr.len() + && final_filter_expr.iter().zip(input_filter_expr.iter()).all( + |(final_expr, partial_expr)| match (final_expr, partial_expr) { + (Some(l), Some(r)) => l.eq(r), + (None, None) => true, + _ => false, + }, + ) +} + +fn supports_min_max_family(aggr_exprs: &[Arc]) -> bool { + !aggr_exprs.is_empty() + && aggr_exprs.iter().all(|expr| { + let name = expr.fun().name(); + name.eq_ignore_ascii_case("min") || name.eq_ignore_ascii_case("max") + }) } fn is_strict_subset( @@ -137,13 +243,17 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::source::DataSourceExec; + use datafusion::functions_aggregate::count::count_udaf; + use datafusion::functions_aggregate::min_max::min_udaf; use datafusion::physical_optimizer::PhysicalOptimizerRule; use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; + use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::{ExecutionPlan, displayable}; use datafusion_common::Result; + use datafusion_physical_expr::aggregate::AggregateExprBuilder; use datafusion_physical_expr::expressions::col; - use datafusion_physical_expr::{Partitioning, PhysicalExpr}; + use datafusion_physical_expr::{Partitioning, PhysicalExpr, aggregate::AggregateFunctionExpr}; use pretty_assertions::assert_eq; use super::ReduceAggregateRepartition; @@ -189,17 +299,37 @@ mod tests { input: Arc, group_by: PhysicalGroupBy, input_schema: SchemaRef, + aggr_expr: Vec>, ) -> Result> { + let filter_expr = vec![None; aggr_expr.len()]; Ok(Arc::new(AggregateExec::try_new( mode, group_by, - vec![], - vec![], + aggr_expr, + filter_expr, input, input_schema, )?)) } + fn min_expr(name: &str, schema: &SchemaRef) -> Result> { + Ok(Arc::new( + AggregateExprBuilder::new(min_udaf(), vec![col(name, schema)?]) + .schema(schema.clone()) + .alias(format!("min({name})")) + .build()?, + )) + } + + fn count_expr(name: &str, schema: &SchemaRef) -> Result> { + Ok(Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![col(name, schema)?]) + .schema(schema.clone()) + .alias(format!("count({name})")) + .build()?, + )) + } + fn optimize(plan: Arc) -> Result { let optimized = ReduceAggregateRepartition.optimize(plan, &Default::default())?; Ok(displayable(optimized.as_ref()).indent(true).to_string()) @@ -214,6 +344,7 @@ mod tests { finer, group_by(&["a", "b"], &raw.schema())?, raw.schema(), + vec![], )?; let final_repartition = repartition(partial.clone(), &["a"], &partial.schema())?; let final_agg = aggregate( @@ -221,6 +352,7 @@ mod tests { final_repartition, group_by(&["a"], &partial.schema())?, raw.schema(), + vec![], )?; assert_eq!( @@ -244,6 +376,7 @@ mod tests { final_repartition, group_by(&["a"], &finer.schema())?, raw.schema(), + vec![], )?; assert_eq!( @@ -265,6 +398,7 @@ mod tests { finer, group_by(&["a", "b"], &raw.schema())?, raw.schema(), + vec![], )?; let final_repartition = repartition(partial.clone(), &["a", "b"], &partial.schema())?; let final_agg = aggregate( @@ -272,6 +406,7 @@ mod tests { final_repartition, group_by(&["a", "b"], &partial.schema())?, raw.schema(), + vec![], )?; assert_eq!( @@ -284,4 +419,65 @@ mod tests { ); Ok(()) } + + #[test] + fn combines_coalesced_partial_final_for_min() -> Result<()> { + let raw = input_exec(schema()); + let finer = repartition(raw.clone(), &["a", "b"], &raw.schema())?; + let aggr_expr = vec![min_expr("c", &raw.schema())?]; + let partial = aggregate( + AggregateMode::Partial, + finer, + group_by(&["a"], &raw.schema())?, + raw.schema(), + aggr_expr.clone(), + )?; + let final_agg = aggregate( + AggregateMode::Final, + Arc::new(CoalescePartitionsExec::new(partial)), + group_by(&["a"], &raw.schema())?, + raw.schema(), + aggr_expr, + )?; + + assert_eq!( + optimize(final_agg)?.trim(), + r#"AggregateExec: mode=Single, gby=[a@0 as a], aggr=[min(c)] + CoalescePartitionsExec + RepartitionExec: partitioning=Hash([a@0, b@1], 8), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[0]"# + ); + Ok(()) + } + + #[test] + fn keeps_coalesced_partial_final_for_count() -> Result<()> { + let raw = input_exec(schema()); + let finer = repartition(raw.clone(), &["a", "b"], &raw.schema())?; + let aggr_expr = vec![count_expr("c", &raw.schema())?]; + let partial = aggregate( + AggregateMode::Partial, + finer, + group_by(&["a"], &raw.schema())?, + raw.schema(), + aggr_expr.clone(), + )?; + let final_agg = aggregate( + AggregateMode::Final, + Arc::new(CoalescePartitionsExec::new(partial)), + group_by(&["a"], &raw.schema())?, + raw.schema(), + aggr_expr, + )?; + + assert_eq!( + optimize(final_agg)?.trim(), + r#"AggregateExec: mode=Final, gby=[a@0 as a], aggr=[count(c)] + CoalescePartitionsExec + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count(c)] + RepartitionExec: partitioning=Hash([a@0, b@1], 8), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[0]"# + ); + Ok(()) + } } diff --git a/tests/cases/distributed/explain/step_aggr_advance.result b/tests/cases/distributed/explain/step_aggr_advance.result index 1e060ec310..a6888a4890 100644 --- a/tests/cases/distributed/explain/step_aggr_advance.result +++ b/tests/cases/distributed/explain/step_aggr_advance.result @@ -58,9 +58,8 @@ tql analyze (1752591864, 1752592164, '30s') max by (a, b, c) (max_over_time(aggr |_|_|_MergeScanExec: REDACTED |_|_|_| | 1_| 0_|_SortExec: expr=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST, greptime_timestamp@3 ASC NULLS LAST], preserve_partitioning=[false] REDACTED -|_|_|_AggregateExec: mode=Final, gby=[a@0 as a, b@1 as b, c@2 as c, greptime_timestamp@3 as greptime_timestamp], aggr=[max(prom_max_over_time(greptime_timestamp_range,greptime_value))] REDACTED +|_|_|_AggregateExec: mode=Single, gby=[a@2 as a, b@3 as b, c@4 as c, greptime_timestamp@0 as greptime_timestamp], aggr=[max(prom_max_over_time(greptime_timestamp_range,greptime_value))] REDACTED |_|_|_CoalescePartitionsExec REDACTED -|_|_|_AggregateExec: mode=Partial, gby=[a@2 as a, b@3 as b, c@4 as c, greptime_timestamp@0 as greptime_timestamp], aggr=[max(prom_max_over_time(greptime_timestamp_range,greptime_value))] REDACTED |_|_|_FilterExec: prom_max_over_time(greptime_timestamp_range,greptime_value)@1 IS NOT NULL REDACTED |_|_|_ProjectionExec: expr=[greptime_timestamp@4 as greptime_timestamp, prom_max_over_time(greptime_timestamp_range@6, greptime_value@5) as prom_max_over_time(greptime_timestamp_range,greptime_value), a@0 as a, b@1 as b, c@2 as c] REDACTED |_|_|_PromRangeManipulateExec: req range=[1752591864000..1752592164000], interval=[30000], eval range=[120000], time index=[greptime_timestamp] REDACTED @@ -69,9 +68,8 @@ tql analyze (1752591864, 1752592164, '30s') max by (a, b, c) (max_over_time(aggr |_|_|_SeriesScan: region=REDACTED, "partition_count":{"count":0, "mem_ranges":0, "files":0, "file_ranges":0}, "distribution":"PerSeries" REDACTED |_|_|_| | 1_| 1_|_SortExec: expr=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST, greptime_timestamp@3 ASC NULLS LAST], preserve_partitioning=[false] REDACTED -|_|_|_AggregateExec: mode=Final, gby=[a@0 as a, b@1 as b, c@2 as c, greptime_timestamp@3 as greptime_timestamp], aggr=[max(prom_max_over_time(greptime_timestamp_range,greptime_value))] REDACTED +|_|_|_AggregateExec: mode=Single, gby=[a@2 as a, b@3 as b, c@4 as c, greptime_timestamp@0 as greptime_timestamp], aggr=[max(prom_max_over_time(greptime_timestamp_range,greptime_value))] REDACTED |_|_|_CoalescePartitionsExec REDACTED -|_|_|_AggregateExec: mode=Partial, gby=[a@2 as a, b@3 as b, c@4 as c, greptime_timestamp@0 as greptime_timestamp], aggr=[max(prom_max_over_time(greptime_timestamp_range,greptime_value))] REDACTED |_|_|_FilterExec: prom_max_over_time(greptime_timestamp_range,greptime_value)@1 IS NOT NULL REDACTED |_|_|_ProjectionExec: expr=[greptime_timestamp@4 as greptime_timestamp, prom_max_over_time(greptime_timestamp_range@6, greptime_value@5) as prom_max_over_time(greptime_timestamp_range,greptime_value), a@0 as a, b@1 as b, c@2 as c] REDACTED |_|_|_PromRangeManipulateExec: req range=[1752591864000..1752592164000], interval=[30000], eval range=[120000], time index=[greptime_timestamp] REDACTED