fix: restore distribution after global limit

Signed-off-by: jeremyhi <fengjiachun@gmail.com>
This commit is contained in:
jeremyhi
2026-07-04 03:05:48 +08:00
parent d6be665b96
commit b30e2f3f4e

View File

@@ -19,10 +19,11 @@ use datafusion::physical_optimizer::PhysicalOptimizerRule;
use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use datafusion::physical_plan::filter::FilterExec;
use datafusion::physical_plan::limit::GlobalLimitExec;
use datafusion::physical_plan::repartition::RepartitionExec;
use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
use datafusion::physical_plan::{ExecutionPlan, ExecutionPlanProperties};
use datafusion_common::Result as DfResult;
use datafusion_physical_expr::OrderingRequirements;
use datafusion_physical_expr::{Distribution, OrderingRequirements};
#[derive(Debug)]
pub struct EnsureGlobalLimitForFetch;
@@ -54,16 +55,22 @@ impl EnsureGlobalLimitForFetch {
let plan = if children.is_empty() {
plan
} else {
let required_input_distribution = plan.required_input_distribution();
let required_input_ordering = plan.required_input_ordering();
let maintains_input_order = plan.maintains_input_order();
let child_parent = ParentContext {
global_fetch: provided_global_fetch(&plan),
required_ordering: None,
required_distribution: Distribution::UnspecifiedDistribution,
};
let children = children
.into_iter()
.enumerate()
.map(|(idx, child)| {
let required_distribution = required_input_distribution
.get(idx)
.cloned()
.unwrap_or(Distribution::UnspecifiedDistribution);
let required_ordering = required_input_ordering
.get(idx)
.cloned()
@@ -78,6 +85,7 @@ impl EnsureGlobalLimitForFetch {
});
let parent = ParentContext {
required_ordering,
required_distribution,
..child_parent.clone()
};
Self::optimize_plan(Arc::clone(child), parent)
@@ -99,14 +107,30 @@ impl EnsureGlobalLimitForFetch {
return Ok(plan);
}
Ok(add_global_fetch(plan, fetch, parent.required_ordering))
add_global_fetch(
plan,
fetch,
parent.required_ordering,
parent.required_distribution,
)
}
}
#[derive(Clone, Default)]
#[derive(Clone)]
struct ParentContext {
global_fetch: Option<usize>,
required_ordering: Option<OrderingRequirements>,
required_distribution: Distribution,
}
impl Default for ParentContext {
fn default() -> Self {
Self {
global_fetch: None,
required_ordering: None,
required_distribution: Distribution::UnspecifiedDistribution,
}
}
}
fn provided_global_fetch(plan: &Arc<dyn ExecutionPlan>) -> Option<usize> {
@@ -121,14 +145,35 @@ fn add_global_fetch(
plan: Arc<dyn ExecutionPlan>,
fetch: usize,
required_ordering: Option<OrderingRequirements>,
) -> Arc<dyn ExecutionPlan> {
if required_ordering.is_some()
required_distribution: Distribution,
) -> DfResult<Arc<dyn ExecutionPlan>> {
let original_partition_count = plan.output_partitioning().partition_count();
let plan = if required_ordering.is_some()
&& let Some(ordering) = plan.output_ordering().cloned()
{
Arc::new(SortPreservingMergeExec::new(ordering, plan).with_fetch(Some(fetch)))
as Arc<dyn ExecutionPlan>
} else {
Arc::new(CoalescePartitionsExec::new(plan).with_fetch(Some(fetch)))
as Arc<dyn ExecutionPlan>
};
restore_required_distribution(plan, required_distribution, original_partition_count)
}
fn restore_required_distribution(
plan: Arc<dyn ExecutionPlan>,
required_distribution: Distribution,
partition_count: usize,
) -> DfResult<Arc<dyn ExecutionPlan>> {
if partition_count <= 1 || !matches!(&required_distribution, Distribution::HashPartitioned(_)) {
return Ok(plan);
}
let partitioning = required_distribution.create_partitioning(partition_count);
Ok(Arc::new(
RepartitionExec::try_new(plan, partitioning)?.with_preserve_order(),
))
}
#[cfg(test)]
@@ -139,10 +184,13 @@ mod tests {
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::physical_expr::expressions::{col, lit};
use datafusion::physical_plan::filter::FilterExecBuilder;
use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
use datafusion::physical_plan::limit::GlobalLimitExec;
use datafusion::physical_plan::projection::ProjectionExec;
use datafusion::physical_plan::repartition::RepartitionExec;
use datafusion::physical_plan::test::TestMemoryExec;
use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr};
use datafusion_common::{JoinType, NullEquality};
use datafusion_physical_expr::{LexOrdering, Partitioning, PhysicalSortExpr};
use super::*;
@@ -244,7 +292,9 @@ mod tests {
filter,
1,
Some(OrderingRequirements::from(required_ordering)),
);
Distribution::UnspecifiedDistribution,
)
.unwrap();
let merge = optimized
.as_any()
.downcast_ref::<SortPreservingMergeExec>()
@@ -278,6 +328,42 @@ mod tests {
assert_eq!(child.fetch(), Some(1));
}
#[test]
fn restores_parent_hash_distribution_after_global_fetch() {
let left = filter_fetch(hash_repartition(unordered_input()), 1);
let right = hash_repartition(unordered_input());
let on = vec![(
col("a", left.schema().as_ref()).unwrap(),
col("a", right.schema().as_ref()).unwrap(),
)];
let join = Arc::new(
HashJoinExec::try_new(
left,
right,
on,
None,
&JoinType::Inner,
None,
PartitionMode::Partitioned,
NullEquality::NullEqualsNothing,
false,
)
.unwrap(),
) as Arc<dyn ExecutionPlan>;
let optimized =
EnsureGlobalLimitForFetch::optimize_plan(join, ParentContext::default()).unwrap();
let left = optimized.children()[0];
let repartition = left.as_any().downcast_ref::<RepartitionExec>().unwrap();
assert!(matches!(
repartition.partitioning(),
Partitioning::Hash(_, 3)
));
assert!(repartition.input().as_any().is::<CoalescePartitionsExec>());
assert_eq!(repartition.input().fetch(), Some(1));
}
fn unordered_input() -> Arc<dyn ExecutionPlan> {
let schema = schema();
let batch = batch(schema.clone());
@@ -307,6 +393,11 @@ mod tests {
)
}
fn hash_repartition(input: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
let partitioning = Partitioning::Hash(vec![col("a", input.schema().as_ref()).unwrap()], 3);
Arc::new(RepartitionExec::try_new(input, partitioning).unwrap())
}
fn schema() -> Arc<Schema> {
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]))
}