chore: extra check for not rewrite aggr node

Signed-off-by: discord9 <discord9@163.com>
This commit is contained in:
discord9
2026-01-26 17:04:33 +08:00
parent 2f82e7525f
commit 810bb14b16
2 changed files with 65 additions and 1 deletions

View File

@@ -304,7 +304,7 @@ impl PlanRewriter {
/// Return true if should stop and expand. The input plan is the parent node of current node
fn should_expand(&mut self, plan: &LogicalPlan) -> DfResult<bool> {
debug!(
"Check should_expand at level: {} with Stack:\n{}, ",
"Check should_expand at level: {} with Stack:\n{}\nWith plan=\n{plan} ",
self.level,
self.stack
.iter()
@@ -789,6 +789,18 @@ impl TreeNodeRewriter for EnforceDistRequirementRewriter {
// still need to continue for next projection if applicable
return Ok(Transformed::yes(new_node));
} else if let LogicalPlan::Aggregate(_) = node {
// something is wrong, we shouldn't add column requirements for aggregate node
// because aggregate node will change the schema and may drop certain columns rightfully
// need to return a error with enough debug info for debugging
let applicable_column_requirements =
self.get_current_applicable_column_requirements(&node)?;
if !applicable_column_requirements.is_empty() {
return Err(datafusion_common::DataFusionError::Internal(format!(
"EnforceDistRequirementRewriter: aggregate node should not have applicable column requirements at level {} for node {}: {:?}",
self.cur_level, node, applicable_column_requirements
)));
}
}
Ok(Transformed::no(node))
}

View File

@@ -748,6 +748,58 @@ fn expand_part_col_aggr_part_col_aggr() {
assert_eq!(expected, result.to_string());
}
#[test]
fn expand_sort_part_col_aggr_part_col_aggr() {
// use logging for better debugging
init_default_ut_logging();
let test_table = TestTable::table_with_name(0, "t".to_string());
let table_source = Arc::new(DefaultTableSource::new(Arc::new(
DfTableProviderAdapter::new(test_table),
)));
let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
.unwrap()
.sort(vec![
col("pk1").sort(true, false),
col("pk2").sort(true, false),
])
.unwrap()
.aggregate(vec![col("pk1"), col("pk2")], vec![max(col("number"))])
.unwrap()
.aggregate(
vec![col("pk1"), col("pk2")],
vec![min(col("max(t.number)"))],
)
.unwrap()
.build()
.unwrap();
let expected_original = [
// See DataFusion #14860 for change details.
"Aggregate: groupBy=[[t.pk1, t.pk2]], aggr=[[min(max(t.number))]]",
" Aggregate: groupBy=[[t.pk1, t.pk2]], aggr=[[max(t.number)]]",
" Sort: t.pk1 ASC NULLS LAST, t.pk2 ASC NULLS LAST",
" TableScan: t",
]
.join("\n");
assert_eq!(expected_original, plan.to_string());
let config = ConfigOptions::default();
let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
let expected = [
"Aggregate: groupBy=[[t.pk1, t.pk2]], aggr=[[min(max(t.number))]]",
" Aggregate: groupBy=[[t.pk1, t.pk2]], aggr=[[max(t.number)]]",
" Projection: t.pk1, t.pk2, t.pk3, t.ts, t.number",
" MergeSort: t.pk1 ASC NULLS LAST, t.pk2 ASC NULLS LAST",
" MergeScan [is_placeholder=false, remote_input=[",
"Sort: t.pk1 ASC NULLS LAST, t.pk2 ASC NULLS LAST",
" TableScan: t",
"]]",
]
.join("\n");
assert_eq!(expected, result.to_string());
}
#[test]
fn expand_step_aggr_proj() {
// use logging for better debugging