From 0251113fc7814420e75eb14dc23dd964bec3b684 Mon Sep 17 00:00:00 2001 From: discord9 Date: Fri, 15 May 2026 18:14:55 +0800 Subject: [PATCH] per review Signed-off-by: discord9 --- src/flow/src/batching_mode/utils.rs | 111 ++++++++++++++++++++++++++-- 1 file changed, 106 insertions(+), 5 deletions(-) diff --git a/src/flow/src/batching_mode/utils.rs b/src/flow/src/batching_mode/utils.rs index d4e0c07f45..0c17dd75d1 100644 --- a/src/flow/src/batching_mode/utils.rs +++ b/src/flow/src/batching_mode/utils.rs @@ -59,10 +59,27 @@ use crate::{Error, TableName}; #[derive(Debug, Clone, PartialEq, Eq)] pub struct IncrementalAggregateMergeColumn { /// Final output/sink field name for the aggregate result/state column. + /// + /// Must NOT include a plan/table qualifier (no `.` separator). pub output_field_name: String, pub merge_op: IncrementalAggregateMergeOp, } +impl IncrementalAggregateMergeColumn { + /// Create a new merge column, validating that `output_field_name` does not + /// contain a plan/table qualifier. + pub fn new(output_field_name: String, merge_op: IncrementalAggregateMergeOp) -> Self { + debug_assert!( + !output_field_name.contains('.'), + "output_field_name must not include a plan/table qualifier, got: {output_field_name}" + ); + Self { + output_field_name, + merge_op, + } + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum IncrementalAggregateMergeOp { Sum, @@ -89,6 +106,19 @@ pub struct IncrementalAggregateAnalysis { pub unsupported_exprs: Vec, } +/// Visitor that captures the aggregate expressions from the **innermost** +/// `Aggregate` node in the plan tree. +/// +/// Since this visits `f_down` and continues recursion, it will overwrite +/// `aggr_exprs` for each `Aggregate` it encounters, ultimately retaining the +/// deepest (innermost) one. This is the intended behavior for the incremental +/// aggregate rewrite: nested aggregates are not supported, and the delta plan +/// produced by the flow engine places a single `Aggregate` at the bottom. +/// +/// If the plan contains multiple nested `Aggregate` nodes (a subquery with its +/// own aggregation), the innermost one is captured, which is conservative and +/// safe — it prevents the rewrite from incorrectly operating on the outer +/// aggregate. #[derive(Default)] struct LastAggregateExprFinder { aggr_exprs: Option>, @@ -209,15 +239,20 @@ pub fn analyze_incremental_aggregate_plan( continue; }; + // qualified_name() returns (Option, String) where the second + // element is the unqualified column/alias name. This relies on + // DataFusion's internal naming convention: aggregate expressions + // emit a column named after the aggregate itself (e.g. "SUM(x)"), + // which matches what the projection aliases reference. let raw_name = aggr_expr.qualified_name().1; let Some(output_field_name) = output_aliases.get(&raw_name).cloned() else { unsupported_exprs.push(aggr_expr.to_string()); continue; }; - merge_columns.push(IncrementalAggregateMergeColumn { + merge_columns.push(IncrementalAggregateMergeColumn::new( output_field_name, merge_op, - }); + )); } Ok(Some(IncrementalAggregateAnalysis { @@ -1293,38 +1328,55 @@ mod test { async fn test_analyze_incremental_aggregate_plan() { let query_engine = create_test_query_engine(); let ctx = QueryContext::arc(); - let testcases = vec![ + let testcases: Vec<(&str, IncrementalAggregateMergeOp, &str)> = vec![ ( "SELECT sum(number) AS number, ts FROM numbers_with_ts GROUP BY ts", IncrementalAggregateMergeOp::Sum, + "number", ), ( "SELECT count(number) AS number, ts FROM numbers_with_ts GROUP BY ts", IncrementalAggregateMergeOp::Sum, + "number", ), ( "SELECT min(number) AS number, ts FROM numbers_with_ts GROUP BY ts", IncrementalAggregateMergeOp::Min, + "number", ), ( "SELECT max(number) AS number, ts FROM numbers_with_ts GROUP BY ts", IncrementalAggregateMergeOp::Max, + "number", ), ( "SELECT bit_and(number) AS number, ts FROM numbers_with_ts GROUP BY ts", IncrementalAggregateMergeOp::BitAnd, + "number", ), ( "SELECT bit_or(number) AS number, ts FROM numbers_with_ts GROUP BY ts", IncrementalAggregateMergeOp::BitOr, + "number", ), ( "SELECT bit_xor(number) AS number, ts FROM numbers_with_ts GROUP BY ts", IncrementalAggregateMergeOp::BitXor, + "number", + ), + ( + "SELECT bool_and(number > 5) AS cond, ts FROM numbers_with_ts GROUP BY ts", + IncrementalAggregateMergeOp::BoolAnd, + "cond", + ), + ( + "SELECT bool_or(number > 5) AS cond, ts FROM numbers_with_ts GROUP BY ts", + IncrementalAggregateMergeOp::BoolOr, + "cond", ), ]; - for (sql, expected_op) in testcases { + for (sql, expected_op, expected_field_name) in testcases { let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false) .await .unwrap(); @@ -1333,7 +1385,10 @@ mod test { assert!(analysis.unsupported_exprs.is_empty()); assert!(analysis.group_key_names.contains(&"ts".to_string())); assert_eq!(analysis.merge_columns.len(), 1); - assert_eq!(analysis.merge_columns[0].output_field_name, "number"); + assert_eq!( + analysis.merge_columns[0].output_field_name, + expected_field_name + ); assert_eq!(analysis.merge_columns[0].merge_op, expected_op); } } @@ -1359,6 +1414,26 @@ mod test { })); } + #[tokio::test] + async fn test_analyze_incremental_aggregate_plan_multiple_group_keys() { + let query_engine = create_test_query_engine(); + let ctx = QueryContext::arc(); + let sql = "SELECT sum(number) AS total, ts, number AS bucket FROM numbers_with_ts GROUP BY ts, number"; + let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); + + let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); + assert!(analysis.unsupported_exprs.is_empty()); + assert!(analysis.group_key_names.contains(&"ts".to_string())); + assert!(analysis.group_key_names.contains(&"bucket".to_string())); + assert_eq!(analysis.group_key_names.len(), 2); + assert_eq!(analysis.merge_columns.len(), 1); + assert_eq!(analysis.merge_columns[0].output_field_name, "total"); + assert_eq!( + analysis.merge_columns[0].merge_op, + IncrementalAggregateMergeOp::Sum + ); + } + #[tokio::test] async fn test_analyze_incremental_aggregate_plan_rejects_avg() { let query_engine = create_test_query_engine(); @@ -1477,4 +1552,30 @@ mod test { IncrementalAggregateMergeOp::Sum ); } + + #[tokio::test] + async fn test_last_aggregate_finder_captures_innermost() { + let query_engine = create_test_query_engine(); + let ctx = QueryContext::arc(); + // Subquery has an inner aggregate (count), outer query has another aggregate (sum). + // LastAggregateExprFinder should capture the innermost one (count). + let sql = "SELECT sum(cnt) AS total, ts \ + FROM (SELECT ts, count(number) AS cnt FROM numbers_with_ts GROUP BY ts) AS sub \ + GROUP BY ts"; + let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); + + let mut finder = LastAggregateExprFinder::default(); + plan.visit(&mut finder).unwrap(); + let aggr_exprs = finder.aggr_exprs.unwrap(); + assert_eq!( + aggr_exprs.len(), + 1, + "Expected innermost aggregate to have 1 expression" + ); + let found_name = aggr_exprs[0].qualified_name().1.to_ascii_lowercase(); + assert!( + found_name.contains("count"), + "Expected innermost aggregate to be count, got: {found_name}" + ); + } }