From 3d8dba40cbd3df8a6fe23e992cfd389411c6ddd7 Mon Sep 17 00:00:00 2001 From: discord9 Date: Thu, 21 May 2026 12:42:18 +0800 Subject: [PATCH] fix: computed shadow expr Signed-off-by: discord9 --- src/flow/src/batching_mode/utils.rs | 1420 ++-------------------- src/flow/src/batching_mode/utils/test.rs | 43 + 2 files changed, 141 insertions(+), 1322 deletions(-) diff --git a/src/flow/src/batching_mode/utils.rs b/src/flow/src/batching_mode/utils.rs index 3c1d89e2cb..7b066388ec 100644 --- a/src/flow/src/batching_mode/utils.rs +++ b/src/flow/src/batching_mode/utils.rs @@ -31,10 +31,10 @@ use datafusion_common::tree_node::{ use datafusion_common::{ Column, DFSchema, DataFusionError, NullEquality, ScalarValue, TableReference, }; -use datafusion_expr::logical_plan::TableScan; +use datafusion_expr::logical_plan::{Aggregate, TableScan}; use datafusion_expr::{ Distinct, JoinType, LogicalPlan, LogicalPlanBuilder, Operator, Projection, and, binary_expr, - bitwise_and, bitwise_or, bitwise_xor, is_null, lit, or, when, + bitwise_and, bitwise_or, bitwise_xor, is_null, or, when, }; use datatypes::schema::{ColumnSchema, SchemaRef}; use query::QueryEngineRef; @@ -112,31 +112,6 @@ pub struct IncrementalAggregateAnalysis { pub unsupported_exprs: Vec, } -/// Visitor that captures aggregate expressions and counts `Aggregate` nodes in -/// the plan tree. -/// -/// Incremental aggregate rewrite only supports plans with exactly one aggregate -/// node. The count lets the analyzer reject nested/sibling aggregate plans -/// instead of accidentally rewriting against whichever aggregate was visited -/// last. -#[derive(Default)] -struct AggregateExprFinder { - aggr_exprs: Option>, - aggregate_count: usize, -} - -impl TreeNodeVisitor<'_> for AggregateExprFinder { - type Node = LogicalPlan; - - fn f_down(&mut self, node: &Self::Node) -> datafusion_common::Result { - if let LogicalPlan::Aggregate(aggregate) = node { - self.aggregate_count += 1; - self.aggr_exprs = Some(aggregate.aggr_expr.clone()); - } - Ok(TreeNodeRecursion::Continue) - } -} - /// Recursively find all `Expr::Column` names inside an expression tree. /// Only recurses into wrappers that are merge-transparent. /// Non-transparent wrappers (e.g., `ScalarFunction`, `Negative`, `Cast`) are @@ -183,18 +158,6 @@ fn find_group_key_names(plan: &LogicalPlan) -> Result, Error> { Ok(group_key_names) } -fn find_aggregate_exprs(plan: &LogicalPlan) -> Result<(usize, Option>), Error> { - let mut aggregate_finder = AggregateExprFinder::default(); - plan.visit(&mut aggregate_finder) - .with_context(|_| DatafusionSnafu { - context: format!("Failed to inspect aggregate expressions from logical plan: {plan:?}"), - })?; - Ok(( - aggregate_finder.aggregate_count, - aggregate_finder.aggr_exprs, - )) -} - fn has_grouping_set(plan: &LogicalPlan) -> bool { match plan { LogicalPlan::Aggregate(aggregate) => aggregate @@ -205,7 +168,21 @@ fn has_grouping_set(plan: &LogicalPlan) -> bool { } } -fn check_inc_aggr_plan_shape(plan: &LogicalPlan) -> Result<(), String> { +fn has_aggregate(plan: &LogicalPlan) -> bool { + match plan { + LogicalPlan::Aggregate(_) => true, + _ => plan.inputs().into_iter().any(has_aggregate), + } +} + +fn peel_subquery_aliases(mut plan: &LogicalPlan) -> &LogicalPlan { + while let LogicalPlan::SubqueryAlias(alias) = plan { + plan = alias.input.as_ref(); + } + plan +} + +fn extract_incremental_aggregate(plan: &LogicalPlan) -> Result, String> { // Supported final shape: optional output Projection directly over one // Aggregate. Post-aggregate filters (HAVING), ordering, limits, // distinct/window/union/extension nodes are intentionally not accepted. @@ -215,26 +192,34 @@ fn check_inc_aggr_plan_shape(plan: &LogicalPlan) -> Result<(), String> { }; match plan { - LogicalPlan::Aggregate(aggregate) => check_input_plan_shape(aggregate.input.as_ref()), - LogicalPlan::Filter(_) => Err( + LogicalPlan::Aggregate(aggregate) => { + check_input_plan_shape(aggregate.input.as_ref())?; + Ok(Some(aggregate)) + } + LogicalPlan::Filter(filter) if has_aggregate(filter.input.as_ref()) => Err( "unsupported post-aggregate filter (HAVING) in incremental aggregate rewrite" .to_string(), ), - _ => Err( + _ if has_aggregate(plan) => Err( "unsupported post-aggregate plan shape in incremental aggregate rewrite".to_string(), ), + _ => Ok(None), } } fn check_input_plan_shape(plan: &LogicalPlan) -> Result<(), String> { + let plan = peel_subquery_aliases(plan); match plan { // Supported aggregate input shape: optional WHERE filter over a table scan. + // SubqueryAlias is a transparent naming wrapper for `FROM table AS alias`. LogicalPlan::TableScan(_) => Ok(()), - LogicalPlan::Filter(filter) - if matches!(filter.input.as_ref(), LogicalPlan::TableScan(_)) => - { - Ok(()) - } + LogicalPlan::Filter(filter) => match peel_subquery_aliases(filter.input.as_ref()) { + LogicalPlan::TableScan(_) => Ok(()), + _ => Err( + "unsupported aggregate input plan shape in incremental aggregate rewrite" + .to_string(), + ), + }, _ => Err( "unsupported aggregate input plan shape in incremental aggregate rewrite".to_string(), ), @@ -405,13 +390,71 @@ fn find_uncovered_output_fields( .collect() } +fn find_unsupported_group_key_projection_outputs( + plan: &LogicalPlan, + aggregate: &Aggregate, + group_key_names: &[String], +) -> Vec { + let LogicalPlan::Projection(projection) = plan else { + return vec![]; + }; + + let group_key_names = group_key_names.iter().cloned().collect::>(); + let group_expr_names = aggregate + .group_expr + .iter() + .filter_map(|expr| expr.name_for_alias().ok()) + .collect::>(); + projection + .expr + .iter() + .filter_map(|expr| { + let output_name = expr.qualified_name().1; + if !group_key_names.contains(&output_name) { + return None; + } + + let source_name = match expr { + Expr::Alias(alias) => alias.expr.name_for_alias().ok(), + _ => expr.name_for_alias().ok(), + }; + if source_name.is_some_and(|name| group_expr_names.contains(&name)) { + None + } else { + Some(format!( + "unsupported group key output field is not a transparent group expression: {output_name}" + )) + } + }) + .collect() +} + pub fn analyze_incremental_aggregate_plan( plan: &LogicalPlan, ) -> Result, Error> { let group_key_names = find_group_key_names(plan)?; - let (aggregate_count, Some(aggr_exprs)) = find_aggregate_exprs(plan)? else { - return Ok(None); + let aggregate = match extract_incremental_aggregate(plan) { + Ok(Some(aggregate)) => aggregate, + Ok(None) => return Ok(None), + Err(reason) => { + let projection_info = collect_output_projection_info(plan); + let mut unsupported_exprs = projection_info + .duplicate_output_names() + .into_iter() + .map(|name| format!("duplicate output field name: {name}")) + .collect::>(); + unsupported_exprs.push(reason); + unsupported_exprs.extend(projection_info.duplicate_aggregate_aliases.iter().cloned()); + return Ok(Some(IncrementalAggregateAnalysis { + group_key_names, + merge_columns: vec![], + literal_columns: vec![], + output_field_names: projection_info.output_field_names, + unsupported_exprs, + })); + } }; + let aggr_exprs = aggregate.aggr_expr.clone(); let projection_info = collect_output_projection_info(plan); let output_field_name_set = projection_info.output_field_name_set(); @@ -421,14 +464,6 @@ pub fn analyze_incremental_aggregate_plan( .into_iter() .map(|name| format!("duplicate output field name: {name}")) .collect::>(); - if aggregate_count != 1 { - unsupported_exprs.push(format!( - "unsupported aggregate plan contains {aggregate_count} Aggregate nodes" - )); - } - if let Err(reason) = check_inc_aggr_plan_shape(plan) { - unsupported_exprs.push(reason); - } if has_grouping_set(plan) { unsupported_exprs.push( "unsupported GROUPING SETS/CUBE/ROLLUP in incremental aggregate rewrite".to_string(), @@ -438,6 +473,11 @@ pub fn analyze_incremental_aggregate_plan( unsupported_exprs .push("unsupported global aggregate in incremental aggregate rewrite".to_string()); } + unsupported_exprs.extend(find_unsupported_group_key_projection_outputs( + plan, + aggregate, + &group_key_names, + )); unsupported_exprs.extend(projection_info.duplicate_aggregate_aliases.iter().cloned()); for aggr_expr in aggr_exprs { let merge_op = match merge_op_for_aggregate_expr(&aggr_expr) { @@ -1299,1267 +1339,3 @@ impl TreeNodeRewriter for AddFilterRewriter { } } } - -#[cfg(test)] -mod test { - use std::sync::Arc; - - use common_recordbatch::RecordBatch; - use datafusion_common::tree_node::TreeNode as _; - use datafusion_expr::GroupingSet; - use datatypes::prelude::{ConcreteDataType, Scalar, VectorRef}; - use datatypes::schema::{ColumnSchema, Schema}; - use pretty_assertions::assert_eq; - use query::query_engine::DefaultSerializer; - use session::context::QueryContext; - use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan}; - use table::test_util::MemTable; - - use super::*; - use crate::test_utils::create_test_query_engine; - - fn u32_table(table_name: &str, columns: Vec<&str>, rows: usize) -> TableRef { - let column_schemas = columns - .iter() - .map(|name| ColumnSchema::new(*name, ConcreteDataType::uint32_datatype(), false)) - .collect::>(); - let vectors = columns - .iter() - .map(|_| Arc::new(::VectorType::from_vec(vec![1; rows])) as VectorRef) - .collect::>(); - let schema = Arc::new(Schema::new(column_schemas)); - let recordbatch = RecordBatch::new(schema, vectors).unwrap(); - MemTable::table(table_name, recordbatch) - } - - fn single_row_u32_table(table_name: &str, columns: Vec<&str>) -> TableRef { - u32_table(table_name, columns, 1) - } - - fn empty_u32_table(table_name: &str, columns: Vec<&str>) -> TableRef { - u32_table(table_name, columns, 0) - } - - fn assert_same_logical_plan(actual: &LogicalPlan, expected: &LogicalPlan) { - assert_eq!( - format!("{}", expected.display_indent()), - format!("{}", actual.display_indent()) - ); - } - - fn test_sink_scan(sink_table: TableRef, sink_table_name: &TableName) -> LogicalPlan { - let table_provider = Arc::new(DfTableProviderAdapter::new(sink_table)); - let table_source = Arc::new(DefaultTableSource::new(table_provider)); - LogicalPlan::TableScan( - TableScan::try_new( - TableReference::Full { - catalog: sink_table_name[0].clone().into(), - schema: sink_table_name[1].clone().into(), - table: sink_table_name[2].clone().into(), - }, - table_source, - None, - vec![], - None, - ) - .unwrap(), - ) - } - - fn expected_left_join_rewrite( - delta_plan: &LogicalPlan, - sink_table: TableRef, - sink_table_name: &TableName, - delta_selected_exprs: Vec, - sink_selected_exprs: Vec, - join_keys: (Vec, Vec), - projection_exprs: Vec, - ) -> LogicalPlan { - let delta_alias = "__flow_delta"; - let sink_alias = "__flow_sink"; - let delta_selected = LogicalPlanBuilder::from(delta_plan.clone()) - .project(delta_selected_exprs) - .unwrap() - .alias(delta_alias) - .unwrap() - .build() - .unwrap(); - let sink_selected = LogicalPlanBuilder::from(test_sink_scan(sink_table, sink_table_name)) - .project(sink_selected_exprs) - .unwrap() - .alias(sink_alias) - .unwrap() - .build() - .unwrap(); - let joined = LogicalPlanBuilder::from(delta_selected) - .join_detailed( - sink_selected, - JoinType::Left, - join_keys, - None, - NullEquality::NullEqualsNull, - ) - .unwrap() - .build() - .unwrap(); - LogicalPlanBuilder::from(joined) - .project(projection_exprs) - .unwrap() - .build() - .unwrap() - } - - fn max_merge_expr(field_name: &str) -> Expr { - let left = qualified_col("__flow_delta", field_name); - let right = qualified_col("__flow_sink", field_name); - when(is_null(right.clone()), left.clone()) - .when(left.clone().gt_eq(right.clone()), left) - .otherwise(right) - .unwrap() - .alias(field_name) - } - - fn sum_merge_expr(field_name: &str) -> Expr { - let left = qualified_col("__flow_delta", field_name); - let right = qualified_col("__flow_sink", field_name); - when(is_null(left.clone()), right.clone()) - .when(is_null(right.clone()), left.clone()) - .otherwise(binary_expr(left, Operator::Plus, right)) - .unwrap() - .alias(field_name) - } - - async fn analyze_test_sql(sql: &str) -> IncrementalAggregateAnalysis { - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); - let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); - analyze_incremental_aggregate_plan(&plan).unwrap().unwrap() - } - - async fn analyze_grouping_set_plan( - make_grouping_set: impl FnOnce(Expr) -> GroupingSet, - ) -> IncrementalAggregateAnalysis { - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); - let plan = sql_to_df_plan( - ctx, - query_engine, - "SELECT sum(number) AS number, ts FROM numbers_with_ts GROUP BY ts", - false, - ) - .await - .unwrap(); - - let LogicalPlan::Projection(projection) = plan else { - panic!("expected projection over aggregate") - }; - let LogicalPlan::Aggregate(aggregate) = projection.input.as_ref() else { - panic!("expected aggregate below projection") - }; - let group_expr = aggregate.group_expr[0].clone(); - let grouping_set_aggregate = datafusion_expr::logical_plan::Aggregate::try_new( - aggregate.input.clone(), - vec![Expr::GroupingSet(make_grouping_set(group_expr))], - aggregate.aggr_expr.clone(), - ) - .unwrap(); - let plan = LogicalPlan::Aggregate(grouping_set_aggregate); - analyze_incremental_aggregate_plan(&plan).unwrap().unwrap() - } - - fn assert_unsupported(analysis: &IncrementalAggregateAnalysis, reason: &str) { - assert!( - analysis - .unsupported_exprs - .iter() - .any(|expr| expr.contains(reason)), - "expected unsupported reason containing {reason:?}, got {:?}", - analysis.unsupported_exprs - ); - assert!( - analysis.merge_columns.is_empty(), - "unsupported analysis should disable merge columns" - ); - } - - /// test if uppercase are handled correctly(with quote) - #[tokio::test] - async fn test_sql_plan_convert() { - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); - let old = r#"SELECT "NUMBER" FROM "UPPERCASE_NUMBERS_WITH_TS""#; - let new = sql_to_df_plan(ctx.clone(), query_engine.clone(), old, false) - .await - .unwrap(); - let new_sql = df_plan_to_sql(&new).unwrap(); - - assert_eq!( - r#"SELECT `UPPERCASE_NUMBERS_WITH_TS`.`NUMBER` FROM `UPPERCASE_NUMBERS_WITH_TS`"#, - new_sql - ); - } - - #[tokio::test] - async fn test_add_filter() { - let testcases = vec![ - ( - "SELECT number FROM numbers_with_ts GROUP BY number", - "SELECT numbers_with_ts.number FROM numbers_with_ts WHERE (number > 4) GROUP BY numbers_with_ts.number", - ), - ( - "SELECT number FROM numbers_with_ts WHERE number < 2 OR number >10", - "SELECT numbers_with_ts.number FROM numbers_with_ts WHERE ((numbers_with_ts.number < 2) OR (numbers_with_ts.number > 10)) AND (number > 4)", - ), - ( - "SELECT date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window", - "SELECT date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE (number > 4) GROUP BY date_bin('5 minutes', numbers_with_ts.ts)", - ), - // subquery - ( - "SELECT number, time_window FROM (SELECT number, date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window, number);", - "SELECT numbers_with_ts.number, time_window FROM (SELECT numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts) AS time_window FROM numbers_with_ts WHERE (number > 4) GROUP BY date_bin('5 minutes', numbers_with_ts.ts), numbers_with_ts.number)", - ), - // complex subquery without alias - ( - "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) GROUP BY number, time_window, bucket_name;", - "SELECT sum(numbers_with_ts.number), numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts) AS time_window, bucket_name FROM (SELECT numbers_with_ts.number, numbers_with_ts.ts, CASE WHEN (numbers_with_ts.number < 5) THEN 'bucket_0_5' WHEN (numbers_with_ts.number >= 5) THEN 'bucket_5_inf' END AS bucket_name FROM numbers_with_ts WHERE (number > 4)) GROUP BY numbers_with_ts.number, date_bin('5 minutes', numbers_with_ts.ts), bucket_name", - ), - // complex subquery alias - ( - "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) as cte WHERE number > 1 GROUP BY number, time_window, bucket_name;", - "SELECT sum(cte.number), cte.number, date_bin('5 minutes', cte.ts) AS time_window, cte.bucket_name FROM (SELECT numbers_with_ts.number, numbers_with_ts.ts, CASE WHEN (numbers_with_ts.number < 5) THEN 'bucket_0_5' WHEN (numbers_with_ts.number >= 5) THEN 'bucket_5_inf' END AS bucket_name FROM numbers_with_ts WHERE (number > 4)) AS cte WHERE (cte.number > 1) GROUP BY cte.number, date_bin('5 minutes', cte.ts), cte.bucket_name", - ), - ]; - use datafusion_expr::{col, lit}; - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); - - for (before, after) in testcases { - let sql = before; - let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false) - .await - .unwrap(); - - let mut add_filter = AddFilterRewriter::new(col("number").gt(lit(4u32))); - let plan = plan.rewrite(&mut add_filter).unwrap().data; - let new_sql = df_plan_to_sql(&plan).unwrap(); - assert_eq!(after, new_sql); - } - } - - #[tokio::test] - async fn test_add_auto_column_rewriter() { - let testcases = vec![ - // add update_at - ( - "SELECT number FROM numbers_with_ts", - Ok("SELECT numbers_with_ts.number, now() AS ts FROM numbers_with_ts"), - vec![ - ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true), - ColumnSchema::new( - "ts", - ConcreteDataType::timestamp_millisecond_datatype(), - false, - ) - .with_time_index(true), - ], - ), - // add ts placeholder - ( - "SELECT number FROM numbers_with_ts", - Ok( - "SELECT numbers_with_ts.number, CAST('1970-01-01 00:00:00' AS TIMESTAMP) AS __ts_placeholder FROM numbers_with_ts", - ), - vec![ - ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true), - ColumnSchema::new( - AUTO_CREATED_PLACEHOLDER_TS_COL, - ConcreteDataType::timestamp_millisecond_datatype(), - false, - ) - .with_time_index(true), - ], - ), - // no modify - ( - "SELECT number, ts FROM numbers_with_ts", - Ok("SELECT numbers_with_ts.number, numbers_with_ts.ts FROM numbers_with_ts"), - vec![ - ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true), - ColumnSchema::new( - "ts", - ConcreteDataType::timestamp_millisecond_datatype(), - false, - ) - .with_time_index(true), - ], - ), - // add update_at and ts placeholder - ( - "SELECT number FROM numbers_with_ts", - Ok( - "SELECT numbers_with_ts.number, now() AS update_at, CAST('1970-01-01 00:00:00' AS TIMESTAMP) AS __ts_placeholder FROM numbers_with_ts", - ), - vec![ - ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true), - ColumnSchema::new( - "update_at", - ConcreteDataType::timestamp_millisecond_datatype(), - false, - ), - ColumnSchema::new( - AUTO_CREATED_PLACEHOLDER_TS_COL, - ConcreteDataType::timestamp_millisecond_datatype(), - false, - ) - .with_time_index(true), - ], - ), - // add ts placeholder - ( - "SELECT number, ts FROM numbers_with_ts", - Ok( - "SELECT numbers_with_ts.number, numbers_with_ts.ts AS update_at, CAST('1970-01-01 00:00:00' AS TIMESTAMP) AS __ts_placeholder FROM numbers_with_ts", - ), - vec![ - ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true), - ColumnSchema::new( - "update_at", - ConcreteDataType::timestamp_millisecond_datatype(), - false, - ), - ColumnSchema::new( - AUTO_CREATED_PLACEHOLDER_TS_COL, - ConcreteDataType::timestamp_millisecond_datatype(), - false, - ) - .with_time_index(true), - ], - ), - // add update_at after time index column - ( - "SELECT number, ts FROM numbers_with_ts", - Ok( - "SELECT numbers_with_ts.number, numbers_with_ts.ts, now() AS update_atat FROM numbers_with_ts", - ), - vec![ - ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true), - ColumnSchema::new( - "ts", - ConcreteDataType::timestamp_millisecond_datatype(), - false, - ) - .with_time_index(true), - ColumnSchema::new( - // name is irrelevant for update_at column - "update_atat", - ConcreteDataType::timestamp_millisecond_datatype(), - false, - ), - ], - ), - // error datatype mismatch - ( - "SELECT number, ts FROM numbers_with_ts", - Err( - "Expect the last column in table to be timestamp column, found column atat with type Int8", - ), - vec![ - ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true), - ColumnSchema::new( - "ts", - ConcreteDataType::timestamp_millisecond_datatype(), - false, - ) - .with_time_index(true), - ColumnSchema::new( - // name is irrelevant for update_at column - "atat", - ConcreteDataType::int8_datatype(), - false, - ), - ], - ), - // error datatype mismatch on second last column - ( - "SELECT number FROM numbers_with_ts", - Err( - "Expect the second last column in the table to be timestamp column, found column ts with type Int8", - ), - vec![ - ColumnSchema::new("number", ConcreteDataType::int32_datatype(), true), - ColumnSchema::new("ts", ConcreteDataType::int8_datatype(), false), - ColumnSchema::new( - // name is irrelevant for update_at column - "atat", - ConcreteDataType::timestamp_millisecond_datatype(), - false, - ) - .with_time_index(true), - ], - ), - ]; - - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); - for (before, after, column_schemas) in testcases { - let schema = Arc::new(Schema::new(column_schemas)); - let mut add_auto_column_rewriter = - ColumnMatcherRewriter::new(schema, Vec::new(), false); - - let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), before, false) - .await - .unwrap(); - let new_sql = (|| { - let plan = plan - .rewrite(&mut add_auto_column_rewriter) - .map_err(|e| e.to_string())? - .data; - df_plan_to_sql(&plan).map_err(|e| e.to_string()) - })(); - match (after, new_sql.clone()) { - (Ok(after), Ok(new_sql)) => assert_eq!(after, new_sql), - (Err(expected), Err(real_err_msg)) => assert!( - real_err_msg.contains(expected), - "expected: {expected}, real: {real_err_msg}" - ), - _ => panic!("expected: {:?}, real: {:?}", after, new_sql), - } - } - } - - #[tokio::test] - async fn test_find_group_by_exprs() { - let testcases = vec![ - ( - "SELECT arrow_cast(date_bin(INTERVAL '1 MINS', numbers_with_ts.ts), 'Timestamp(Second, None)') AS ts FROM numbers_with_ts GROUP BY ts;", - vec!["ts"], - ), - ( - "SELECT number FROM numbers_with_ts GROUP BY number", - vec!["number"], - ), - ( - "SELECT date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window", - vec!["time_window"], - ), - // subquery - ( - "SELECT number, time_window FROM (SELECT number, date_bin('5 minutes', ts) as time_window FROM numbers_with_ts GROUP BY time_window, number);", - vec!["time_window", "number"], - ), - // complex subquery without alias - ( - "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) GROUP BY number, time_window, bucket_name;", - vec!["number", "time_window", "bucket_name"], - ), - // complex subquery alias - ( - "SELECT sum(number), number, date_bin('5 minutes', ts) as time_window, bucket_name FROM (SELECT number, ts, case when number < 5 THEN 'bucket_0_5' when number >= 5 THEN 'bucket_5_inf' END as bucket_name FROM numbers_with_ts) as cte GROUP BY number, time_window, bucket_name;", - vec!["number", "time_window", "bucket_name"], - ), - ]; - - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); - for (sql, expected) in testcases { - // need to be unoptimize for better readiability - let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false) - .await - .unwrap(); - let mut group_by_exprs = FindGroupByFinalName::default(); - plan.visit(&mut group_by_exprs).unwrap(); - let expected: HashSet = expected.into_iter().map(|s| s.to_string()).collect(); - assert_eq!( - expected, - group_by_exprs.get_group_expr_names().unwrap_or_default() - ); - } - } - - #[tokio::test] - async fn test_analyze_incremental_aggregate_plan() { - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); - let testcases: Vec<(&str, IncrementalAggregateMergeOp, &str)> = vec![ - ( - "SELECT sum(number) AS number, ts FROM numbers_with_ts GROUP BY ts", - IncrementalAggregateMergeOp::Sum, - "number", - ), - ( - "SELECT count(number) AS number, ts FROM numbers_with_ts GROUP BY ts", - IncrementalAggregateMergeOp::Sum, - "number", - ), - ( - "SELECT min(number) AS number, ts FROM numbers_with_ts GROUP BY ts", - IncrementalAggregateMergeOp::Min, - "number", - ), - ( - "SELECT max(number) AS number, ts FROM numbers_with_ts GROUP BY ts", - IncrementalAggregateMergeOp::Max, - "number", - ), - ( - "SELECT bit_and(number) AS number, ts FROM numbers_with_ts GROUP BY ts", - IncrementalAggregateMergeOp::BitAnd, - "number", - ), - ( - "SELECT bit_or(number) AS number, ts FROM numbers_with_ts GROUP BY ts", - IncrementalAggregateMergeOp::BitOr, - "number", - ), - ( - "SELECT bit_xor(number) AS number, ts FROM numbers_with_ts GROUP BY ts", - IncrementalAggregateMergeOp::BitXor, - "number", - ), - ( - "SELECT bool_and(number > 5) AS cond, ts FROM numbers_with_ts GROUP BY ts", - IncrementalAggregateMergeOp::BoolAnd, - "cond", - ), - ( - "SELECT bool_or(number > 5) AS cond, ts FROM numbers_with_ts GROUP BY ts", - IncrementalAggregateMergeOp::BoolOr, - "cond", - ), - ]; - - for (sql, expected_op, expected_field_name) in testcases { - let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false) - .await - .unwrap(); - - let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); - assert!(analysis.unsupported_exprs.is_empty()); - assert!(analysis.group_key_names.contains(&"ts".to_string())); - assert_eq!(analysis.merge_columns.len(), 1); - assert_eq!( - analysis.merge_columns[0].output_field_name, - expected_field_name - ); - assert_eq!(analysis.merge_columns[0].merge_op, expected_op); - } - } - - #[tokio::test] - async fn test_analyze_incremental_aggregate_plan_keeps_aliases_for_multiple_aggregates() { - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); - let sql = "SELECT max(number) AS max_number, min(number) AS min_number, ts FROM numbers_with_ts GROUP BY ts"; - let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); - - let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); - assert!(analysis.unsupported_exprs.is_empty()); - assert!(analysis.group_key_names.contains(&"ts".to_string())); - assert_eq!(analysis.merge_columns.len(), 2); - assert!(analysis.merge_columns.iter().any(|merge_col| { - merge_col.output_field_name == "max_number" - && merge_col.merge_op == IncrementalAggregateMergeOp::Max - })); - assert!(analysis.merge_columns.iter().any(|merge_col| { - merge_col.output_field_name == "min_number" - && merge_col.merge_op == IncrementalAggregateMergeOp::Min - })); - } - - #[tokio::test] - async fn test_analyze_incremental_aggregate_plan_allows_where_before_aggregate() { - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); - let sql = - "SELECT sum(number) AS number, ts FROM numbers_with_ts WHERE number > 10 GROUP BY ts"; - let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); - - let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); - assert!(analysis.unsupported_exprs.is_empty()); - assert!(analysis.group_key_names.contains(&"ts".to_string())); - assert_eq!(analysis.merge_columns.len(), 1); - assert_eq!(analysis.merge_columns[0].output_field_name, "number"); - assert_eq!( - analysis.merge_columns[0].merge_op, - IncrementalAggregateMergeOp::Sum - ); - } - - #[tokio::test] - async fn test_analyze_incremental_aggregate_plan_rejects_having_filter() { - let sql = "SELECT sum(number) AS number, ts FROM numbers_with_ts GROUP BY ts HAVING sum(number) > 10"; - let analysis = analyze_test_sql(sql).await; - assert_unsupported(&analysis, "post-aggregate filter"); - } - - #[tokio::test] - async fn test_analyze_incremental_aggregate_plan_allows_aggregate_filter() { - let sql = "SELECT sum(number) FILTER (WHERE number > 10) AS number, ts FROM numbers_with_ts GROUP BY ts"; - let analysis = analyze_test_sql(sql).await; - - assert!(analysis.unsupported_exprs.is_empty()); - assert!(analysis.group_key_names.contains(&"ts".to_string())); - assert_eq!(analysis.merge_columns.len(), 1); - assert_eq!(analysis.merge_columns[0].output_field_name, "number"); - assert_eq!( - analysis.merge_columns[0].merge_op, - IncrementalAggregateMergeOp::Sum - ); - } - - #[tokio::test] - async fn test_analyze_incremental_aggregate_plan_rejects_aggregate_order_by() { - let sql = "SELECT sum(number ORDER BY ts) AS number, ts FROM numbers_with_ts GROUP BY ts"; - let analysis = analyze_test_sql(sql).await; - assert_unsupported(&analysis, "aggregate ORDER BY"); - } - - #[tokio::test] - async fn test_analyze_incremental_aggregate_plan_rejects_sort_above_aggregate() { - let sql = "SELECT sum(number) AS number, ts FROM numbers_with_ts GROUP BY ts ORDER BY number DESC"; - let analysis = analyze_test_sql(sql).await; - assert_unsupported(&analysis, "post-aggregate plan shape"); - } - - #[tokio::test] - async fn test_analyze_incremental_aggregate_plan_rejects_limit_above_aggregate() { - let sql = "SELECT max(number) AS number, ts FROM numbers_with_ts GROUP BY ts LIMIT 1"; - let analysis = analyze_test_sql(sql).await; - assert_unsupported(&analysis, "post-aggregate plan shape"); - } - - #[tokio::test] - async fn test_analyze_incremental_aggregate_plan_rejects_distinct_above_aggregate() { - let sql = "SELECT DISTINCT max(number) AS number, ts FROM numbers_with_ts GROUP BY ts"; - let analysis = analyze_test_sql(sql).await; - assert_unsupported(&analysis, "post-aggregate plan shape"); - } - - #[tokio::test] - async fn test_analyze_incremental_aggregate_plan_rejects_nested_aggregates() { - let sql = "SELECT sum(cnt) AS total FROM (SELECT count(*) AS cnt, ts FROM numbers_with_ts GROUP BY ts) s"; - let analysis = analyze_test_sql(sql).await; - assert_unsupported(&analysis, "Aggregate nodes"); - } - - #[tokio::test] - async fn test_analyze_incremental_aggregate_plan_rejects_union_aggregate_branches() { - let sql = "SELECT sum(number) AS number, ts FROM numbers_with_ts GROUP BY ts UNION ALL SELECT max(number) AS number, ts FROM numbers_with_ts GROUP BY ts"; - let analysis = analyze_test_sql(sql).await; - assert_unsupported(&analysis, "Aggregate nodes"); - assert_unsupported(&analysis, "post-aggregate plan shape"); - } - - #[tokio::test] - async fn test_analyze_incremental_aggregate_plan_rejects_window_above_aggregate() { - let sql = "SELECT sum(number) AS number, ts, row_number() OVER (ORDER BY sum(number)) AS rn FROM numbers_with_ts GROUP BY ts"; - let analysis = analyze_test_sql(sql).await; - assert_unsupported(&analysis, "post-aggregate plan shape"); - } - - #[tokio::test] - async fn test_analyze_incremental_aggregate_plan_rejects_join_below_aggregate() { - let sql = "SELECT sum(lhs.number) AS number, lhs.ts FROM numbers_with_ts AS lhs JOIN numbers_with_ts AS rhs ON lhs.ts = rhs.ts GROUP BY lhs.ts"; - let analysis = analyze_test_sql(sql).await; - assert_unsupported(&analysis, "aggregate input plan shape"); - } - - #[tokio::test] - async fn test_analyze_incremental_aggregate_plan_rejects_grouping_sets() { - let analysis = - analyze_grouping_set_plan(|expr| GroupingSet::GroupingSets(vec![vec![expr]])).await; - assert_unsupported(&analysis, "GROUPING SETS/CUBE/ROLLUP"); - } - - #[tokio::test] - async fn test_analyze_incremental_aggregate_plan_rejects_cube() { - let analysis = analyze_grouping_set_plan(|expr| GroupingSet::Cube(vec![expr])).await; - assert_unsupported(&analysis, "GROUPING SETS/CUBE/ROLLUP"); - } - - #[tokio::test] - async fn test_analyze_incremental_aggregate_plan_rejects_rollup() { - let analysis = analyze_grouping_set_plan(|expr| GroupingSet::Rollup(vec![expr])).await; - assert_unsupported(&analysis, "GROUPING SETS/CUBE/ROLLUP"); - } - - #[tokio::test] - async fn test_analyze_incremental_aggregate_plan_preserves_raw_aggregate_name() { - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); - let sql = "SELECT max(number), ts FROM numbers_with_ts GROUP BY ts"; - let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); - - let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); - assert!(analysis.unsupported_exprs.is_empty()); - assert_eq!(analysis.merge_columns.len(), 1); - assert_eq!( - analysis.merge_columns[0].output_field_name, - "max(numbers_with_ts.number)" - ); - assert_eq!( - analysis.merge_columns[0].merge_op, - IncrementalAggregateMergeOp::Max - ); - } - - #[tokio::test] - async fn test_analyze_incremental_aggregate_plan_rejects_wrapper_aliased_as_raw_name() { - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); - let sql = r#"SELECT COALESCE(max(number), 0) AS "max(numbers_with_ts.number)", ts FROM numbers_with_ts GROUP BY ts"#; - let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); - - let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); - assert!( - !analysis.unsupported_exprs.is_empty(), - "wrapper aliased to a raw aggregate field name must not bypass analysis" - ); - assert!(analysis.merge_columns.is_empty()); - } - - #[tokio::test] - async fn test_analyze_incremental_aggregate_plan_supports_count_star() { - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); - let sql = "SELECT count(*) AS wildcard, ts FROM numbers_with_ts GROUP BY ts"; - let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); - - let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); - assert!(analysis.unsupported_exprs.is_empty()); - assert_eq!(analysis.merge_columns.len(), 1); - assert_eq!(analysis.merge_columns[0].output_field_name, "wildcard"); - assert_eq!( - analysis.merge_columns[0].merge_op, - IncrementalAggregateMergeOp::Sum - ); - } - - #[tokio::test] - async fn test_analyze_incremental_aggregate_plan_supports_aggregate_input_exprs() { - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); - let testcases = [ - "SELECT sum(abs(number)) AS sum_abs, ts FROM numbers_with_ts GROUP BY ts", - "SELECT sum(CASE WHEN number > 5 THEN 1 ELSE 0 END) AS above_five, ts FROM numbers_with_ts GROUP BY ts", - ]; - - for sql in testcases { - let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false) - .await - .unwrap(); - let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); - assert!( - analysis.unsupported_exprs.is_empty(), - "aggregate input expressions should be mergeable for SQL: {sql}" - ); - assert_eq!(analysis.merge_columns.len(), 1); - assert_eq!( - analysis.merge_columns[0].merge_op, - IncrementalAggregateMergeOp::Sum - ); - } - } - - #[tokio::test] - async fn test_analyze_incremental_aggregate_plan_rejects_output_expr_wrappers() { - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); - let testcases = [ - "SELECT abs(sum(number)) AS abs_sum, ts FROM numbers_with_ts GROUP BY ts", - "SELECT max(number) - min(number) AS maxmin, ts FROM numbers_with_ts GROUP BY ts", - "SELECT count(number) + 123 AS total_count, ts FROM numbers_with_ts GROUP BY ts", - "SELECT sum(CASE WHEN number > 5 THEN 1 ELSE 0 END) / count(number) AS ratio, ts FROM numbers_with_ts GROUP BY ts", - ]; - - for sql in testcases { - let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false) - .await - .unwrap(); - let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); - assert!( - !analysis.unsupported_exprs.is_empty(), - "aggregate output wrappers should be rejected for SQL: {sql}" - ); - } - } - - #[tokio::test] - async fn test_analyze_incremental_aggregate_plan_allows_literal_outputs() { - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); - let sql = "SELECT max(number) AS number, ts, 42 AS lit FROM numbers_with_ts GROUP BY ts"; - let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false) - .await - .unwrap(); - - let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); - assert!(analysis.unsupported_exprs.is_empty()); - assert_eq!(analysis.literal_columns, vec!["lit".to_string()]); - assert_eq!( - analysis.output_field_names, - vec!["number".to_string(), "ts".to_string(), "lit".to_string()] - ); - - let sink_table_name = [ - "greptime".to_string(), - "public".to_string(), - "numbers_with_ts".to_string(), - ]; - let (sink_table, _) = get_table_info_df_schema( - query_engine.engine_state().catalog_manager().clone(), - sink_table_name.clone(), - ) - .await - .unwrap(); - let rewritten = rewrite_incremental_aggregate_with_sink_merge( - &plan, - &analysis, - sink_table.clone(), - &sink_table_name, - ) - .await - .unwrap(); - - let rewritten_fields = rewritten - .schema() - .fields() - .iter() - .map(|field| field.name().clone()) - .collect::>(); - assert_eq!(rewritten_fields, analysis.output_field_names); - let expected = expected_left_join_rewrite( - &plan, - sink_table, - &sink_table_name, - vec![ - unqualified_col("ts"), - unqualified_col("number"), - unqualified_col("lit"), - ], - vec![unqualified_col("ts"), unqualified_col("number")], - ( - vec![qualified_column("__flow_delta", "ts")], - vec![qualified_column("__flow_sink", "ts")], - ), - vec![ - max_merge_expr("number"), - qualified_col("__flow_delta", "ts").alias("ts"), - qualified_col("__flow_delta", "lit").alias("lit"), - ], - ); - assert_same_logical_plan(&rewritten, &expected); - } - - #[tokio::test] - async fn test_analyze_incremental_aggregate_plan_allows_unaliased_literal_output() { - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); - let sql = "SELECT 42, max(number) AS number, ts FROM numbers_with_ts GROUP BY ts"; - let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); - - let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); - assert!(analysis.unsupported_exprs.is_empty()); - assert_eq!(analysis.literal_columns.len(), 1); - assert_eq!(analysis.output_field_names[0], analysis.literal_columns[0]); - assert_eq!(analysis.output_field_names[1], "number"); - assert_eq!(analysis.output_field_names[2], "ts"); - } - - #[tokio::test] - async fn test_analyze_incremental_aggregate_plan_allows_string_literal_output() { - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); - let sql = - "SELECT max(number) AS number, ts, 'hello' AS label FROM numbers_with_ts GROUP BY ts"; - let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); - - let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); - assert!(analysis.unsupported_exprs.is_empty()); - assert_eq!(analysis.literal_columns, vec!["label".to_string()]); - assert_eq!( - analysis.output_field_names, - vec!["number".to_string(), "ts".to_string(), "label".to_string()] - ); - } - - #[tokio::test] - async fn test_rewrite_incremental_aggregate_preserves_non_identifier_aliases() { - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); - let sql = "SELECT max(number) AS \"max value\", number, 42 AS \"literal value\" FROM numbers_with_ts GROUP BY number"; - let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); - let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); - assert!(analysis.unsupported_exprs.is_empty()); - assert_eq!( - analysis.output_field_names, - vec!["max value", "number", "literal value"] - ); - - let sink_table = - single_row_u32_table("non_identifier_alias_sink", vec!["number", "max value"]); - let rewritten = rewrite_incremental_aggregate_with_sink_merge( - &plan, - &analysis, - sink_table, - &[ - "greptime".to_string(), - "public".to_string(), - "non_identifier_alias_sink".to_string(), - ], - ) - .await - .unwrap(); - - assert_eq!( - rewritten - .schema() - .fields() - .iter() - .map(|field| field.name().clone()) - .collect::>(), - vec![ - "max value".to_string(), - "number".to_string(), - "literal value".to_string() - ] - ); - } - - #[tokio::test] - async fn test_analyze_incremental_aggregate_plan_rejects_uncovered_outputs() { - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); - let sql = "SELECT sum(number) AS total, number + 1 AS bucket FROM numbers_with_ts GROUP BY number"; - let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); - - let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); - assert!( - analysis - .unsupported_exprs - .iter() - .any(|expr| expr.contains("unsupported output field: bucket")), - "non-literal extra output should be rejected: {:?}", - analysis.unsupported_exprs - ); - } - - #[tokio::test] - async fn test_datafusion_rejects_duplicate_output_names() { - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); - let sql = "SELECT max(number) AS x, min(number) AS x, ts FROM numbers_with_ts GROUP BY ts"; - let err = sql_to_df_plan(ctx, query_engine, sql, false) - .await - .unwrap_err(); - let err = format!("{err:?}"); - assert!( - err.contains("Projections require unique expression names"), - "DataFusion should reject duplicate output aliases before incremental analysis: {err}" - ); - } - - #[tokio::test] - async fn test_analyze_incremental_aggregate_plan_rejects_same_aggregate_multiple_aliases() { - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); - let sql = "SELECT sum(number) AS a, sum(number) AS b, ts FROM numbers_with_ts GROUP BY ts"; - let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); - - let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); - assert!( - analysis - .unsupported_exprs - .iter() - .any(|expr| expr.contains("same aggregate output") - && expr.contains("a") - && expr.contains("b")), - "same aggregate with multiple aliases should be unsupported until explicit reproduction is implemented: {:?}", - analysis.unsupported_exprs - ); - assert!(analysis.merge_columns.is_empty()); - } - - #[test] - fn test_qualified_col_preserves_non_identifier_field_name() { - let expr = qualified_col("__flow_delta", "max(numbers_with_ts.number)"); - let Expr::Column(column) = expr else { - panic!("expected column expression"); - }; - assert_eq!(column.name, "max(numbers_with_ts.number)"); - assert_eq!(column.relation.unwrap().to_string(), "__flow_delta"); - } - - #[tokio::test] - async fn test_analyze_incremental_aggregate_plan_multiple_group_keys() { - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); - let sql = "SELECT sum(number) AS total, ts, number AS bucket FROM numbers_with_ts GROUP BY ts, number"; - let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); - - let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); - assert!(analysis.unsupported_exprs.is_empty()); - assert!(analysis.group_key_names.contains(&"ts".to_string())); - assert!(analysis.group_key_names.contains(&"bucket".to_string())); - assert_eq!(analysis.group_key_names.len(), 2); - assert_eq!(analysis.merge_columns.len(), 1); - assert_eq!(analysis.merge_columns[0].output_field_name, "total"); - assert_eq!( - analysis.merge_columns[0].merge_op, - IncrementalAggregateMergeOp::Sum - ); - } - - #[tokio::test] - async fn test_analyze_incremental_aggregate_plan_rejects_avg() { - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); - let sql = "SELECT avg(number) AS avg_num, ts FROM numbers_with_ts GROUP BY ts"; - let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); - - let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); - assert!(!analysis.unsupported_exprs.is_empty()); - } - - #[tokio::test] - async fn test_analyze_incremental_aggregate_plan_rejects_distinct() { - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); - let sql = "SELECT count(distinct number) AS number, ts FROM numbers_with_ts GROUP BY ts"; - let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); - - let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); - assert!(!analysis.unsupported_exprs.is_empty()); - } - - #[tokio::test] - async fn test_analyze_incremental_aggregate_plan_rejects_coalesce_wrapped_aggregate() { - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); - // COALESCE wraps the aggregate output — the wrapper is not merge-transparent, - // so the analyzer should mark the aggregate as unsupported rather than - // attempting an unsafe incremental rewrite. - let sql = - "SELECT COALESCE(max(number), 0) AS coalesced_max, ts FROM numbers_with_ts GROUP BY ts"; - let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); - let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); - // Non-transparent wrapper → alias unresolvable → unsupported - assert!( - !analysis.unsupported_exprs.is_empty(), - "COALESCE-wrapped aggregate should be unsupported" - ); - assert!( - analysis.merge_columns.is_empty(), - "COALESCE-wrapped aggregate should have no merge columns" - ); - } - - #[tokio::test] - async fn test_rewrite_incremental_aggregate_with_left_join() { - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); - let sql = "SELECT max(number) AS number, ts FROM numbers_with_ts GROUP BY ts"; - let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false) - .await - .unwrap(); - let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); - let sink_table_name = [ - "greptime".to_string(), - "public".to_string(), - "numbers_with_ts".to_string(), - ]; - let (sink_table, _) = get_table_info_df_schema( - query_engine.engine_state().catalog_manager().clone(), - sink_table_name.clone(), - ) - .await - .unwrap(); - - let rewritten = rewrite_incremental_aggregate_with_sink_merge( - &plan, - &analysis, - sink_table.clone(), - &sink_table_name, - ) - .await - .unwrap(); - - let expected = expected_left_join_rewrite( - &plan, - sink_table, - &sink_table_name, - vec![unqualified_col("ts"), unqualified_col("number")], - vec![unqualified_col("ts"), unqualified_col("number")], - ( - vec![qualified_column("__flow_delta", "ts")], - vec![qualified_column("__flow_sink", "ts")], - ), - vec![ - max_merge_expr("number"), - qualified_col("__flow_delta", "ts").alias("ts"), - ], - ); - assert_same_logical_plan(&rewritten, &expected); - } - - #[tokio::test] - async fn test_analyze_incremental_aggregate_plan_rejects_global_aggregate() { - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); - let testcases = [ - "SELECT max(number) AS number FROM numbers_with_ts", - "SELECT max(number) AS number, 42 AS lit FROM numbers_with_ts", - "SELECT count(*) AS cnt, sum(number) AS total FROM numbers_with_ts", - ]; - - for sql in testcases { - let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false) - .await - .unwrap(); - let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); - assert_unsupported(&analysis, "global aggregate"); - } - } - - #[tokio::test] - async fn test_rewrite_incremental_aggregate_rejects_empty_group_keys() { - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); - let sql = "SELECT max(number) AS number FROM numbers_with_ts"; - let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); - let analysis = IncrementalAggregateAnalysis { - group_key_names: vec![], - merge_columns: vec![IncrementalAggregateMergeColumn::new( - "number".to_string(), - IncrementalAggregateMergeOp::Max, - )], - literal_columns: vec![], - output_field_names: vec!["number".to_string()], - unsupported_exprs: vec![], - }; - - let sink_table = single_row_u32_table("global_guard_sink", vec!["number"]); - let sink_table_name = [ - "greptime".to_string(), - "public".to_string(), - "global_guard_sink".to_string(), - ]; - let err = rewrite_incremental_aggregate_with_sink_merge( - &plan, - &analysis, - sink_table, - &sink_table_name, - ) - .await - .unwrap_err(); - let err = format!("{err:?}"); - assert!( - err.contains("global aggregate query is not supported"), - "rewrite should defensively reject empty group keys: {err}" - ); - } - - #[tokio::test] - async fn test_rewrite_incremental_aggregate_preserves_raw_aggregate_field_name() { - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); - let sql = "SELECT max(number), number FROM numbers_with_ts GROUP BY number"; - let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); - let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); - assert!(analysis.unsupported_exprs.is_empty()); - - let raw_field_name = "max(numbers_with_ts.number)"; - let sink_table = single_row_u32_table("raw_aggregate_sink", vec!["number", raw_field_name]); - let sink_table_name = [ - "greptime".to_string(), - "public".to_string(), - "raw_aggregate_sink".to_string(), - ]; - let rewritten = rewrite_incremental_aggregate_with_sink_merge( - &plan, - &analysis, - sink_table.clone(), - &sink_table_name, - ) - .await - .unwrap(); - - let rewritten_fields = rewritten - .schema() - .fields() - .iter() - .map(|field| field.name().clone()) - .collect::>(); - assert!(rewritten_fields.contains(&raw_field_name.to_string())); - let expected = expected_left_join_rewrite( - &plan, - sink_table, - &sink_table_name, - vec![unqualified_col("number"), unqualified_col(raw_field_name)], - vec![unqualified_col("number"), unqualified_col(raw_field_name)], - ( - vec![qualified_column("__flow_delta", "number")], - vec![qualified_column("__flow_sink", "number")], - ), - vec![ - max_merge_expr(raw_field_name), - qualified_col("__flow_delta", "number").alias("number"), - ], - ); - assert_same_logical_plan(&rewritten, &expected); - } - - #[tokio::test] - async fn test_null_cast() { - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); - let sql = "SELECT NULL::DOUBLE FROM numbers_with_ts"; - let plan = sql_to_df_plan(ctx, query_engine.clone(), sql, false) - .await - .unwrap(); - - let _sub_plan = DFLogicalSubstraitConvertor {} - .encode(&plan, DefaultSerializer) - .unwrap(); - } - - #[tokio::test] - async fn test_analyze_incremental_aggregate_plan_rejects_cast_wrapped_alias() { - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); - let testcases = [ - "SELECT CAST(sum(number) AS BIGINT) AS total, ts FROM numbers_with_ts GROUP BY ts", - "SELECT TRY_CAST(sum(number) AS BIGINT) AS total, ts FROM numbers_with_ts GROUP BY ts", - ]; - - for sql in testcases { - let plan = sql_to_df_plan(ctx.clone(), query_engine.clone(), sql, false) - .await - .unwrap(); - let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); - assert!( - !analysis.unsupported_exprs.is_empty(), - "CAST/TryCast-wrapped aggregate output should be unsupported for SQL: {sql}" - ); - } - } - - #[tokio::test] - async fn test_aggregate_expr_finder_counts_multiple_aggregates() { - let query_engine = create_test_query_engine(); - let ctx = QueryContext::arc(); - // Subquery has an inner aggregate (count), outer query has another aggregate (sum). - let sql = "SELECT sum(cnt) AS total, ts \ - FROM (SELECT ts, count(number) AS cnt FROM numbers_with_ts GROUP BY ts) AS sub \ - GROUP BY ts"; - let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); - - let mut finder = AggregateExprFinder::default(); - plan.visit(&mut finder).unwrap(); - assert!( - finder.aggregate_count > 1, - "nested aggregate plans should be identifiable as unsupported" - ); - } -} diff --git a/src/flow/src/batching_mode/utils/test.rs b/src/flow/src/batching_mode/utils/test.rs index 7d815155e2..863580b4ae 100644 --- a/src/flow/src/batching_mode/utils/test.rs +++ b/src/flow/src/batching_mode/utils/test.rs @@ -623,6 +623,37 @@ async fn test_analyze_incremental_aggregate_plan_allows_alias_wrapped_scan() { } } +#[tokio::test] +async fn test_rewrite_incremental_aggregate_allows_alias_wrapped_scan() { + let query_engine = create_test_query_engine(); + let ctx = QueryContext::arc(); + let sql = "SELECT max(n.number) AS number, n.ts FROM numbers_with_ts AS n GROUP BY n.ts"; + let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); + let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); + assert!(analysis.unsupported_exprs.is_empty()); + + let rewritten = rewrite_incremental_aggregate_with_sink_merge( + &plan, + &analysis, + single_row_u32_table("alias_wrapped_sink", vec!["ts", "number"]), + &[ + "greptime".to_string(), + "public".to_string(), + "alias_wrapped_sink".to_string(), + ], + ) + .await + .unwrap(); + + let rewritten_fields = rewritten + .schema() + .fields() + .iter() + .map(|field| field.name().clone()) + .collect::>(); + assert_eq!(rewritten_fields, analysis.output_field_names); +} + #[tokio::test] async fn test_analyze_incremental_aggregate_plan_rejects_having_filter() { let sql = @@ -982,6 +1013,18 @@ async fn test_analyze_incremental_aggregate_plan_rejects_uncovered_outputs() { ); } +#[tokio::test] +async fn test_analyze_incremental_aggregate_plan_rejects_computed_group_key_shadow() { + let query_engine = create_test_query_engine(); + let ctx = QueryContext::arc(); + let sql = + "SELECT number + 1 AS number, sum(number) AS total FROM numbers_with_ts GROUP BY number"; + let plan = sql_to_df_plan(ctx, query_engine, sql, false).await.unwrap(); + + let analysis = analyze_incremental_aggregate_plan(&plan).unwrap().unwrap(); + assert_unsupported(&analysis, "not a transparent group expression"); +} + #[tokio::test] async fn test_datafusion_rejects_duplicate_output_names() { let query_engine = create_test_query_engine();