From c86574eb8243f27f8ffc0d88bf856d3c91a0fab1 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Sat, 28 Mar 2026 13:40:58 +0800 Subject: [PATCH] put it all into the rule Signed-off-by: Ruihang Xia --- .../optimizer/reduce_aggregate_repartition.rs | 84 +++++++++++++++---- src/query/src/query_engine/state.rs | 42 +--------- 2 files changed, 70 insertions(+), 56 deletions(-) diff --git a/src/query/src/optimizer/reduce_aggregate_repartition.rs b/src/query/src/optimizer/reduce_aggregate_repartition.rs index 50ca1cca7f..981ed5698f 100644 --- a/src/query/src/optimizer/reduce_aggregate_repartition.rs +++ b/src/query/src/optimizer/reduce_aggregate_repartition.rs @@ -23,6 +23,9 @@ use datafusion::physical_plan::{ExecutionPlan, ExecutionPlanProperties, InputOrd use datafusion_common::Result as DfResult; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_physical_expr::{Distribution, Partitioning}; +use promql::extension_plan::{ + InstantManipulateExec, RangeManipulateExec, SeriesDivideExec, SeriesNormalizeExec, +}; /// Replaces a redundant hash repartition before a coarser aggregate with a /// single fan-in. @@ -80,12 +83,6 @@ impl ReduceAggregateRepartition { return Ok(Transformed::no(plan)); } - let Partitioning::Hash(finer_partition_exprs, _) = - repartition_exec.input().output_partitioning() - else { - return Ok(Transformed::no(plan)); - }; - let Some(required_distribution) = agg_exec.required_input_distribution().into_iter().next() else { @@ -100,15 +97,7 @@ impl ReduceAggregateRepartition { return Ok(Transformed::no(plan)); } - let coarsening_satisfaction = repartition_exec.partitioning().satisfaction( - &Distribution::HashPartitioned(finer_partition_exprs.clone()), - repartition_exec - .input() - .properties() - .equivalence_properties(), - true, - ); - if !coarsening_satisfaction.is_subset() { + if !Self::can_reduce_repartition(repartition_exec) { return Ok(Transformed::no(plan)); } @@ -129,6 +118,71 @@ impl ReduceAggregateRepartition { }) .data() } + + fn can_reduce_repartition(repartition_exec: &RepartitionExec) -> bool { + let has_direct_promql_input = + Self::has_direct_promql_partial_input(repartition_exec.input()); + if Self::contains_promql_exec_deep(repartition_exec.input()) { + return has_direct_promql_input; + } + + let Partitioning::Hash(finer_partition_exprs, _) = + repartition_exec.input().output_partitioning() + else { + return false; + }; + + let coarsening_satisfaction = repartition_exec.partitioning().satisfaction( + &Distribution::HashPartitioned(finer_partition_exprs.clone()), + repartition_exec + .input() + .properties() + .equivalence_properties(), + true, + ); + coarsening_satisfaction.is_subset() + } + + fn has_direct_promql_partial_input(plan: &Arc) -> bool { + let Some(partial_agg) = plan.as_any().downcast_ref::() else { + return false; + }; + + partial_agg.mode() == &AggregateMode::Partial + && Self::contains_promql_vector_exec(partial_agg.input()) + } + + fn contains_promql_vector_exec(plan: &Arc) -> bool { + if Self::is_promql_vector_exec(plan) { + return true; + } + + if plan.as_any().is::() { + return false; + } + + plan.children() + .into_iter() + .any(Self::contains_promql_vector_exec) + } + + fn contains_promql_exec_deep(plan: &Arc) -> bool { + if Self::is_promql_vector_exec(plan) { + return true; + } + + plan.children() + .into_iter() + .any(Self::contains_promql_exec_deep) + } + + fn is_promql_vector_exec(plan: &Arc) -> bool { + let plan = plan.as_any(); + plan.is::() + || plan.is::() + || plan.is::() + || plan.is::() + } } #[cfg(test)] diff --git a/src/query/src/query_engine/state.rs b/src/query/src/query_engine/state.rs index c6a96a6b03..3ded3f77fe 100644 --- a/src/query/src/query_engine/state.rs +++ b/src/query/src/query_engine/state.rs @@ -49,9 +49,7 @@ use datafusion_optimizer::Analyzer; use datafusion_optimizer::analyzer::function_rewrite::ApplyFunctionRewrites; use datafusion_optimizer::optimizer::Optimizer; use partition::manager::PartitionRuleManagerRef; -use promql::extension_plan::{ - InstantManipulate, PromExtensionPlanner, RangeManipulate, SeriesDivide, SeriesNormalize, -}; +use promql::extension_plan::PromExtensionPlanner; use table::TableRef; use table::table::adapter::DfTableProviderAdapter; @@ -461,50 +459,12 @@ impl QueryPlanner for DfQueryPlanner { logical_plan: &DfLogicalPlan, session_state: &SessionState, ) -> DfResult> { - let scoped_session_state; - let session_state = if should_disable_repartitioned_aggregations(logical_plan) { - scoped_session_state = { - let mut session_state = session_state.clone(); - *session_state.config_mut() = session_state - .config() - .clone() - .with_repartition_aggregations(false); - session_state - }; - &scoped_session_state - } else { - session_state - }; - self.physical_planner .create_physical_plan(logical_plan, session_state) .await } } -fn should_disable_repartitioned_aggregations(plan: &DfLogicalPlan) -> bool { - match plan { - DfLogicalPlan::Aggregate(aggregate) => contains_promql_vector_node(&aggregate.input), - _ => plan - .inputs() - .into_iter() - .any(should_disable_repartitioned_aggregations), - } -} - -fn contains_promql_vector_node(plan: &DfLogicalPlan) -> bool { - match plan { - DfLogicalPlan::Extension(extension) => { - let node = extension.node.as_any(); - node.is::() - || node.is::() - || node.is::() - || node.is::() - } - _ => plan.inputs().into_iter().any(contains_promql_vector_node), - } -} - /// MySQL-compatible scalar function aliases: (target_name, alias) const SCALAR_FUNCTION_ALIASES: &[(&str, &str)] = &[ ("upper", "ucase"),