diff --git a/src/query/src/dist_plan/analyzer.rs b/src/query/src/dist_plan/analyzer.rs index 6eb8b5eac6..e9dd698354 100644 --- a/src/query/src/dist_plan/analyzer.rs +++ b/src/query/src/dist_plan/analyzer.rs @@ -304,7 +304,7 @@ impl PlanRewriter { /// Return true if should stop and expand. The input plan is the parent node of current node fn should_expand(&mut self, plan: &LogicalPlan) -> DfResult { debug!( - "Check should_expand at level: {} with Stack:\n{}, ", + "Check should_expand at level: {} with Stack:\n{}\nWith plan=\n{plan} ", self.level, self.stack .iter() @@ -789,6 +789,18 @@ impl TreeNodeRewriter for EnforceDistRequirementRewriter { // still need to continue for next projection if applicable return Ok(Transformed::yes(new_node)); + } else if let LogicalPlan::Aggregate(_) = node { + // something is wrong, we shouldn't add column requirements for aggregate node + // because aggregate node will change the schema and may drop certain columns rightfully + // need to return a error with enough debug info for debugging + let applicable_column_requirements = + self.get_current_applicable_column_requirements(&node)?; + if !applicable_column_requirements.is_empty() { + return Err(datafusion_common::DataFusionError::Internal(format!( + "EnforceDistRequirementRewriter: aggregate node should not have applicable column requirements at level {} for node {}: {:?}", + self.cur_level, node, applicable_column_requirements + ))); + } } Ok(Transformed::no(node)) } diff --git a/src/query/src/dist_plan/analyzer/test.rs b/src/query/src/dist_plan/analyzer/test.rs index 7d4578f3dd..b42accd91c 100644 --- a/src/query/src/dist_plan/analyzer/test.rs +++ b/src/query/src/dist_plan/analyzer/test.rs @@ -748,6 +748,58 @@ fn expand_part_col_aggr_part_col_aggr() { assert_eq!(expected, result.to_string()); } +#[test] +fn expand_sort_part_col_aggr_part_col_aggr() { + // use logging for better debugging + init_default_ut_logging(); + let test_table = TestTable::table_with_name(0, "t".to_string()); + let table_source = Arc::new(DefaultTableSource::new(Arc::new( + DfTableProviderAdapter::new(test_table), + ))); + let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![]) + .unwrap() + .sort(vec![ + col("pk1").sort(true, false), + col("pk2").sort(true, false), + ]) + .unwrap() + .aggregate(vec![col("pk1"), col("pk2")], vec![max(col("number"))]) + .unwrap() + .aggregate( + vec![col("pk1"), col("pk2")], + vec![min(col("max(t.number)"))], + ) + .unwrap() + .build() + .unwrap(); + + let expected_original = [ + // See DataFusion #14860 for change details. + "Aggregate: groupBy=[[t.pk1, t.pk2]], aggr=[[min(max(t.number))]]", + " Aggregate: groupBy=[[t.pk1, t.pk2]], aggr=[[max(t.number)]]", + " Sort: t.pk1 ASC NULLS LAST, t.pk2 ASC NULLS LAST", + " TableScan: t", + ] + .join("\n"); + assert_eq!(expected_original, plan.to_string()); + + let config = ConfigOptions::default(); + let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap(); + + let expected = [ + "Aggregate: groupBy=[[t.pk1, t.pk2]], aggr=[[min(max(t.number))]]", + " Aggregate: groupBy=[[t.pk1, t.pk2]], aggr=[[max(t.number)]]", + " Projection: t.pk1, t.pk2, t.pk3, t.ts, t.number", + " MergeSort: t.pk1 ASC NULLS LAST, t.pk2 ASC NULLS LAST", + " MergeScan [is_placeholder=false, remote_input=[", + "Sort: t.pk1 ASC NULLS LAST, t.pk2 ASC NULLS LAST", + " TableScan: t", + "]]", + ] + .join("\n"); + assert_eq!(expected, result.to_string()); +} + #[test] fn expand_step_aggr_proj() { // use logging for better debugging