fix: recover plan schema after dist analyzer (#5665)

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
This commit is contained in:
Ruihang Xia
2025-03-06 16:29:55 -08:00
committed by GitHub
parent 0124a0d156
commit e463942a5b
4 changed files with 67 additions and 28 deletions

View File

@@ -19,6 +19,7 @@ use datafusion::datasource::DefaultTableSource;
use datafusion::error::Result as DfResult;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
use datafusion_common::Column;
use datafusion_expr::expr::{Exists, InSubquery};
use datafusion_expr::utils::expr_to_columns;
use datafusion_expr::{col as col_fn, Expr, LogicalPlan, LogicalPlanBuilder, Subquery};
@@ -215,7 +216,7 @@ impl PlanRewriter {
}
for col in container {
self.column_requirements.insert(col.flat_name());
self.column_requirements.insert(col.quoted_flat_name());
}
}
@@ -270,6 +271,8 @@ impl PlanRewriter {
}
fn expand(&mut self, mut on_node: LogicalPlan) -> DfResult<LogicalPlan> {
// store schema before expand
let schema = on_node.schema().clone();
let mut rewriter = EnforceDistRequirementRewriter {
column_requirements: std::mem::take(&mut self.column_requirements),
};
@@ -285,6 +288,13 @@ impl PlanRewriter {
}
self.set_expanded();
// recover the schema
let node = LogicalPlanBuilder::from(node)
.project(schema.iter().map(|(qualifier, field)| {
Expr::Column(Column::new(qualifier.cloned(), field.name()))
}))?
.build()?;
Ok(node)
}
}
@@ -447,7 +457,8 @@ mod test {
let config = ConfigOptions::default();
let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
let expected = "MergeScan [is_placeholder=false]";
let expected = "Projection: avg(t.number)\
\n MergeScan [is_placeholder=false]";
assert_eq!(expected, result.to_string());
}
@@ -472,7 +483,8 @@ mod test {
let expected = [
"Sort: t.number ASC NULLS LAST",
" Distinct:",
" MergeScan [is_placeholder=false]",
" Projection: t.number",
" MergeScan [is_placeholder=false]",
]
.join("\n");
assert_eq!(expected, result.to_string());
@@ -494,7 +506,8 @@ mod test {
let config = ConfigOptions::default();
let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
let expected = "MergeScan [is_placeholder=false]";
let expected = "Projection: t.number\
\n MergeScan [is_placeholder=false]";
assert_eq!(expected, result.to_string());
}
@@ -531,11 +544,16 @@ mod test {
let config = ConfigOptions::default();
let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
let expected = "Limit: skip=0, fetch=1\
\n LeftSemi Join: Filter: t.number = right.number\
\n MergeScan [is_placeholder=false]\
\n SubqueryAlias: right\
\n MergeScan [is_placeholder=false]";
let expected = [
"Limit: skip=0, fetch=1",
" LeftSemi Join: Filter: t.number = right.number",
" Projection: t.number",
" MergeScan [is_placeholder=false]",
" SubqueryAlias: right",
" Projection: t.number",
" MergeScan [is_placeholder=false]",
]
.join("\n");
assert_eq!(expected, result.to_string());
}
}