diff --git a/src/flow/src/batching_mode/utils.rs b/src/flow/src/batching_mode/utils.rs index 56c22f5a07..fdce5ccd09 100644 --- a/src/flow/src/batching_mode/utils.rs +++ b/src/flow/src/batching_mode/utils.rs @@ -194,6 +194,16 @@ fn find_aggregate_exprs(plan: &LogicalPlan) -> Result<(usize, Option>) )) } +fn has_grouping_set(plan: &LogicalPlan) -> bool { + match plan { + LogicalPlan::Aggregate(aggregate) => aggregate + .group_expr + .iter() + .any(|expr| matches!(expr, Expr::GroupingSet(_))), + _ => plan.inputs().into_iter().any(has_grouping_set), + } +} + fn check_inc_aggr_plan_shape(plan: &LogicalPlan) -> Result<(), String> { // Supported final shape: optional output Projection directly over one // Aggregate. Post-aggregate filters (HAVING), ordering, limits, @@ -418,6 +428,11 @@ pub fn analyze_incremental_aggregate_plan( if let Err(reason) = check_inc_aggr_plan_shape(plan) { unsupported_exprs.push(reason); } + if has_grouping_set(plan) { + unsupported_exprs.push( + "unsupported GROUPING SETS/CUBE/ROLLUP in incremental aggregate rewrite".to_string(), + ); + } unsupported_exprs.extend(projection_info.duplicate_aggregate_aliases.iter().cloned()); if group_key_names.is_empty() && projection_info @@ -1313,6 +1328,7 @@ mod test { use common_recordbatch::RecordBatch; use datafusion_common::tree_node::TreeNode as _; + use datafusion_expr::GroupingSet; use datatypes::prelude::{ConcreteDataType, Scalar, VectorRef}; use datatypes::schema::{ColumnSchema, Schema}; use pretty_assertions::assert_eq; @@ -1442,6 +1458,37 @@ mod test { analyze_incremental_aggregate_plan(&plan).unwrap().unwrap() } + async fn analyze_grouping_set_plan( + make_grouping_set: impl FnOnce(Expr) -> GroupingSet, + ) -> IncrementalAggregateAnalysis { + let query_engine = create_test_query_engine(); + let ctx = QueryContext::arc(); + let plan = sql_to_df_plan( + ctx, + query_engine, + "SELECT sum(number) AS number, ts FROM numbers_with_ts GROUP BY ts", + false, + ) + .await + .unwrap(); + + let LogicalPlan::Projection(projection) = plan else { + panic!("expected projection over aggregate") + }; + let LogicalPlan::Aggregate(aggregate) = projection.input.as_ref() else { + panic!("expected aggregate below projection") + }; + let group_expr = aggregate.group_expr[0].clone(); + let grouping_set_aggregate = datafusion_expr::logical_plan::Aggregate::try_new( + aggregate.input.clone(), + vec![Expr::GroupingSet(make_grouping_set(group_expr))], + aggregate.aggr_expr.clone(), + ) + .unwrap(); + let plan = LogicalPlan::Aggregate(grouping_set_aggregate); + analyze_incremental_aggregate_plan(&plan).unwrap().unwrap() + } + fn assert_unsupported(analysis: &IncrementalAggregateAnalysis, reason: &str) { assert!( analysis @@ -1940,6 +1987,25 @@ mod test { assert_unsupported(&analysis, "aggregate input plan shape"); } + #[tokio::test] + async fn test_analyze_incremental_aggregate_plan_rejects_grouping_sets() { + let analysis = + analyze_grouping_set_plan(|expr| GroupingSet::GroupingSets(vec![vec![expr]])).await; + assert_unsupported(&analysis, "GROUPING SETS/CUBE/ROLLUP"); + } + + #[tokio::test] + async fn test_analyze_incremental_aggregate_plan_rejects_cube() { + let analysis = analyze_grouping_set_plan(|expr| GroupingSet::Cube(vec![expr])).await; + assert_unsupported(&analysis, "GROUPING SETS/CUBE/ROLLUP"); + } + + #[tokio::test] + async fn test_analyze_incremental_aggregate_plan_rejects_rollup() { + let analysis = analyze_grouping_set_plan(|expr| GroupingSet::Rollup(vec![expr])).await; + assert_unsupported(&analysis, "GROUPING SETS/CUBE/ROLLUP"); + } + #[tokio::test] async fn test_analyze_incremental_aggregate_plan_preserves_raw_aggregate_name() { let query_engine = create_test_query_engine();