From 89b1a204bf1f277d009182fa3afd2fffe1cf56bd Mon Sep 17 00:00:00 2001 From: discord9 Date: Wed, 20 May 2026 21:30:34 +0800 Subject: [PATCH] fix: rm global aggr Signed-off-by: discord9 --- src/flow/src/batching_mode/utils.rs | 420 ++++++---------------------- 1 file changed, 79 insertions(+), 341 deletions(-) diff --git a/src/flow/src/batching_mode/utils.rs b/src/flow/src/batching_mode/utils.rs index fdce5ccd09..f9737f54f5 100644 --- a/src/flow/src/batching_mode/utils.rs +++ b/src/flow/src/batching_mode/utils.rs @@ -52,8 +52,6 @@ use crate::df_optimizer::apply_df_optimizer; use crate::error::{DatafusionSnafu, ExternalSnafu, InvalidQuerySnafu, TableNotFoundSnafu}; use crate::{Error, TableName}; -const GLOBAL_AGGREGATE_JOIN_KEY: &str = "__flow_global_aggregate_join_key"; - /// Describes how one aggregate output field should be merged with the /// corresponding existing field in the sink table. /// @@ -433,17 +431,11 @@ pub fn analyze_incremental_aggregate_plan( "unsupported GROUPING SETS/CUBE/ROLLUP in incremental aggregate rewrite".to_string(), ); } - unsupported_exprs.extend(projection_info.duplicate_aggregate_aliases.iter().cloned()); - if group_key_names.is_empty() - && projection_info - .output_field_names - .iter() - .any(|name| name == GLOBAL_AGGREGATE_JOIN_KEY) - { - unsupported_exprs.push(format!( - "unsupported output field uses reserved internal name: {GLOBAL_AGGREGATE_JOIN_KEY}" - )); + if group_key_names.is_empty() { + unsupported_exprs + .push("unsupported global aggregate in incremental aggregate rewrite".to_string()); } + unsupported_exprs.extend(projection_info.duplicate_aggregate_aliases.iter().cloned()); for aggr_expr in aggr_exprs { let merge_op = match merge_op_for_aggregate_expr(&aggr_expr) { Ok(merge_op) => merge_op, @@ -513,11 +505,6 @@ pub fn analyze_incremental_aggregate_plan( /// LEFT JOIN sink /// ON __flow_delta.ts IS NOT DISTINCT FROM __flow_sink.ts /// ``` -/// -/// For a global aggregate without group keys, DataFusion still requires a -/// non-empty join condition. We add `__flow_global_aggregate_join_key = 1` to -/// both sides and join on it. This relies on the global aggregate sink keeping a -/// single state row; multiple sink rows would fan out the single delta row. pub async fn rewrite_incremental_aggregate_with_sink_merge( delta_plan: &LogicalPlan, analysis: &IncrementalAggregateAnalysis, @@ -543,9 +530,16 @@ pub async fn rewrite_incremental_aggregate_with_sink_merge( } ); + ensure!( + !analysis.group_key_names.is_empty(), + InvalidQuerySnafu { + reason: "UNSUPPORTED_INCREMENTAL_AGG: global aggregate query is not supported" + .to_string() + } + ); + let delta_alias = "__flow_delta"; let sink_alias = "__flow_sink"; - let is_global_aggregate = analysis.group_key_names.is_empty(); let mut selected_columns = analysis.group_key_names.clone(); selected_columns.extend( @@ -557,20 +551,11 @@ pub async fn rewrite_incremental_aggregate_with_sink_merge( let mut delta_selected_columns = selected_columns.clone(); delta_selected_columns.extend(analysis.literal_columns.iter().cloned()); - let mut delta_selected_exprs = delta_selected_columns + let delta_selected_exprs = delta_selected_columns .iter() .cloned() .map(unqualified_col) .collect::>(); - if is_global_aggregate { - // DataFusion does not allow an empty join condition. A global aggregate - // has exactly one delta row and its sink is expected to hold exactly one - // state row, so both sides use the same internal constant key to express - // "merge the single global state row" as a normal left join. If a sink - // somehow contains multiple rows, this join would fan out; callers must - // maintain the single-row sink invariant for global aggregate flows. - delta_selected_exprs.push(lit(1i32).alias(GLOBAL_AGGREGATE_JOIN_KEY)); - } let delta_selected = LogicalPlanBuilder::from(delta_plan.clone()) .project(delta_selected_exprs) .with_context(|_| DatafusionSnafu { @@ -604,14 +589,11 @@ pub async fn rewrite_incremental_aggregate_with_sink_merge( })?, ); - let mut sink_selected_exprs = selected_columns + let sink_selected_exprs = selected_columns .iter() .cloned() .map(unqualified_col) .collect::>(); - if is_global_aggregate { - sink_selected_exprs.push(lit(1i32).alias(GLOBAL_AGGREGATE_JOIN_KEY)); - } let sink_selected = LogicalPlanBuilder::from(sink_scan) .project(sink_selected_exprs) .with_context(|_| DatafusionSnafu { @@ -626,27 +608,20 @@ pub async fn rewrite_incremental_aggregate_with_sink_merge( context: "Failed to build projected sink plan for incremental sink merge".to_string(), })?; - let join_keys = if is_global_aggregate { - ( - vec![qualified_column(delta_alias, GLOBAL_AGGREGATE_JOIN_KEY)], - vec![qualified_column(sink_alias, GLOBAL_AGGREGATE_JOIN_KEY)], - ) - } else { - ( - analysis - .group_key_names - .iter() - .cloned() - .map(|c| qualified_column(delta_alias, c)) - .collect::>(), - analysis - .group_key_names - .iter() - .cloned() - .map(|c| qualified_column(sink_alias, c)) - .collect::>(), - ) - }; + let join_keys = ( + analysis + .group_key_names + .iter() + .cloned() + .map(|c| qualified_column(delta_alias, c)) + .collect::>(), + analysis + .group_key_names + .iter() + .cloned() + .map(|c| qualified_column(sink_alias, c)) + .collect::>(), + ); let joined = LogicalPlanBuilder::from(delta_selected) .join_detailed( @@ -2193,7 +2168,8 @@ mod test { async fn test_analyze_incremental_aggregate_plan_allows_string_literal_output() { let query_engine = create_test_query_engine(); let ctx = QueryContext::arc(); - let sql = "SELECT max(number) AS number, 'hello' AS label FROM numbers_with_ts"; + let sql = + "SELECT max(number) AS number, ts, 'hello' AS label FROM numbers_with_ts GROUP BY ts"; let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); @@ -2201,73 +2177,25 @@ mod test { assert_eq!(analysis.literal_columns, vec!["label".to_string()]); assert_eq!( analysis.output_field_names, - vec!["number".to_string(), "label".to_string()] + vec!["number".to_string(), "ts".to_string(), "label".to_string()] ); - - let sink_table = single_row_u32_table("string_literal_sink", vec!["number"]); - let sink_table_name = [ - "greptime".to_string(), - "public".to_string(), - "string_literal_sink".to_string(), - ]; - let rewritten = rewrite_incremental_aggregate_with_sink_merge( - &plan, - &analysis, - sink_table.clone(), - &sink_table_name, - ) - .await - .unwrap(); - - assert_eq!( - rewritten - .schema() - .fields() - .iter() - .map(|field| field.name().clone()) - .collect::>(), - vec!["number".to_string(), "label".to_string()] - ); - let expected = expected_left_join_rewrite( - &plan, - sink_table, - &sink_table_name, - vec![ - unqualified_col("number"), - unqualified_col("label"), - lit(1i32).alias(GLOBAL_AGGREGATE_JOIN_KEY), - ], - vec![ - unqualified_col("number"), - lit(1i32).alias(GLOBAL_AGGREGATE_JOIN_KEY), - ], - ( - vec![qualified_column("__flow_delta", GLOBAL_AGGREGATE_JOIN_KEY)], - vec![qualified_column("__flow_sink", GLOBAL_AGGREGATE_JOIN_KEY)], - ), - vec![ - max_merge_expr("number"), - qualified_col("__flow_delta", "label").alias("label"), - ], - ); - assert_same_logical_plan(&rewritten, &expected); } #[tokio::test] async fn test_rewrite_incremental_aggregate_preserves_non_identifier_aliases() { let query_engine = create_test_query_engine(); let ctx = QueryContext::arc(); - let sql = - "SELECT max(number) AS \"max value\", 42 AS \"literal value\" FROM numbers_with_ts"; + let sql = "SELECT max(number) AS \"max value\", number, 42 AS \"literal value\" FROM numbers_with_ts GROUP BY number"; let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); assert!(analysis.unsupported_exprs.is_empty()); assert_eq!( analysis.output_field_names, - vec!["max value", "literal value"] + vec!["max value", "number", "literal value"] ); - let sink_table = single_row_u32_table("non_identifier_alias_sink", vec!["max value"]); + let sink_table = + single_row_u32_table("non_identifier_alias_sink", vec!["number", "max value"]); let rewritten = rewrite_incremental_aggregate_with_sink_merge( &plan, &analysis, @@ -2288,35 +2216,14 @@ mod test { .iter() .map(|field| field.name().clone()) .collect::>(), - vec!["max value".to_string(), "literal value".to_string()] + vec![ + "max value".to_string(), + "number".to_string(), + "literal value".to_string() + ] ); } - #[tokio::test] - async fn test_analyze_incremental_aggregate_plan_rejects_reserved_global_join_key_output() { - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); - let testcases = [ - format!("SELECT max(number) AS \"{GLOBAL_AGGREGATE_JOIN_KEY}\" FROM numbers_with_ts"), - format!("SELECT max(number) AS {GLOBAL_AGGREGATE_JOIN_KEY} FROM numbers_with_ts"), - ]; - - for sql in testcases { - let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), &sql, false) - .await - .unwrap(); - let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); - assert!( - analysis - .unsupported_exprs - .iter() - .any(|expr| expr.contains("reserved internal name")), - "global aggregate output should not collide with the internal join key for SQL {sql}: {:?}", - analysis.unsupported_exprs - ); - } - } - #[tokio::test] async fn test_analyze_incremental_aggregate_plan_rejects_uncovered_outputs() { let query_engine = create_test_query_engine(); @@ -2494,229 +2401,60 @@ mod test { } #[tokio::test] - async fn test_rewrite_incremental_aggregate_with_empty_join_keys_for_global_aggregate() { + async fn test_analyze_incremental_aggregate_plan_rejects_global_aggregate() { + let query_engine = create_test_query_engine(); + let ctx = QueryContext::arc(); + let testcases = [ + "SELECT max(number) AS number FROM numbers_with_ts", + "SELECT max(number) AS number, 42 AS lit FROM numbers_with_ts", + "SELECT count(*) AS cnt, sum(number) AS total FROM numbers_with_ts", + ]; + + for sql in testcases { + let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false) + .await + .unwrap(); + let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); + assert_unsupported(&analysis, "global aggregate"); + } + } + + #[tokio::test] + async fn test_rewrite_incremental_aggregate_rejects_empty_group_keys() { let query_engine = create_test_query_engine(); let ctx = QueryContext::arc(); let sql = "SELECT max(number) AS number FROM numbers_with_ts"; let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); - let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); - assert!(analysis.unsupported_exprs.is_empty()); - assert!(analysis.group_key_names.is_empty()); - assert_eq!(analysis.merge_columns.len(), 1); + let analysis = IncrementalAggregateAnalysis { + group_key_names: vec![], + merge_columns: vec![IncrementalAggregateMergeColumn::new( + "number".to_string(), + IncrementalAggregateMergeOp::Max, + )], + literal_columns: vec![], + output_field_names: vec!["number".to_string()], + unsupported_exprs: vec![], + }; - let sink_table = single_row_u32_table("global_sink", vec!["number"]); + let sink_table = single_row_u32_table("global_guard_sink", vec!["number"]); let sink_table_name = [ "greptime".to_string(), "public".to_string(), - "global_sink".to_string(), + "global_guard_sink".to_string(), ]; - let rewritten = rewrite_incremental_aggregate_with_sink_merge( + let err = rewrite_incremental_aggregate_with_sink_merge( &plan, &analysis, - sink_table.clone(), + sink_table, &sink_table_name, ) .await - .unwrap(); - - let expected = expected_left_join_rewrite( - &plan, - sink_table, - &sink_table_name, - vec![ - unqualified_col("number"), - lit(1i32).alias(GLOBAL_AGGREGATE_JOIN_KEY), - ], - vec![ - unqualified_col("number"), - lit(1i32).alias(GLOBAL_AGGREGATE_JOIN_KEY), - ], - ( - vec![qualified_column("__flow_delta", GLOBAL_AGGREGATE_JOIN_KEY)], - vec![qualified_column("__flow_sink", GLOBAL_AGGREGATE_JOIN_KEY)], - ), - vec![max_merge_expr("number")], + .unwrap_err(); + let err = format!("{err:?}"); + assert!( + err.contains("global aggregate query is not supported"), + "rewrite should defensively reject empty group keys: {err}" ); - assert_same_logical_plan(&rewritten, &expected); - assert_eq!( - rewritten - .schema() - .fields() - .iter() - .map(|field| field.name().clone()) - .collect::>(), - vec!["number".to_string()] - ); - } - - #[tokio::test] - async fn test_rewrite_incremental_aggregate_global_aggregate_with_empty_sink() { - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); - let sql = "SELECT max(number) AS number FROM numbers_with_ts"; - let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); - let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); - assert!(analysis.unsupported_exprs.is_empty()); - - let sink_table = empty_u32_table("empty_global_sink", vec!["number"]); - let sink_table_name = [ - "greptime".to_string(), - "public".to_string(), - "empty_global_sink".to_string(), - ]; - let rewritten = rewrite_incremental_aggregate_with_sink_merge( - &plan, - &analysis, - sink_table.clone(), - &sink_table_name, - ) - .await - .unwrap(); - - let expected = expected_left_join_rewrite( - &plan, - sink_table, - &sink_table_name, - vec![ - unqualified_col("number"), - lit(1i32).alias(GLOBAL_AGGREGATE_JOIN_KEY), - ], - vec![ - unqualified_col("number"), - lit(1i32).alias(GLOBAL_AGGREGATE_JOIN_KEY), - ], - ( - vec![qualified_column("__flow_delta", GLOBAL_AGGREGATE_JOIN_KEY)], - vec![qualified_column("__flow_sink", GLOBAL_AGGREGATE_JOIN_KEY)], - ), - vec![max_merge_expr("number")], - ); - assert_same_logical_plan(&rewritten, &expected); - assert_eq!( - rewritten - .schema() - .fields() - .iter() - .map(|field| field.name().clone()) - .collect::>(), - vec!["number".to_string()] - ); - } - - #[tokio::test] - async fn test_rewrite_incremental_aggregate_global_aggregate_with_literal() { - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); - let sql = "SELECT max(number) AS number, 42 AS lit FROM numbers_with_ts"; - let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); - let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); - assert!(analysis.unsupported_exprs.is_empty()); - assert_eq!(analysis.literal_columns, vec!["lit".to_string()]); - - let sink_table = single_row_u32_table("global_literal_sink", vec!["number"]); - let sink_table_name = [ - "greptime".to_string(), - "public".to_string(), - "global_literal_sink".to_string(), - ]; - let rewritten = rewrite_incremental_aggregate_with_sink_merge( - &plan, - &analysis, - sink_table.clone(), - &sink_table_name, - ) - .await - .unwrap(); - - assert_eq!( - rewritten - .schema() - .fields() - .iter() - .map(|field| field.name().clone()) - .collect::>(), - vec!["number".to_string(), "lit".to_string()] - ); - let expected = expected_left_join_rewrite( - &plan, - sink_table, - &sink_table_name, - vec![ - unqualified_col("number"), - unqualified_col("lit"), - lit(1i32).alias(GLOBAL_AGGREGATE_JOIN_KEY), - ], - vec![ - unqualified_col("number"), - lit(1i32).alias(GLOBAL_AGGREGATE_JOIN_KEY), - ], - ( - vec![qualified_column("__flow_delta", GLOBAL_AGGREGATE_JOIN_KEY)], - vec![qualified_column("__flow_sink", GLOBAL_AGGREGATE_JOIN_KEY)], - ), - vec![ - max_merge_expr("number"), - qualified_col("__flow_delta", "lit").alias("lit"), - ], - ); - assert_same_logical_plan(&rewritten, &expected); - } - - #[tokio::test] - async fn test_rewrite_incremental_aggregate_global_aggregate_with_multiple_merge_columns() { - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); - let sql = "SELECT count(*) AS cnt, sum(number) AS total FROM numbers_with_ts"; - let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); - let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); - assert!(analysis.unsupported_exprs.is_empty()); - assert_eq!(analysis.merge_columns.len(), 2); - - let sink_table = single_row_u32_table("global_multi_merge_sink", vec!["cnt", "total"]); - let sink_table_name = [ - "greptime".to_string(), - "public".to_string(), - "global_multi_merge_sink".to_string(), - ]; - let rewritten = rewrite_incremental_aggregate_with_sink_merge( - &plan, - &analysis, - sink_table.clone(), - &sink_table_name, - ) - .await - .unwrap(); - - assert_eq!( - rewritten - .schema() - .fields() - .iter() - .map(|field| field.name().clone()) - .collect::>(), - vec!["cnt".to_string(), "total".to_string()] - ); - let expected = expected_left_join_rewrite( - &plan, - sink_table, - &sink_table_name, - vec![ - unqualified_col("cnt"), - unqualified_col("total"), - lit(1i32).alias(GLOBAL_AGGREGATE_JOIN_KEY), - ], - vec![ - unqualified_col("cnt"), - unqualified_col("total"), - lit(1i32).alias(GLOBAL_AGGREGATE_JOIN_KEY), - ], - ( - vec![qualified_column("__flow_delta", GLOBAL_AGGREGATE_JOIN_KEY)], - vec![qualified_column("__flow_sink", GLOBAL_AGGREGATE_JOIN_KEY)], - ), - vec![sum_merge_expr("cnt"), sum_merge_expr("total")], - ); - assert_same_logical_plan(&rewritten, &expected); } #[tokio::test]