diff --git a/src/flow/src/batching_mode/utils.rs b/src/flow/src/batching_mode/utils.rs index 78ee05ad9a..073f9c6947 100644 --- a/src/flow/src/batching_mode/utils.rs +++ b/src/flow/src/batching_mode/utils.rs @@ -28,7 +28,7 @@ use datafusion::sql::unparser::Unparser; use datafusion_common::tree_node::{ Transformed, TreeNode as _, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, }; -use datafusion_common::{DFSchema, DataFusionError, ScalarValue, TableReference}; +use datafusion_common::{DFSchema, DataFusionError, NullEquality, ScalarValue, TableReference}; use datafusion_expr::logical_plan::TableScan; use datafusion_expr::{ Distinct, JoinType, LogicalPlan, LogicalPlanBuilder, Operator, Projection, and, binary_expr, @@ -105,6 +105,23 @@ impl TreeNodeVisitor<'_> for LastAggregateExprFinder { } } +/// Recursively find all `Expr::Column` names inside an expression tree. +/// Only recurses into wrappers that are merge-transparent (type casts). +/// Non-transparent wrappers (e.g., `ScalarFunction`, `Negative`) are +/// intentionally not recursed into since their merge semantics would be +/// incorrect — the caller will fall back to the raw aggregate name. +fn find_column_names(expr: &Expr, names: &mut Vec) { + match expr { + Expr::Column(col) => { + names.push(col.name.clone()); + } + Expr::Alias(alias) => find_column_names(&alias.expr, names), + Expr::Cast(cast) => find_column_names(&cast.expr, names), + Expr::TryCast(try_cast) => find_column_names(&try_cast.expr, names), + _ => {} + } +} + pub fn analyze_incremental_aggregate_plan( plan: &LogicalPlan, ) -> Result, Error> { @@ -128,12 +145,26 @@ pub fn analyze_incremental_aggregate_plan( for expr in &projection.expr { match expr { Expr::Alias(alias) => { - if let Expr::Column(col) = alias.expr.as_ref() { - output_aliases.insert(col.name.clone(), alias.name.clone()); + // Alias resolution has three cases: + // - 0 Column refs (e.g., literal `42 AS lit`): skip — no mapping + // - 1 Column ref: record the mapping (e.g., `CAST(sum(x)) AS total`) + // - >1 Column refs (e.g., `COALESCE(sum(x), sum(y))`): + // skip — ambiguous merge semantics, fall back to raw agg name + let alias_name = alias.name.clone(); + let mut col_names = Vec::new(); + find_column_names(&alias.expr, &mut col_names); + if col_names.len() == 1 { + if let Some(col_name) = col_names.into_iter().next() { + output_aliases.entry(col_name).or_insert(alias_name); + } } + // If >1 column references detected (e.g., COALESCE(sum(x), sum(y))), + // intentionally skip alias mapping — the merge semantics are ambiguous. } Expr::Column(col) => { - output_aliases.insert(col.name.clone(), col.name.clone()); + output_aliases + .entry(col.name.clone()) + .or_insert(col.name.clone()); } _ => {} } @@ -178,7 +209,10 @@ pub fn analyze_incremental_aggregate_plan( }; let raw_name = aggr_expr.qualified_name().1; - let output_field_name = output_aliases.get(&raw_name).cloned().unwrap_or(raw_name); + let Some(output_field_name) = output_aliases.get(&raw_name).cloned() else { + unsupported_exprs.push(aggr_expr.to_string()); + continue; + }; merge_columns.push(IncrementalAggregateMergeColumn { output_field_name, merge_op, @@ -290,7 +324,13 @@ pub async fn rewrite_incremental_aggregate_with_sink_merge( ); let joined = LogicalPlanBuilder::from(delta_selected) - .join(sink_selected, JoinType::Left, join_keys, None) + .join_detailed( + sink_selected, + JoinType::Left, + join_keys, + None, + NullEquality::NullEqualsNull, + ) .with_context(|_| DatafusionSnafu { context: "Failed to left join delta and sink plans for incremental sink merge" .to_string(), @@ -1340,6 +1380,28 @@ mod test { assert!(!analysis.unsupported_exprs.is_empty()); } + #[tokio::test] + async fn test_analyze_incremental_aggregate_plan_rejects_coalesce_wrapped_aggregate() { + let query_engine = create_test_query_engine(); + let ctx = QueryContext::arc(); + // COALESCE wraps the aggregate output — the wrapper is not merge-transparent, + // so the analyzer should mark the aggregate as unsupported rather than + // attempting an unsafe incremental rewrite. + let sql = + "SELECT COALESCE(max(number), 0) AS coalesced_max, ts FROM numbers_with_ts GROUP BY ts"; + let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); + let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); + // Non-transparent wrapper → alias unresolvable → unsupported + assert!( + !analysis.unsupported_exprs.is_empty(), + "COALESCE-wrapped aggregate should be unsupported" + ); + assert!( + analysis.merge_columns.is_empty(), + "COALESCE-wrapped aggregate should have no merge columns" + ); + } + #[tokio::test] async fn test_rewrite_incremental_aggregate_with_left_join() { let query_engine = create_test_query_engine(); @@ -1391,4 +1453,27 @@ mod test { .encode(&plan, DefaultSerializer) .unwrap(); } + + #[tokio::test] + async fn test_analyze_incremental_aggregate_plan_handles_cast_wrapped_alias() { + let query_engine = create_test_query_engine(); + let ctx = QueryContext::arc(); + // CAST wraps the aggregate output — the analyzer should still find the alias + let sql = + "SELECT CAST(sum(number) AS BIGINT) AS total, ts FROM numbers_with_ts GROUP BY ts"; + let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); + let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); + assert!(analysis.unsupported_exprs.is_empty()); + assert!(analysis.group_key_names.contains(&"ts".to_string())); + assert_eq!(analysis.merge_columns.len(), 1); + assert_eq!( + analysis.merge_columns[0].output_field_name, "total", + "Expected alias 'total' for CAST-wrapped aggregate, but got '{}'", + analysis.merge_columns[0].output_field_name + ); + assert_eq!( + analysis.merge_columns[0].merge_op, + IncrementalAggregateMergeOp::Sum + ); + } }