From 024ca1af795321947f8056b127ae89d991fba895 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Wed, 15 Apr 2026 08:44:33 +0800 Subject: [PATCH] refactor Signed-off-by: Ruihang Xia --- .../optimizer/reduce_aggregate_repartition.rs | 87 +++++++------------ 1 file changed, 31 insertions(+), 56 deletions(-) diff --git a/src/query/src/optimizer/reduce_aggregate_repartition.rs b/src/query/src/optimizer/reduce_aggregate_repartition.rs index 981ed5698f..8da3b6b9c3 100644 --- a/src/query/src/optimizer/reduce_aggregate_repartition.rs +++ b/src/query/src/optimizer/reduce_aggregate_repartition.rs @@ -120,9 +120,13 @@ impl ReduceAggregateRepartition { } fn can_reduce_repartition(repartition_exec: &RepartitionExec) -> bool { - let has_direct_promql_input = - Self::has_direct_promql_partial_input(repartition_exec.input()); - if Self::contains_promql_exec_deep(repartition_exec.input()) { + let has_direct_promql_input = matches!( + repartition_exec.input().as_any().downcast_ref::(), + Some(partial_agg) + if partial_agg.mode() == &AggregateMode::Partial + && Self::contains_promql_exec(partial_agg.input(), true) + ); + if Self::contains_promql_exec(repartition_exec.input(), false) { return has_direct_promql_input; } @@ -143,45 +147,23 @@ impl ReduceAggregateRepartition { coarsening_satisfaction.is_subset() } - fn has_direct_promql_partial_input(plan: &Arc) -> bool { - let Some(partial_agg) = plan.as_any().downcast_ref::() else { - return false; - }; - - partial_agg.mode() == &AggregateMode::Partial - && Self::contains_promql_vector_exec(partial_agg.input()) - } - - fn contains_promql_vector_exec(plan: &Arc) -> bool { - if Self::is_promql_vector_exec(plan) { + fn contains_promql_exec(plan: &Arc, stop_at_aggregate: bool) -> bool { + let plan_any = plan.as_any(); + if plan_any.is::() + || plan_any.is::() + || plan_any.is::() + || plan_any.is::() + { return true; } - if plan.as_any().is::() { + if stop_at_aggregate && plan_any.is::() { return false; } plan.children() .into_iter() - .any(Self::contains_promql_vector_exec) - } - - fn contains_promql_exec_deep(plan: &Arc) -> bool { - if Self::is_promql_vector_exec(plan) { - return true; - } - - plan.children() - .into_iter() - .any(Self::contains_promql_exec_deep) - } - - fn is_promql_vector_exec(plan: &Arc) -> bool { - let plan = plan.as_any(); - plan.is::() - || plan.is::() - || plan.is::() - || plan.is::() + .any(|child| Self::contains_promql_exec(child, stop_at_aggregate)) } } @@ -205,15 +187,12 @@ mod tests { use super::ReduceAggregateRepartition; - fn schema() -> SchemaRef { - Arc::new(Schema::new(vec![ + fn input_exec() -> Arc { + let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int64, true), Field::new("b", DataType::Int64, true), Field::new("c", DataType::Int64, true), - ])) - } - - fn input_exec(schema: SchemaRef) -> Arc { + ])); let config = MemorySourceConfig::try_new(&[vec![]], schema, None).unwrap(); DataSourceExec::from_data_source(config) } @@ -241,13 +220,6 @@ mod tests { )?)) } - fn round_robin_repartition(input: Arc) -> Result> { - Ok(Arc::new(RepartitionExec::try_new( - input, - Partitioning::RoundRobinBatch(8), - )?)) - } - fn aggregate( mode: AggregateMode, input: Arc, @@ -284,7 +256,7 @@ mod tests { #[test] fn rewrites_final_partitioned_subset_repartition() -> Result<()> { - let raw = input_exec(schema()); + let raw = input_exec(); let finer = repartition(raw.clone(), &["a", "b"], &raw.schema())?; let partial = aggregate( AggregateMode::Partial, @@ -315,7 +287,7 @@ mod tests { #[test] fn rewrites_single_partitioned_subset_repartition() -> Result<()> { - let raw = input_exec(schema()); + let raw = input_exec(); let finer = repartition(raw.clone(), &["a", "b"], &raw.schema())?; let final_repartition = repartition(finer.clone(), &["a"], &finer.schema())?; let final_agg = aggregate( @@ -338,7 +310,7 @@ mod tests { #[test] fn keeps_equal_partitioning_keys() -> Result<()> { - let raw = input_exec(schema()); + let raw = input_exec(); let finer = repartition(raw.clone(), &["a", "b"], &raw.schema())?; let partial = aggregate( AggregateMode::Partial, @@ -369,7 +341,7 @@ mod tests { #[test] fn rewrites_when_finer_key_order_differs() -> Result<()> { - let raw = input_exec(schema()); + let raw = input_exec(); let finer = repartition(raw.clone(), &["c", "a", "b"], &raw.schema())?; let partial = aggregate( AggregateMode::Partial, @@ -400,7 +372,7 @@ mod tests { #[test] fn rewrites_when_repartition_satisfies_group_by_with_subset_keys() -> Result<()> { - let raw = input_exec(schema()); + let raw = input_exec(); let finer = repartition(raw.clone(), &["a", "b", "c"], &raw.schema())?; let final_repartition = repartition(finer.clone(), &["a"], &finer.schema())?; let final_agg = aggregate( @@ -423,7 +395,7 @@ mod tests { #[test] fn keeps_non_hash_repartition_child() -> Result<()> { - let raw = input_exec(schema()); + let raw = input_exec(); let finer = repartition(raw.clone(), &["a", "b"], &raw.schema())?; let partial = aggregate( AggregateMode::Partial, @@ -432,7 +404,10 @@ mod tests { raw.schema(), vec![], )?; - let final_repartition = round_robin_repartition(partial.clone())?; + let final_repartition = Arc::new(RepartitionExec::try_new( + partial.clone(), + Partitioning::RoundRobinBatch(8), + )?); let final_agg = aggregate( AggregateMode::FinalPartitioned, final_repartition, @@ -454,7 +429,7 @@ mod tests { #[test] fn rewrites_subset_partitioning_through_projection() -> Result<()> { - let raw = input_exec(schema()); + let raw = input_exec(); let finer = repartition(raw.clone(), &["a", "b", "c"], &raw.schema())?; let projected = project_with_aliases(finer, &[("a", "x"), ("b", "y"), ("c", "z")])?; let final_repartition = repartition(projected.clone(), &["x", "y"], &projected.schema())?; @@ -477,7 +452,7 @@ mod tests { #[test] fn keeps_non_subset_repartition() -> Result<()> { - let raw = input_exec(schema()); + let raw = input_exec(); let coarser = repartition(raw.clone(), &["a"], &raw.schema())?; let final_repartition = repartition(coarser.clone(), &["a", "b"], &coarser.schema())?; let final_agg = aggregate(