From 90a119ceadea95703809d0b1b4a17db71bea6a41 Mon Sep 17 00:00:00 2001 From: discord9 Date: Tue, 19 May 2026 22:14:37 +0800 Subject: [PATCH] test: add expected plan test Signed-off-by: discord9 --- src/flow/src/batching_mode/utils.rs | 435 ++++++++++++++++++++++++---- 1 file changed, 371 insertions(+), 64 deletions(-) diff --git a/src/flow/src/batching_mode/utils.rs b/src/flow/src/batching_mode/utils.rs index 6ca69d9d8f..188815174d 100644 --- a/src/flow/src/batching_mode/utils.rs +++ b/src/flow/src/batching_mode/utils.rs @@ -199,6 +199,7 @@ fn find_aggregate_exprs(plan: &LogicalPlan) -> Result>, Error> struct OutputProjectionInfo { has_top_level_projection: bool, output_aliases: HashMap, + duplicate_aggregate_aliases: BTreeSet, literal_columns: HashSet, output_field_names: Vec, } @@ -251,7 +252,15 @@ fn collect_output_projection_info(plan: &LogicalPlan) -> OutputProjectionInfo { } 1 => { if let Some(col_name) = col_names.into_iter().next() { - output_aliases.entry(col_name).or_insert(alias_name); + if let Some(existing_alias) = output_aliases.get(&col_name) { + if existing_alias != &alias_name { + projection_info.duplicate_aggregate_aliases.insert(format!( + "same aggregate output {col_name} is used by multiple aliases: {existing_alias}, {alias_name}" + )); + } + } else { + output_aliases.insert(col_name, alias_name); + } } } _ => {} @@ -358,6 +367,7 @@ pub fn analyze_incremental_aggregate_plan( .into_iter() .map(|name| format!("duplicate output field name: {name}")) .collect::>(); + unsupported_exprs.extend(projection_info.duplicate_aggregate_aliases.iter().cloned()); if group_key_names.is_empty() && projection_info .output_field_names @@ -391,6 +401,9 @@ pub fn analyze_incremental_aggregate_plan( .into_iter() .map(|name| format!("unsupported output field: {name}")), ); + if !unsupported_exprs.is_empty() { + merge_columns.clear(); + } let mut literal_columns = projection_info .literal_columns .into_iter() @@ -406,6 +419,36 @@ pub fn analyze_incremental_aggregate_plan( })) } +/// Rewrites one incremental aggregate delta plan by left-joining it with the +/// existing sink-table state and projecting merged aggregate outputs. +/// +/// For a grouped aggregate such as: +/// +/// ```text +/// SELECT max(number) AS number, ts FROM numbers_with_ts GROUP BY ts +/// ``` +/// +/// the rewrite is roughly: +/// +/// ```text +/// delta = SELECT ts, number FROM AS __flow_delta +/// sink = SELECT ts, number FROM AS __flow_sink +/// SELECT +/// CASE +/// WHEN __flow_sink.number IS NULL THEN __flow_delta.number +/// WHEN __flow_delta.number >= __flow_sink.number THEN __flow_delta.number +/// ELSE __flow_sink.number +/// END AS number, +/// __flow_delta.ts AS ts +/// FROM delta +/// LEFT JOIN sink +/// ON __flow_delta.ts IS NOT DISTINCT FROM __flow_sink.ts +/// ``` +/// +/// For a global aggregate without group keys, DataFusion still requires a +/// non-empty join condition. We add `__flow_global_aggregate_join_key = 1` to +/// both sides and join on it. This relies on the global aggregate sink keeping a +/// single state row; multiple sink rows would fan out the single delta row. pub async fn rewrite_incremental_aggregate_with_sink_merge( delta_plan: &LogicalPlan, analysis: &IncrementalAggregateAnalysis, @@ -1249,6 +1292,95 @@ mod test { u32_table(table_name, columns, 0) } + fn assert_same_logical_plan(actual: &LogicalPlan, expected: &LogicalPlan) { + assert_eq!( + format!("{}", expected.display_indent()), + format!("{}", actual.display_indent()) + ); + } + + fn test_sink_scan(sink_table: TableRef, sink_table_name: &TableName) -> LogicalPlan { + let table_provider = Arc::new(DfTableProviderAdapter::new(sink_table)); + let table_source = Arc::new(DefaultTableSource::new(table_provider)); + LogicalPlan::TableScan( + TableScan::try_new( + TableReference::Full { + catalog: sink_table_name[0].clone().into(), + schema: sink_table_name[1].clone().into(), + table: sink_table_name[2].clone().into(), + }, + table_source, + None, + vec![], + None, + ) + .unwrap(), + ) + } + + fn expected_left_join_rewrite( + delta_plan: &LogicalPlan, + sink_table: TableRef, + sink_table_name: &TableName, + delta_selected_exprs: Vec, + sink_selected_exprs: Vec, + join_keys: (Vec, Vec), + projection_exprs: Vec, + ) -> LogicalPlan { + let delta_alias = "__flow_delta"; + let sink_alias = "__flow_sink"; + let delta_selected = LogicalPlanBuilder::from(delta_plan.clone()) + .project(delta_selected_exprs) + .unwrap() + .alias(delta_alias) + .unwrap() + .build() + .unwrap(); + let sink_selected = LogicalPlanBuilder::from(test_sink_scan(sink_table, sink_table_name)) + .project(sink_selected_exprs) + .unwrap() + .alias(sink_alias) + .unwrap() + .build() + .unwrap(); + let joined = LogicalPlanBuilder::from(delta_selected) + .join_detailed( + sink_selected, + JoinType::Left, + join_keys, + None, + NullEquality::NullEqualsNull, + ) + .unwrap() + .build() + .unwrap(); + LogicalPlanBuilder::from(joined) + .project(projection_exprs) + .unwrap() + .build() + .unwrap() + } + + fn max_merge_expr(field_name: &str) -> Expr { + let left = qualified_col("__flow_delta", field_name); + let right = qualified_col("__flow_sink", field_name); + when(is_null(right.clone()), left.clone()) + .when(left.clone().gt_eq(right.clone()), left) + .otherwise(right) + .unwrap() + .alias(field_name) + } + + fn sum_merge_expr(field_name: &str) -> Expr { + let left = qualified_col("__flow_delta", field_name); + let right = qualified_col("__flow_sink", field_name); + when(is_null(left.clone()), right.clone()) + .when(is_null(right.clone()), left.clone()) + .otherwise(binary_expr(left, Operator::Plus, right)) + .unwrap() + .alias(field_name) + } + /// test if uppercase are handled correctly(with quote) #[tokio::test] async fn test_sql_plan_convert() { @@ -1752,25 +1884,22 @@ mod test { vec!["number".to_string(), "ts".to_string(), "lit".to_string()] ); + let sink_table_name = [ + "greptime".to_string(), + "public".to_string(), + "numbers_with_ts".to_string(), + ]; let (sink_table, _) = get_table_info_df_schema( query_engine.engine_state().catalog_manager().clone(), - [ - "greptime".to_string(), - "public".to_string(), - "numbers_with_ts".to_string(), - ], + sink_table_name.clone(), ) .await .unwrap(); let rewritten = rewrite_incremental_aggregate_with_sink_merge( &plan, &analysis, - sink_table, - &[ - "greptime".to_string(), - "public".to_string(), - "numbers_with_ts".to_string(), - ], + sink_table.clone(), + &sink_table_name, ) .await .unwrap(); @@ -1782,6 +1911,27 @@ mod test { .map(|field| field.name().clone()) .collect::>(); assert_eq!(rewritten_fields, analysis.output_field_names); + let expected = expected_left_join_rewrite( + &plan, + sink_table, + &sink_table_name, + vec![ + unqualified_col("ts"), + unqualified_col("number"), + unqualified_col("lit"), + ], + vec![unqualified_col("ts"), unqualified_col("number")], + ( + vec![qualified_column("__flow_delta", "ts")], + vec![qualified_column("__flow_sink", "ts")], + ), + vec![ + max_merge_expr("number"), + qualified_col("__flow_delta", "ts").alias("ts"), + qualified_col("__flow_delta", "lit").alias("lit"), + ], + ); + assert_same_logical_plan(&rewritten, &expected); } #[tokio::test] @@ -1813,6 +1963,54 @@ mod test { analysis.output_field_names, vec!["number".to_string(), "label".to_string()] ); + + let sink_table = single_row_u32_table("string_literal_sink", vec!["number"]); + let sink_table_name = [ + "greptime".to_string(), + "public".to_string(), + "string_literal_sink".to_string(), + ]; + let rewritten = rewrite_incremental_aggregate_with_sink_merge( + &plan, + &analysis, + sink_table.clone(), + &sink_table_name, + ) + .await + .unwrap(); + + assert_eq!( + rewritten + .schema() + .fields() + .iter() + .map(|field| field.name().clone()) + .collect::>(), + vec!["number".to_string(), "label".to_string()] + ); + let expected = expected_left_join_rewrite( + &plan, + sink_table, + &sink_table_name, + vec![ + unqualified_col("number"), + unqualified_col("label"), + lit(1i32).alias(GLOBAL_AGGREGATE_JOIN_KEY), + ], + vec![ + unqualified_col("number"), + lit(1i32).alias(GLOBAL_AGGREGATE_JOIN_KEY), + ], + ( + vec![qualified_column("__flow_delta", GLOBAL_AGGREGATE_JOIN_KEY)], + vec![qualified_column("__flow_sink", GLOBAL_AGGREGATE_JOIN_KEY)], + ), + vec![ + max_merge_expr("number"), + qualified_col("__flow_delta", "label").alias("label"), + ], + ); + assert_same_logical_plan(&rewritten, &expected); } #[tokio::test] @@ -1924,10 +2122,13 @@ mod test { analysis .unsupported_exprs .iter() - .any(|expr| expr.contains("unsupported output field: b")), + .any(|expr| expr.contains("same aggregate output") + && expr.contains("a") + && expr.contains("b")), "same aggregate with multiple aliases should be unsupported until explicit reproduction is implemented: {:?}", analysis.unsupported_exprs ); + assert!(analysis.merge_columns.is_empty()); } #[test] @@ -2013,13 +2214,14 @@ mod test { .await .unwrap(); let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); + let sink_table_name = [ + "greptime".to_string(), + "public".to_string(), + "numbers_with_ts".to_string(), + ]; let (sink_table, _) = get_table_info_df_schema( query_engine.engine_state().catalog_manager().clone(), - [ - "greptime".to_string(), - "public".to_string(), - "numbers_with_ts".to_string(), - ], + sink_table_name.clone(), ) .await .unwrap(); @@ -2027,19 +2229,28 @@ mod test { let rewritten = rewrite_incremental_aggregate_with_sink_merge( &plan, &analysis, - sink_table, - &[ - "greptime".to_string(), - "public".to_string(), - "numbers_with_ts".to_string(), - ], + sink_table.clone(), + &sink_table_name, ) .await .unwrap(); - let plan_text = format!("{}", rewritten.display_indent()); - assert!(plan_text.contains("Left Join")); - assert!(!plan_text.contains("Union")); + let expected = expected_left_join_rewrite( + &plan, + sink_table, + &sink_table_name, + vec![unqualified_col("ts"), unqualified_col("number")], + vec![unqualified_col("ts"), unqualified_col("number")], + ( + vec![qualified_column("__flow_delta", "ts")], + vec![qualified_column("__flow_sink", "ts")], + ), + vec![ + max_merge_expr("number"), + qualified_col("__flow_delta", "ts").alias("ts"), + ], + ); + assert_same_logical_plan(&rewritten, &expected); } #[tokio::test] @@ -2054,21 +2265,39 @@ mod test { assert_eq!(analysis.merge_columns.len(), 1); let sink_table = single_row_u32_table("global_sink", vec!["number"]); + let sink_table_name = [ + "greptime".to_string(), + "public".to_string(), + "global_sink".to_string(), + ]; let rewritten = rewrite_incremental_aggregate_with_sink_merge( &plan, &analysis, - sink_table, - &[ - "greptime".to_string(), - "public".to_string(), - "global_sink".to_string(), - ], + sink_table.clone(), + &sink_table_name, ) .await .unwrap(); - let plan_text = format!("{}", rewritten.display_indent()); - assert!(plan_text.contains("Left Join")); + let expected = expected_left_join_rewrite( + &plan, + sink_table, + &sink_table_name, + vec![ + unqualified_col("number"), + lit(1i32).alias(GLOBAL_AGGREGATE_JOIN_KEY), + ], + vec![ + unqualified_col("number"), + lit(1i32).alias(GLOBAL_AGGREGATE_JOIN_KEY), + ], + ( + vec![qualified_column("__flow_delta", GLOBAL_AGGREGATE_JOIN_KEY)], + vec![qualified_column("__flow_sink", GLOBAL_AGGREGATE_JOIN_KEY)], + ), + vec![max_merge_expr("number")], + ); + assert_same_logical_plan(&rewritten, &expected); assert_eq!( rewritten .schema() @@ -2090,22 +2319,39 @@ mod test { assert!(analysis.unsupported_exprs.is_empty()); let sink_table = empty_u32_table("empty_global_sink", vec!["number"]); + let sink_table_name = [ + "greptime".to_string(), + "public".to_string(), + "empty_global_sink".to_string(), + ]; let rewritten = rewrite_incremental_aggregate_with_sink_merge( &plan, &analysis, - sink_table, - &[ - "greptime".to_string(), - "public".to_string(), - "empty_global_sink".to_string(), - ], + sink_table.clone(), + &sink_table_name, ) .await .unwrap(); - let plan_text = format!("{}", rewritten.display_indent()); - assert!(plan_text.contains("Left Join")); - assert!(plan_text.contains(GLOBAL_AGGREGATE_JOIN_KEY)); + let expected = expected_left_join_rewrite( + &plan, + sink_table, + &sink_table_name, + vec![ + unqualified_col("number"), + lit(1i32).alias(GLOBAL_AGGREGATE_JOIN_KEY), + ], + vec![ + unqualified_col("number"), + lit(1i32).alias(GLOBAL_AGGREGATE_JOIN_KEY), + ], + ( + vec![qualified_column("__flow_delta", GLOBAL_AGGREGATE_JOIN_KEY)], + vec![qualified_column("__flow_sink", GLOBAL_AGGREGATE_JOIN_KEY)], + ), + vec![max_merge_expr("number")], + ); + assert_same_logical_plan(&rewritten, &expected); assert_eq!( rewritten .schema() @@ -2128,15 +2374,16 @@ mod test { assert_eq!(analysis.literal_columns, vec!["lit".to_string()]); let sink_table = single_row_u32_table("global_literal_sink", vec!["number"]); + let sink_table_name = [ + "greptime".to_string(), + "public".to_string(), + "global_literal_sink".to_string(), + ]; let rewritten = rewrite_incremental_aggregate_with_sink_merge( &plan, &analysis, - sink_table, - &[ - "greptime".to_string(), - "public".to_string(), - "global_literal_sink".to_string(), - ], + sink_table.clone(), + &sink_table_name, ) .await .unwrap(); @@ -2150,6 +2397,29 @@ mod test { .collect::>(), vec!["number".to_string(), "lit".to_string()] ); + let expected = expected_left_join_rewrite( + &plan, + sink_table, + &sink_table_name, + vec![ + unqualified_col("number"), + unqualified_col("lit"), + lit(1i32).alias(GLOBAL_AGGREGATE_JOIN_KEY), + ], + vec![ + unqualified_col("number"), + lit(1i32).alias(GLOBAL_AGGREGATE_JOIN_KEY), + ], + ( + vec![qualified_column("__flow_delta", GLOBAL_AGGREGATE_JOIN_KEY)], + vec![qualified_column("__flow_sink", GLOBAL_AGGREGATE_JOIN_KEY)], + ), + vec![ + max_merge_expr("number"), + qualified_col("__flow_delta", "lit").alias("lit"), + ], + ); + assert_same_logical_plan(&rewritten, &expected); } #[tokio::test] @@ -2163,15 +2433,16 @@ mod test { assert_eq!(analysis.merge_columns.len(), 2); let sink_table = single_row_u32_table("global_multi_merge_sink", vec!["cnt", "total"]); + let sink_table_name = [ + "greptime".to_string(), + "public".to_string(), + "global_multi_merge_sink".to_string(), + ]; let rewritten = rewrite_incremental_aggregate_with_sink_merge( &plan, &analysis, - sink_table, - &[ - "greptime".to_string(), - "public".to_string(), - "global_multi_merge_sink".to_string(), - ], + sink_table.clone(), + &sink_table_name, ) .await .unwrap(); @@ -2185,6 +2456,27 @@ mod test { .collect::>(), vec!["cnt".to_string(), "total".to_string()] ); + let expected = expected_left_join_rewrite( + &plan, + sink_table, + &sink_table_name, + vec![ + unqualified_col("cnt"), + unqualified_col("total"), + lit(1i32).alias(GLOBAL_AGGREGATE_JOIN_KEY), + ], + vec![ + unqualified_col("cnt"), + unqualified_col("total"), + lit(1i32).alias(GLOBAL_AGGREGATE_JOIN_KEY), + ], + ( + vec![qualified_column("__flow_delta", GLOBAL_AGGREGATE_JOIN_KEY)], + vec![qualified_column("__flow_sink", GLOBAL_AGGREGATE_JOIN_KEY)], + ), + vec![sum_merge_expr("cnt"), sum_merge_expr("total")], + ); + assert_same_logical_plan(&rewritten, &expected); } #[tokio::test] @@ -2198,15 +2490,16 @@ mod test { let raw_field_name = "max(numbers_with_ts.number)"; let sink_table = single_row_u32_table("raw_aggregate_sink", vec!["number", raw_field_name]); + let sink_table_name = [ + "greptime".to_string(), + "public".to_string(), + "raw_aggregate_sink".to_string(), + ]; let rewritten = rewrite_incremental_aggregate_with_sink_merge( &plan, &analysis, - sink_table, - &[ - "greptime".to_string(), - "public".to_string(), - "raw_aggregate_sink".to_string(), - ], + sink_table.clone(), + &sink_table_name, ) .await .unwrap(); @@ -2218,8 +2511,22 @@ mod test { .map(|field| field.name().clone()) .collect::>(); assert!(rewritten_fields.contains(&raw_field_name.to_string())); - let plan_text = format!("{}", rewritten.display_indent()); - assert!(plan_text.contains(raw_field_name)); + let expected = expected_left_join_rewrite( + &plan, + sink_table, + &sink_table_name, + vec![unqualified_col("number"), unqualified_col(raw_field_name)], + vec![unqualified_col("number"), unqualified_col(raw_field_name)], + ( + vec![qualified_column("__flow_delta", "number")], + vec![qualified_column("__flow_sink", "number")], + ), + vec![ + max_merge_expr(raw_field_name), + qualified_col("__flow_delta", "number").alias("number"), + ], + ); + assert_same_logical_plan(&rewritten, &expected); } #[tokio::test]