diff --git a/src/flow/src/batching_mode/utils.rs b/src/flow/src/batching_mode/utils.rs index 0c17dd75d1..412060decc 100644 --- a/src/flow/src/batching_mode/utils.rs +++ b/src/flow/src/batching_mode/utils.rs @@ -28,11 +28,13 @@ use datafusion::sql::unparser::Unparser; use datafusion_common::tree_node::{ Transformed, TreeNode as _, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, }; -use datafusion_common::{DFSchema, DataFusionError, NullEquality, ScalarValue, TableReference}; +use datafusion_common::{ + Column, DFSchema, DataFusionError, NullEquality, ScalarValue, TableReference, +}; use datafusion_expr::logical_plan::TableScan; use datafusion_expr::{ Distinct, JoinType, LogicalPlan, LogicalPlanBuilder, Operator, Projection, and, binary_expr, - bitwise_and, bitwise_or, bitwise_xor, col, is_null, or, when, + bitwise_and, bitwise_or, bitwise_xor, is_null, or, when, }; use datatypes::schema::{ColumnSchema, SchemaRef}; use query::QueryEngineRef; @@ -55,24 +57,20 @@ use crate::{Error, TableName}; /// /// `output_field_name` is the final output/sink schema field name produced by /// the delta plan and read from the sink table. It is not a DataFusion `Column` -/// reference and must not include a plan/table qualifier. +/// reference. It may contain dots or other non-identifier characters when the +/// query keeps DataFusion's raw aggregate output name, e.g. +/// `max(numbers_with_ts.number)`. #[derive(Debug, Clone, PartialEq, Eq)] pub struct IncrementalAggregateMergeColumn { /// Final output/sink field name for the aggregate result/state column. /// - /// Must NOT include a plan/table qualifier (no `.` separator). pub output_field_name: String, pub merge_op: IncrementalAggregateMergeOp, } impl IncrementalAggregateMergeColumn { - /// Create a new merge column, validating that `output_field_name` does not - /// contain a plan/table qualifier. + /// Create a new merge column. pub fn new(output_field_name: String, merge_op: IncrementalAggregateMergeOp) -> Self { - debug_assert!( - !output_field_name.contains('.'), - "output_field_name must not include a plan/table qualifier, got: {output_field_name}" - ); Self { output_field_name, merge_op, @@ -97,7 +95,8 @@ pub enum IncrementalAggregateMergeOp { /// `group_key_names` and each merge column's `output_field_name` are final /// output/sink schema field names used to project both the delta plan and the /// sink table before the left-join merge. They are not DataFusion logical-plan -/// `Column` references and must not be qualified. +/// `Column` references; callers must attach qualifiers structurally instead of +/// formatting qualified names as strings. #[derive(Debug, Clone, PartialEq, Eq)] pub struct IncrementalAggregateAnalysis { /// Final output/sink field names for group keys used as merge join keys. @@ -152,25 +151,46 @@ fn find_column_names(expr: &Expr, names: &mut Vec) { } } -pub fn analyze_incremental_aggregate_plan( - plan: &LogicalPlan, -) -> Result, Error> { +fn unqualified_col(name: impl Into) -> Expr { + Expr::Column(Column::from_name(name.into())) +} + +fn qualified_col(qualifier: &str, name: impl Into) -> Expr { + Expr::Column(Column::new(Some(qualifier), name.into())) +} + +fn qualified_column(qualifier: &str, name: impl Into) -> Column { + Column::new(Some(qualifier), name.into()) +} + +fn find_group_key_names(plan: &LogicalPlan) -> Result, Error> { let mut group_finder = FindGroupByFinalName::default(); plan.visit(&mut group_finder) .with_context(|_| DatafusionSnafu { context: format!("Failed to inspect group-by columns from logical plan: {plan:?}"), })?; + let mut group_key_names = group_finder + .get_group_expr_names() + .unwrap_or_default() + .into_iter() + .collect::>(); + group_key_names.sort(); + Ok(group_key_names) +} + +fn find_aggregate_exprs(plan: &LogicalPlan) -> Result>, Error> { let mut aggregate_finder = LastAggregateExprFinder::default(); plan.visit(&mut aggregate_finder) .with_context(|_| DatafusionSnafu { context: format!("Failed to inspect aggregate expressions from logical plan: {plan:?}"), })?; - let Some(aggr_exprs) = aggregate_finder.aggr_exprs else { - return Ok(None); - }; + Ok(aggregate_finder.aggr_exprs) +} +fn collect_output_aliases(plan: &LogicalPlan) -> (bool, HashMap, HashSet) { let mut output_aliases = HashMap::new(); + let has_top_level_projection = matches!(plan, LogicalPlan::Projection(_)); if let LogicalPlan::Projection(projection) = plan { for expr in &projection.expr { match expr { @@ -202,50 +222,79 @@ pub fn analyze_incremental_aggregate_plan( } } - let mut group_key_names = group_finder - .get_group_expr_names() - .unwrap_or_default() - .into_iter() - .collect::>(); - group_key_names.sort(); + let output_field_names = plan + .schema() + .fields() + .iter() + .map(|field| field.name().clone()) + .collect::>(); + + (has_top_level_projection, output_aliases, output_field_names) +} + +fn merge_op_for_aggregate_expr(aggr_expr: &Expr) -> Option { + let aggr_func = get_aggr_func(aggr_expr)?; + if aggr_func.params.distinct { + return None; + } + + match aggr_func.func.name().to_ascii_lowercase().as_str() { + "sum" | "count" => Some(IncrementalAggregateMergeOp::Sum), + "min" => Some(IncrementalAggregateMergeOp::Min), + "max" => Some(IncrementalAggregateMergeOp::Max), + "bool_and" => Some(IncrementalAggregateMergeOp::BoolAnd), + "bool_or" => Some(IncrementalAggregateMergeOp::BoolOr), + "bit_and" => Some(IncrementalAggregateMergeOp::BitAnd), + "bit_or" => Some(IncrementalAggregateMergeOp::BitOr), + "bit_xor" => Some(IncrementalAggregateMergeOp::BitXor), + _ => None, + } +} + +fn resolve_aggregate_output_field_name( + aggr_expr: &Expr, + has_top_level_projection: bool, + output_aliases: &HashMap, + output_field_names: &HashSet, +) -> Option { + // qualified_name() returns (Option, String) where the second + // element is the unqualified column/alias name. This relies on + // DataFusion's internal naming convention: aggregate expressions + // emit a column named after the aggregate itself (e.g. "SUM(x)"), + // which matches what the projection aliases reference. + let raw_name = aggr_expr.qualified_name().1; + if let Some(alias) = output_aliases.get(&raw_name) { + Some(alias.clone()) + } else if !has_top_level_projection && output_field_names.contains(&raw_name) { + Some(raw_name) + } else { + None + } +} + +pub fn analyze_incremental_aggregate_plan( + plan: &LogicalPlan, +) -> Result, Error> { + let group_key_names = find_group_key_names(plan)?; + let Some(aggr_exprs) = find_aggregate_exprs(plan)? else { + return Ok(None); + }; + let (has_top_level_projection, output_aliases, output_field_names) = + collect_output_aliases(plan); let mut merge_columns = Vec::with_capacity(aggr_exprs.len()); let mut unsupported_exprs = Vec::new(); for aggr_expr in aggr_exprs { - let Some(aggr_func) = get_aggr_func(&aggr_expr) else { + let Some(merge_op) = merge_op_for_aggregate_expr(&aggr_expr) else { unsupported_exprs.push(aggr_expr.to_string()); continue; }; - - let aggr_name = aggr_func.func.name().to_ascii_lowercase(); - let merge_op = if aggr_func.params.distinct { - None - } else { - match aggr_name.as_str() { - "sum" | "count" => Some(IncrementalAggregateMergeOp::Sum), - "min" => Some(IncrementalAggregateMergeOp::Min), - "max" => Some(IncrementalAggregateMergeOp::Max), - "bool_and" => Some(IncrementalAggregateMergeOp::BoolAnd), - "bool_or" => Some(IncrementalAggregateMergeOp::BoolOr), - "bit_and" => Some(IncrementalAggregateMergeOp::BitAnd), - "bit_or" => Some(IncrementalAggregateMergeOp::BitOr), - "bit_xor" => Some(IncrementalAggregateMergeOp::BitXor), - _ => None, - } - }; - - let Some(merge_op) = merge_op else { - unsupported_exprs.push(aggr_expr.to_string()); - continue; - }; - - // qualified_name() returns (Option, String) where the second - // element is the unqualified column/alias name. This relies on - // DataFusion's internal naming convention: aggregate expressions - // emit a column named after the aggregate itself (e.g. "SUM(x)"), - // which matches what the projection aliases reference. - let raw_name = aggr_expr.qualified_name().1; - let Some(output_field_name) = output_aliases.get(&raw_name).cloned() else { + let Some(output_field_name) = resolve_aggregate_output_field_name( + &aggr_expr, + has_top_level_projection, + &output_aliases, + &output_field_names, + ) else { unsupported_exprs.push(aggr_expr.to_string()); continue; }; @@ -298,7 +347,11 @@ pub async fn rewrite_incremental_aggregate_with_sink_merge( .map(|c| c.output_field_name.clone()), ); - let selected_exprs = selected_columns.iter().map(col).collect::>(); + let selected_exprs = selected_columns + .iter() + .cloned() + .map(unqualified_col) + .collect::>(); let delta_selected = LogicalPlanBuilder::from(delta_plan.clone()) .project(selected_exprs.clone()) .with_context(|_| DatafusionSnafu { @@ -350,12 +403,14 @@ pub async fn rewrite_incremental_aggregate_with_sink_merge( analysis .group_key_names .iter() - .map(|c| datafusion_common::Column::from_qualified_name(format!("{delta_alias}.{c}"))) + .cloned() + .map(|c| qualified_column(delta_alias, c)) .collect::>(), analysis .group_key_names .iter() - .map(|c| datafusion_common::Column::from_qualified_name(format!("{sink_alias}.{c}"))) + .cloned() + .map(|c| qualified_column(sink_alias, c)) .collect::>(), ); @@ -379,7 +434,8 @@ pub async fn rewrite_incremental_aggregate_with_sink_merge( let mut projection_exprs = analysis .group_key_names .iter() - .map(|c| col(format!("{delta_alias}.{c}")).alias(c.clone())) + .cloned() + .map(|c| qualified_col(delta_alias, c.clone()).alias(c)) .collect::>(); for merge_col in &analysis.merge_columns { projection_exprs.push(build_left_join_merge_expr( @@ -405,8 +461,8 @@ fn build_left_join_merge_expr( sink_alias: &str, merge_col: &IncrementalAggregateMergeColumn, ) -> Result { - let left = col(format!("{delta_alias}.{}", merge_col.output_field_name)); - let right = col(format!("{sink_alias}.{}", merge_col.output_field_name)); + let left = qualified_col(delta_alias, merge_col.output_field_name.clone()); + let right = qualified_col(sink_alias, merge_col.output_field_name.clone()); let merged = match merge_col.merge_op { IncrementalAggregateMergeOp::Sum => when(is_null(left.clone()), right.clone()) .when(is_null(right.clone()), left.clone()) @@ -1414,6 +1470,117 @@ mod test { })); } + #[tokio::test] + async fn test_analyze_incremental_aggregate_plan_preserves_raw_aggregate_name() { + let query_engine = create_test_query_engine(); + let ctx = QueryContext::arc(); + let sql = "SELECT max(number), ts FROM numbers_with_ts GROUP BY ts"; + let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); + + let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); + assert!(analysis.unsupported_exprs.is_empty()); + assert_eq!(analysis.merge_columns.len(), 1); + assert_eq!( + analysis.merge_columns[0].output_field_name, + "max(numbers_with_ts.number)" + ); + assert_eq!( + analysis.merge_columns[0].merge_op, + IncrementalAggregateMergeOp::Max + ); + } + + #[tokio::test] + async fn test_analyze_incremental_aggregate_plan_rejects_wrapper_aliased_as_raw_name() { + let query_engine = create_test_query_engine(); + let ctx = QueryContext::arc(); + let sql = r#"SELECT COALESCE(max(number), 0) AS "max(numbers_with_ts.number)", ts FROM numbers_with_ts GROUP BY ts"#; + let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); + + let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); + assert!( + !analysis.unsupported_exprs.is_empty(), + "wrapper aliased to a raw aggregate field name must not bypass analysis" + ); + assert!(analysis.merge_columns.is_empty()); + } + + #[tokio::test] + async fn test_analyze_incremental_aggregate_plan_supports_count_star() { + let query_engine = create_test_query_engine(); + let ctx = QueryContext::arc(); + let sql = "SELECT count(*) AS wildcard, ts FROM numbers_with_ts GROUP BY ts"; + let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); + + let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); + assert!(analysis.unsupported_exprs.is_empty()); + assert_eq!(analysis.merge_columns.len(), 1); + assert_eq!(analysis.merge_columns[0].output_field_name, "wildcard"); + assert_eq!( + analysis.merge_columns[0].merge_op, + IncrementalAggregateMergeOp::Sum + ); + } + + #[tokio::test] + async fn test_analyze_incremental_aggregate_plan_supports_aggregate_input_exprs() { + let query_engine = create_test_query_engine(); + let ctx = QueryContext::arc(); + let testcases = [ + "SELECT sum(abs(number)) AS sum_abs, ts FROM numbers_with_ts GROUP BY ts", + "SELECT sum(CASE WHEN number > 5 THEN 1 ELSE 0 END) AS above_five, ts FROM numbers_with_ts GROUP BY ts", + ]; + + for sql in testcases { + let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false) + .await + .unwrap(); + let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); + assert!( + analysis.unsupported_exprs.is_empty(), + "aggregate input expressions should be mergeable for SQL: {sql}" + ); + assert_eq!(analysis.merge_columns.len(), 1); + assert_eq!( + analysis.merge_columns[0].merge_op, + IncrementalAggregateMergeOp::Sum + ); + } + } + + #[tokio::test] + async fn test_analyze_incremental_aggregate_plan_rejects_output_expr_wrappers() { + let query_engine = create_test_query_engine(); + let ctx = QueryContext::arc(); + let testcases = [ + "SELECT abs(sum(number)) AS abs_sum, ts FROM numbers_with_ts GROUP BY ts", + "SELECT max(number) - min(number) AS maxmin, ts FROM numbers_with_ts GROUP BY ts", + "SELECT count(number) + 123 AS total_count, ts FROM numbers_with_ts GROUP BY ts", + "SELECT sum(CASE WHEN number > 5 THEN 1 ELSE 0 END) / count(number) AS ratio, ts FROM numbers_with_ts GROUP BY ts", + ]; + + for sql in testcases { + let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false) + .await + .unwrap(); + let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); + assert!( + !analysis.unsupported_exprs.is_empty(), + "aggregate output wrappers should be rejected for SQL: {sql}" + ); + } + } + + #[test] + fn test_qualified_col_preserves_non_identifier_field_name() { + let expr = qualified_col("__flow_delta", "max(numbers_with_ts.number)"); + let Expr::Column(column) = expr else { + panic!("expected column expression"); + }; + assert_eq!(column.name, "max(numbers_with_ts.number)"); + assert_eq!(column.relation.unwrap().to_string(), "__flow_delta"); + } + #[tokio::test] async fn test_analyze_incremental_aggregate_plan_multiple_group_keys() { let query_engine = create_test_query_engine();