fix: handle hash distribution properly (#6943)

* fix: handle hash distribution properly

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* fix clippy

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* Update src/query/src/optimizer/pass_distribution.rs

Co-authored-by: dennis zhuang <killme2008@gmail.com>

---------

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
Co-authored-by: dennis zhuang <killme2008@gmail.com>
This commit is contained in:
Ruihang Xia
2025-09-09 23:35:10 -07:00
committed by Weny Xu
parent aa7e7942f8
commit a67803d0e9
3 changed files with 85 additions and 52 deletions

View File

@@ -420,17 +420,22 @@ impl MergeScanExec {
return None;
}
let mut hash_cols = HashSet::default();
let partition_cols = self
.partition_cols
.iter()
.map(|x| x.as_str())
.collect::<HashSet<_>>();
let mut overlaps = vec![];
for expr in &hash_exprs {
if let Some(col_expr) = expr.as_any().downcast_ref::<Column>() {
hash_cols.insert(col_expr.name());
// TODO(ruihang): tracking aliases
if let Some(col_expr) = expr.as_any().downcast_ref::<Column>()
&& partition_cols.contains(col_expr.name())
{
overlaps.push(expr.clone());
}
}
for col in &self.partition_cols {
if !hash_cols.contains(col.as_str()) {
// The partitioning columns are not the same
return None;
}
if overlaps.is_empty() {
return None;
}
Some(Self {
@@ -443,7 +448,7 @@ impl MergeScanExec {
metric: self.metric.clone(),
properties: PlanProperties::new(
self.properties.eq_properties.clone(),
Partitioning::Hash(hash_exprs, self.target_partition),
Partitioning::Hash(overlaps, self.target_partition),
self.properties.emission_type,
self.properties.boundedness,
),

View File

@@ -17,7 +17,6 @@ use std::sync::Arc;
use datafusion::config::ConfigOptions;
use datafusion::physical_optimizer::PhysicalOptimizerRule;
use datafusion::physical_plan::ExecutionPlan;
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::Result as DfResult;
use datafusion_physical_expr::Distribution;
@@ -56,26 +55,52 @@ impl PassDistribution {
plan: Arc<dyn ExecutionPlan>,
_config: &ConfigOptions,
) -> DfResult<Arc<dyn ExecutionPlan>> {
let mut distribution_requirement = None;
let result = plan.transform_down(|plan| {
if let Some(distribution) = plan.required_input_distribution().first()
&& !matches!(distribution, Distribution::UnspecifiedDistribution)
// incorrect workaround, doesn't fix the actual issue
&& plan.name() != "HashJoinExec"
{
distribution_requirement = Some(distribution.clone());
}
// Start from root with no requirement
Self::rewrite_with_distribution(plan, None)
}
if let Some(merge_scan) = plan.as_any().downcast_ref::<MergeScanExec>()
&& let Some(distribution) = distribution_requirement.as_ref()
&& let Some(new_plan) = merge_scan.try_with_new_distribution(distribution.clone())
{
Ok(Transformed::yes(Arc::new(new_plan) as _))
} else {
Ok(Transformed::no(plan))
}
})?;
/// Top-down rewrite that propagates distribution requirements to children.
fn rewrite_with_distribution(
plan: Arc<dyn ExecutionPlan>,
current_req: Option<Distribution>,
) -> DfResult<Arc<dyn ExecutionPlan>> {
// If this is a MergeScanExec, try to apply the current requirement.
if let Some(merge_scan) = plan.as_any().downcast_ref::<MergeScanExec>()
&& let Some(distribution) = current_req.as_ref()
&& let Some(new_plan) = merge_scan.try_with_new_distribution(distribution.clone())
{
// Leaf node; no children to process
return Ok(Arc::new(new_plan) as _);
}
Ok(result.data)
// Compute per-child requirements from the current node.
let children = plan.children();
if children.is_empty() {
return Ok(plan);
}
let required = plan.required_input_distribution();
let mut new_children = Vec::with_capacity(children.len());
for (idx, child) in children.into_iter().enumerate() {
let child_req = match required.get(idx) {
Some(Distribution::UnspecifiedDistribution) => None,
None => current_req.clone(),
Some(req) => Some(req.clone()),
};
let new_child = Self::rewrite_with_distribution(child.clone(), child_req)?;
new_children.push(new_child);
}
// Rebuild the node only if any child changed (pointer inequality)
let unchanged = plan
.children()
.into_iter()
.zip(new_children.iter())
.all(|(old, new)| Arc::ptr_eq(old, new));
if unchanged {
Ok(plan)
} else {
plan.with_new_children(new_children)
}
}
}

View File

@@ -64,29 +64,32 @@ Error: 3000(PlanQuery), Failed to plan SQL: Error during planning: Order by colu
-- SQLNESS REPLACE (partitioning.*) REDACTED
EXPLAIN SELECT a % 2, b FROM test UNION SELECT a % 2 AS k, b FROM test ORDER BY -1;
+---------------+------------------------------------------------------------------------------------------------------------+
| plan_type | plan |
+---------------+------------------------------------------------------------------------------------------------------------+
| logical_plan | Sort: Int64(-1) ASC NULLS LAST |
| | Aggregate: groupBy=[[test.a % Int64(2), b]], aggr=[[]] |
| | Union |
| | MergeScan [is_placeholder=false, remote_input=[ |
| | Projection: CAST(test.a AS Int64) % Int64(2) AS test.a % Int64(2), test.b |
| | TableScan: test |
| | ]] |
| | MergeScan [is_placeholder=false, remote_input=[ |
| | Projection: CAST(test.a AS Int64) % Int64(2) AS test.a % Int64(2), test.b |
| | TableScan: test |
| | ]] |
| physical_plan | CoalescePartitionsExec |
| | AggregateExec: mode=SinglePartitioned, gby=[test.a % Int64(2)@0 as test.a % Int64(2), b@1 as b], aggr=[] |
| | InterleaveExec |
| | CooperativeExec |
| | MergeScanExec: REDACTED
| | CooperativeExec |
| | MergeScanExec: REDACTED
| | |
+---------------+------------------------------------------------------------------------------------------------------------+
+---------------+-----------------------------------------------------------------------------------------------------------+
| plan_type | plan |
+---------------+-----------------------------------------------------------------------------------------------------------+
| logical_plan | Sort: Int64(-1) ASC NULLS LAST |
| | Aggregate: groupBy=[[test.a % Int64(2), b]], aggr=[[]] |
| | Union |
| | MergeScan [is_placeholder=false, remote_input=[ |
| | Projection: CAST(test.a AS Int64) % Int64(2) AS test.a % Int64(2), test.b |
| | TableScan: test |
| | ]] |
| | MergeScan [is_placeholder=false, remote_input=[ |
| | Projection: CAST(test.a AS Int64) % Int64(2) AS test.a % Int64(2), test.b |
| | TableScan: test |
| | ]] |
| physical_plan | CoalescePartitionsExec |
| | AggregateExec: mode=FinalPartitioned, gby=[test.a % Int64(2)@0 as test.a % Int64(2), b@1 as b], aggr=[] |
| | CoalesceBatchesExec: target_batch_size=8192 |
| | RepartitionExec: REDACTED
| | AggregateExec: mode=Partial, gby=[test.a % Int64(2)@0 as test.a % Int64(2), b@1 as b], aggr=[] |
| | InterleaveExec |
| | CooperativeExec |
| | MergeScanExec: REDACTED
| | CooperativeExec |
| | MergeScanExec: REDACTED
| | |
+---------------+-----------------------------------------------------------------------------------------------------------+
SELECT a % 2, b FROM test UNION SELECT a % 2 AS k FROM test ORDER BY -1;