From 799196330af9a15f1a6d31e22c06ca9de7b9f1ff Mon Sep 17 00:00:00 2001 From: discord9 Date: Wed, 20 May 2026 16:05:01 +0800 Subject: [PATCH] refactor: allow list Signed-off-by: discord9 --- src/flow/src/batching_mode/utils.rs | 254 ++++++++++++++++++++-------- 1 file changed, 179 insertions(+), 75 deletions(-) diff --git a/src/flow/src/batching_mode/utils.rs b/src/flow/src/batching_mode/utils.rs index 59041c2d31..56c22f5a07 100644 --- a/src/flow/src/batching_mode/utils.rs +++ b/src/flow/src/batching_mode/utils.rs @@ -111,29 +111,25 @@ pub struct IncrementalAggregateAnalysis { pub unsupported_exprs: Vec, } -/// Visitor that captures the aggregate expressions from the **innermost** -/// `Aggregate` node in the plan tree. +/// Visitor that captures aggregate expressions and counts `Aggregate` nodes in +/// the plan tree. /// -/// Since this visits `f_down` and continues recursion, it will overwrite -/// `aggr_exprs` for each `Aggregate` it encounters, ultimately retaining the -/// deepest (innermost) one. This is the intended behavior for the incremental -/// aggregate rewrite: nested aggregates are not supported, and the delta plan -/// produced by the flow engine places a single `Aggregate` at the bottom. -/// -/// If the plan contains multiple nested `Aggregate` nodes (a subquery with its -/// own aggregation), the innermost one is captured, which is conservative and -/// safe — it prevents the rewrite from incorrectly operating on the outer -/// aggregate. +/// Incremental aggregate rewrite only supports plans with exactly one aggregate +/// node. The count lets the analyzer reject nested/sibling aggregate plans +/// instead of accidentally rewriting against whichever aggregate was visited +/// last. #[derive(Default)] -struct LastAggregateExprFinder { +struct AggregateExprFinder { aggr_exprs: Option>, + aggregate_count: usize, } -impl TreeNodeVisitor<'_> for LastAggregateExprFinder { +impl TreeNodeVisitor<'_> for AggregateExprFinder { type Node = LogicalPlan; fn f_down(&mut self, node: &Self::Node) -> datafusion_common::Result { if let LogicalPlan::Aggregate(aggregate) = node { + self.aggregate_count += 1; self.aggr_exprs = Some(aggregate.aggr_expr.clone()); } Ok(TreeNodeRecursion::Continue) @@ -186,27 +182,51 @@ fn find_group_key_names(plan: &LogicalPlan) -> Result, Error> { Ok(group_key_names) } -fn find_aggregate_exprs(plan: &LogicalPlan) -> Result>, Error> { - let mut aggregate_finder = LastAggregateExprFinder::default(); +fn find_aggregate_exprs(plan: &LogicalPlan) -> Result<(usize, Option>), Error> { + let mut aggregate_finder = AggregateExprFinder::default(); plan.visit(&mut aggregate_finder) .with_context(|_| DatafusionSnafu { context: format!("Failed to inspect aggregate expressions from logical plan: {plan:?}"), })?; - Ok(aggregate_finder.aggr_exprs) + Ok(( + aggregate_finder.aggregate_count, + aggregate_finder.aggr_exprs, + )) } -fn contains_aggregate(plan: &LogicalPlan) -> bool { - matches!(plan, LogicalPlan::Aggregate(_)) || plan.inputs().into_iter().any(contains_aggregate) -} +fn check_inc_aggr_plan_shape(plan: &LogicalPlan) -> Result<(), String> { + // Supported final shape: optional output Projection directly over one + // Aggregate. Post-aggregate filters (HAVING), ordering, limits, + // distinct/window/union/extension nodes are intentionally not accepted. + let plan = match plan { + LogicalPlan::Projection(projection) => projection.input.as_ref(), + _ => plan, + }; -fn has_filter_above_aggregate(plan: &LogicalPlan) -> bool { match plan { - // HAVING and other post-aggregate filters appear as `Filter` nodes above - // an `Aggregate`. Applying them before the sink-merge would filter on - // the delta aggregate rather than the final merged aggregate, so reject - // them until the rewrite can rebuild the predicate after merging. - LogicalPlan::Filter(filter) if contains_aggregate(filter.input.as_ref()) => true, - _ => plan.inputs().into_iter().any(has_filter_above_aggregate), + LogicalPlan::Aggregate(aggregate) => check_input_plan_shape(aggregate.input.as_ref()), + LogicalPlan::Filter(_) => Err( + "unsupported post-aggregate filter (HAVING) in incremental aggregate rewrite" + .to_string(), + ), + _ => Err( + "unsupported post-aggregate plan shape in incremental aggregate rewrite".to_string(), + ), + } +} + +fn check_input_plan_shape(plan: &LogicalPlan) -> Result<(), String> { + match plan { + // Supported aggregate input shape: optional WHERE filter over a table scan. + LogicalPlan::TableScan(_) => Ok(()), + LogicalPlan::Filter(filter) + if matches!(filter.input.as_ref(), LogicalPlan::TableScan(_)) => + { + Ok(()) + } + _ => Err( + "unsupported aggregate input plan shape in incremental aggregate rewrite".to_string(), + ), } } @@ -303,22 +323,30 @@ fn collect_output_projection_info(plan: &LogicalPlan) -> OutputProjectionInfo { projection_info } -fn merge_op_for_aggregate_expr(aggr_expr: &Expr) -> Option { - let aggr_func = get_aggr_func(aggr_expr)?; +fn merge_op_for_aggregate_expr(aggr_expr: &Expr) -> Result { + let Some(aggr_func) = get_aggr_func(aggr_expr) else { + return Err(aggr_expr.to_string()); + }; if aggr_func.params.distinct { - return None; + return Err(format!("unsupported DISTINCT aggregate: {aggr_expr}")); + } + if !aggr_func.params.order_by.is_empty() { + return Err(format!("unsupported aggregate ORDER BY: {aggr_expr}")); + } + if aggr_func.params.null_treatment.is_some() { + return Err(format!("unsupported aggregate NULL treatment: {aggr_expr}")); } match aggr_func.func.name().to_ascii_lowercase().as_str() { - "sum" | "count" => Some(IncrementalAggregateMergeOp::Sum), - "min" => Some(IncrementalAggregateMergeOp::Min), - "max" => Some(IncrementalAggregateMergeOp::Max), - "bool_and" => Some(IncrementalAggregateMergeOp::BoolAnd), - "bool_or" => Some(IncrementalAggregateMergeOp::BoolOr), - "bit_and" => Some(IncrementalAggregateMergeOp::BitAnd), - "bit_or" => Some(IncrementalAggregateMergeOp::BitOr), - "bit_xor" => Some(IncrementalAggregateMergeOp::BitXor), - _ => None, + "sum" | "count" => Ok(IncrementalAggregateMergeOp::Sum), + "min" => Ok(IncrementalAggregateMergeOp::Min), + "max" => Ok(IncrementalAggregateMergeOp::Max), + "bool_and" => Ok(IncrementalAggregateMergeOp::BoolAnd), + "bool_or" => Ok(IncrementalAggregateMergeOp::BoolOr), + "bit_and" => Ok(IncrementalAggregateMergeOp::BitAnd), + "bit_or" => Ok(IncrementalAggregateMergeOp::BitOr), + "bit_xor" => Ok(IncrementalAggregateMergeOp::BitXor), + _ => Err(aggr_expr.to_string()), } } @@ -370,7 +398,7 @@ pub fn analyze_incremental_aggregate_plan( plan: &LogicalPlan, ) -> Result, Error> { let group_key_names = find_group_key_names(plan)?; - let Some(aggr_exprs) = find_aggregate_exprs(plan)? else { + let (aggregate_count, Some(aggr_exprs)) = find_aggregate_exprs(plan)? else { return Ok(None); }; let projection_info = collect_output_projection_info(plan); @@ -382,11 +410,13 @@ pub fn analyze_incremental_aggregate_plan( .into_iter() .map(|name| format!("duplicate output field name: {name}")) .collect::>(); - if has_filter_above_aggregate(plan) { - unsupported_exprs.push( - "unsupported post-aggregate filter (HAVING) in incremental aggregate rewrite" - .to_string(), - ); + if aggregate_count != 1 { + unsupported_exprs.push(format!( + "unsupported aggregate plan contains {aggregate_count} Aggregate nodes" + )); + } + if let Err(reason) = check_inc_aggr_plan_shape(plan) { + unsupported_exprs.push(reason); } unsupported_exprs.extend(projection_info.duplicate_aggregate_aliases.iter().cloned()); if group_key_names.is_empty() @@ -400,9 +430,12 @@ pub fn analyze_incremental_aggregate_plan( )); } for aggr_expr in aggr_exprs { - let Some(merge_op) = merge_op_for_aggregate_expr(&aggr_expr) else { - unsupported_exprs.push(aggr_expr.to_string()); - continue; + let merge_op = match merge_op_for_aggregate_expr(&aggr_expr) { + Ok(merge_op) => merge_op, + Err(reason) => { + unsupported_exprs.push(reason); + continue; + } }; let Some(output_field_name) = resolve_aggregate_output_field_name( &aggr_expr, @@ -1402,6 +1435,28 @@ mod test { .alias(field_name) } + async fn analyze_test_sql(sql: &str) -> IncrementalAggregateAnalysis { + let query_engine = create_test_query_engine(); + let ctx = QueryContext::arc(); + let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); + analyze_incremental_aggregate_plan(&plan).unwrap().unwrap() + } + + fn assert_unsupported(analysis: &IncrementalAggregateAnalysis, reason: &str) { + assert!( + analysis + .unsupported_exprs + .iter() + .any(|expr| expr.contains(reason)), + "expected unsupported reason containing {reason:?}, got {:?}", + analysis.unsupported_exprs + ); + assert!( + analysis.merge_columns.is_empty(), + "unsupported analysis should disable merge columns" + ); + } + /// test if uppercase are handled correctly(with quote) #[tokio::test] async fn test_sql_plan_convert() { @@ -1808,26 +1863,83 @@ mod test { #[tokio::test] async fn test_analyze_incremental_aggregate_plan_rejects_having_filter() { - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); let sql = "SELECT sum(number) AS number, ts FROM numbers_with_ts GROUP BY ts HAVING sum(number) > 10"; - let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); + let analysis = analyze_test_sql(sql).await; + assert_unsupported(&analysis, "post-aggregate filter"); + } - let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); - assert!( - analysis - .unsupported_exprs - .iter() - .any(|expr| expr.contains("post-aggregate filter")), - "HAVING/post-aggregate filter should be unsupported: {:?}", - analysis.unsupported_exprs - ); - assert!( - analysis.merge_columns.is_empty(), - "unsupported HAVING should disable merge columns" + #[tokio::test] + async fn test_analyze_incremental_aggregate_plan_allows_aggregate_filter() { + let sql = "SELECT sum(number) FILTER (WHERE number > 10) AS number, ts FROM numbers_with_ts GROUP BY ts"; + let analysis = analyze_test_sql(sql).await; + + assert!(analysis.unsupported_exprs.is_empty()); + assert!(analysis.group_key_names.contains(&"ts".to_string())); + assert_eq!(analysis.merge_columns.len(), 1); + assert_eq!(analysis.merge_columns[0].output_field_name, "number"); + assert_eq!( + analysis.merge_columns[0].merge_op, + IncrementalAggregateMergeOp::Sum ); } + #[tokio::test] + async fn test_analyze_incremental_aggregate_plan_rejects_aggregate_order_by() { + let sql = "SELECT sum(number ORDER BY ts) AS number, ts FROM numbers_with_ts GROUP BY ts"; + let analysis = analyze_test_sql(sql).await; + assert_unsupported(&analysis, "aggregate ORDER BY"); + } + + #[tokio::test] + async fn test_analyze_incremental_aggregate_plan_rejects_sort_above_aggregate() { + let sql = "SELECT sum(number) AS number, ts FROM numbers_with_ts GROUP BY ts ORDER BY number DESC"; + let analysis = analyze_test_sql(sql).await; + assert_unsupported(&analysis, "post-aggregate plan shape"); + } + + #[tokio::test] + async fn test_analyze_incremental_aggregate_plan_rejects_limit_above_aggregate() { + let sql = "SELECT max(number) AS number, ts FROM numbers_with_ts GROUP BY ts LIMIT 1"; + let analysis = analyze_test_sql(sql).await; + assert_unsupported(&analysis, "post-aggregate plan shape"); + } + + #[tokio::test] + async fn test_analyze_incremental_aggregate_plan_rejects_distinct_above_aggregate() { + let sql = "SELECT DISTINCT max(number) AS number, ts FROM numbers_with_ts GROUP BY ts"; + let analysis = analyze_test_sql(sql).await; + assert_unsupported(&analysis, "post-aggregate plan shape"); + } + + #[tokio::test] + async fn test_analyze_incremental_aggregate_plan_rejects_nested_aggregates() { + let sql = "SELECT sum(cnt) AS total FROM (SELECT count(*) AS cnt, ts FROM numbers_with_ts GROUP BY ts) s"; + let analysis = analyze_test_sql(sql).await; + assert_unsupported(&analysis, "Aggregate nodes"); + } + + #[tokio::test] + async fn test_analyze_incremental_aggregate_plan_rejects_union_aggregate_branches() { + let sql = "SELECT sum(number) AS number, ts FROM numbers_with_ts GROUP BY ts UNION ALL SELECT max(number) AS number, ts FROM numbers_with_ts GROUP BY ts"; + let analysis = analyze_test_sql(sql).await; + assert_unsupported(&analysis, "Aggregate nodes"); + assert_unsupported(&analysis, "post-aggregate plan shape"); + } + + #[tokio::test] + async fn test_analyze_incremental_aggregate_plan_rejects_window_above_aggregate() { + let sql = "SELECT sum(number) AS number, ts, row_number() OVER (ORDER BY sum(number)) AS rn FROM numbers_with_ts GROUP BY ts"; + let analysis = analyze_test_sql(sql).await; + assert_unsupported(&analysis, "post-aggregate plan shape"); + } + + #[tokio::test] + async fn test_analyze_incremental_aggregate_plan_rejects_join_below_aggregate() { + let sql = "SELECT sum(lhs.number) AS number, lhs.ts FROM numbers_with_ts AS lhs JOIN numbers_with_ts AS rhs ON lhs.ts = rhs.ts GROUP BY lhs.ts"; + let analysis = analyze_test_sql(sql).await; + assert_unsupported(&analysis, "aggregate input plan shape"); + } + #[tokio::test] async fn test_analyze_incremental_aggregate_plan_preserves_raw_aggregate_name() { let query_engine = create_test_query_engine(); @@ -2627,28 +2739,20 @@ mod test { } #[tokio::test] - async fn test_last_aggregate_finder_captures_innermost() { + async fn test_aggregate_expr_finder_counts_multiple_aggregates() { let query_engine = create_test_query_engine(); let ctx = QueryContext::arc(); // Subquery has an inner aggregate (count), outer query has another aggregate (sum). - // LastAggregateExprFinder should capture the innermost one (count). let sql = "SELECT sum(cnt) AS total, ts \ FROM (SELECT ts, count(number) AS cnt FROM numbers_with_ts GROUP BY ts) AS sub \ GROUP BY ts"; let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); - let mut finder = LastAggregateExprFinder::default(); + let mut finder = AggregateExprFinder::default(); plan.visit(&mut finder).unwrap(); - let aggr_exprs = finder.aggr_exprs.unwrap(); - assert_eq!( - aggr_exprs.len(), - 1, - "Expected innermost aggregate to have 1 expression" - ); - let found_name = aggr_exprs[0].qualified_name().1.to_ascii_lowercase(); assert!( - found_name.contains("count"), - "Expected innermost aggregate to be count, got: {found_name}" + finder.aggregate_count > 1, + "nested aggregate plans should be identifiable as unsupported" ); } }