mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-07-05 21:40:38 +00:00
fix: restore distribution after global limit
Signed-off-by: jeremyhi <fengjiachun@gmail.com>
This commit is contained in:
@@ -19,10 +19,11 @@ use datafusion::physical_optimizer::PhysicalOptimizerRule;
|
||||
use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
|
||||
use datafusion::physical_plan::filter::FilterExec;
|
||||
use datafusion::physical_plan::limit::GlobalLimitExec;
|
||||
use datafusion::physical_plan::repartition::RepartitionExec;
|
||||
use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
|
||||
use datafusion::physical_plan::{ExecutionPlan, ExecutionPlanProperties};
|
||||
use datafusion_common::Result as DfResult;
|
||||
use datafusion_physical_expr::OrderingRequirements;
|
||||
use datafusion_physical_expr::{Distribution, OrderingRequirements};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EnsureGlobalLimitForFetch;
|
||||
@@ -54,16 +55,22 @@ impl EnsureGlobalLimitForFetch {
|
||||
let plan = if children.is_empty() {
|
||||
plan
|
||||
} else {
|
||||
let required_input_distribution = plan.required_input_distribution();
|
||||
let required_input_ordering = plan.required_input_ordering();
|
||||
let maintains_input_order = plan.maintains_input_order();
|
||||
let child_parent = ParentContext {
|
||||
global_fetch: provided_global_fetch(&plan),
|
||||
required_ordering: None,
|
||||
required_distribution: Distribution::UnspecifiedDistribution,
|
||||
};
|
||||
let children = children
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(idx, child)| {
|
||||
let required_distribution = required_input_distribution
|
||||
.get(idx)
|
||||
.cloned()
|
||||
.unwrap_or(Distribution::UnspecifiedDistribution);
|
||||
let required_ordering = required_input_ordering
|
||||
.get(idx)
|
||||
.cloned()
|
||||
@@ -78,6 +85,7 @@ impl EnsureGlobalLimitForFetch {
|
||||
});
|
||||
let parent = ParentContext {
|
||||
required_ordering,
|
||||
required_distribution,
|
||||
..child_parent.clone()
|
||||
};
|
||||
Self::optimize_plan(Arc::clone(child), parent)
|
||||
@@ -99,14 +107,30 @@ impl EnsureGlobalLimitForFetch {
|
||||
return Ok(plan);
|
||||
}
|
||||
|
||||
Ok(add_global_fetch(plan, fetch, parent.required_ordering))
|
||||
add_global_fetch(
|
||||
plan,
|
||||
fetch,
|
||||
parent.required_ordering,
|
||||
parent.required_distribution,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
#[derive(Clone)]
|
||||
struct ParentContext {
|
||||
global_fetch: Option<usize>,
|
||||
required_ordering: Option<OrderingRequirements>,
|
||||
required_distribution: Distribution,
|
||||
}
|
||||
|
||||
impl Default for ParentContext {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
global_fetch: None,
|
||||
required_ordering: None,
|
||||
required_distribution: Distribution::UnspecifiedDistribution,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn provided_global_fetch(plan: &Arc<dyn ExecutionPlan>) -> Option<usize> {
|
||||
@@ -121,14 +145,35 @@ fn add_global_fetch(
|
||||
plan: Arc<dyn ExecutionPlan>,
|
||||
fetch: usize,
|
||||
required_ordering: Option<OrderingRequirements>,
|
||||
) -> Arc<dyn ExecutionPlan> {
|
||||
if required_ordering.is_some()
|
||||
required_distribution: Distribution,
|
||||
) -> DfResult<Arc<dyn ExecutionPlan>> {
|
||||
let original_partition_count = plan.output_partitioning().partition_count();
|
||||
let plan = if required_ordering.is_some()
|
||||
&& let Some(ordering) = plan.output_ordering().cloned()
|
||||
{
|
||||
Arc::new(SortPreservingMergeExec::new(ordering, plan).with_fetch(Some(fetch)))
|
||||
as Arc<dyn ExecutionPlan>
|
||||
} else {
|
||||
Arc::new(CoalescePartitionsExec::new(plan).with_fetch(Some(fetch)))
|
||||
as Arc<dyn ExecutionPlan>
|
||||
};
|
||||
|
||||
restore_required_distribution(plan, required_distribution, original_partition_count)
|
||||
}
|
||||
|
||||
fn restore_required_distribution(
|
||||
plan: Arc<dyn ExecutionPlan>,
|
||||
required_distribution: Distribution,
|
||||
partition_count: usize,
|
||||
) -> DfResult<Arc<dyn ExecutionPlan>> {
|
||||
if partition_count <= 1 || !matches!(&required_distribution, Distribution::HashPartitioned(_)) {
|
||||
return Ok(plan);
|
||||
}
|
||||
|
||||
let partitioning = required_distribution.create_partitioning(partition_count);
|
||||
Ok(Arc::new(
|
||||
RepartitionExec::try_new(plan, partitioning)?.with_preserve_order(),
|
||||
))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -139,10 +184,13 @@ mod tests {
|
||||
use datafusion::arrow::record_batch::RecordBatch;
|
||||
use datafusion::physical_expr::expressions::{col, lit};
|
||||
use datafusion::physical_plan::filter::FilterExecBuilder;
|
||||
use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
|
||||
use datafusion::physical_plan::limit::GlobalLimitExec;
|
||||
use datafusion::physical_plan::projection::ProjectionExec;
|
||||
use datafusion::physical_plan::repartition::RepartitionExec;
|
||||
use datafusion::physical_plan::test::TestMemoryExec;
|
||||
use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr};
|
||||
use datafusion_common::{JoinType, NullEquality};
|
||||
use datafusion_physical_expr::{LexOrdering, Partitioning, PhysicalSortExpr};
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -244,7 +292,9 @@ mod tests {
|
||||
filter,
|
||||
1,
|
||||
Some(OrderingRequirements::from(required_ordering)),
|
||||
);
|
||||
Distribution::UnspecifiedDistribution,
|
||||
)
|
||||
.unwrap();
|
||||
let merge = optimized
|
||||
.as_any()
|
||||
.downcast_ref::<SortPreservingMergeExec>()
|
||||
@@ -278,6 +328,42 @@ mod tests {
|
||||
assert_eq!(child.fetch(), Some(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn restores_parent_hash_distribution_after_global_fetch() {
|
||||
let left = filter_fetch(hash_repartition(unordered_input()), 1);
|
||||
let right = hash_repartition(unordered_input());
|
||||
let on = vec![(
|
||||
col("a", left.schema().as_ref()).unwrap(),
|
||||
col("a", right.schema().as_ref()).unwrap(),
|
||||
)];
|
||||
let join = Arc::new(
|
||||
HashJoinExec::try_new(
|
||||
left,
|
||||
right,
|
||||
on,
|
||||
None,
|
||||
&JoinType::Inner,
|
||||
None,
|
||||
PartitionMode::Partitioned,
|
||||
NullEquality::NullEqualsNothing,
|
||||
false,
|
||||
)
|
||||
.unwrap(),
|
||||
) as Arc<dyn ExecutionPlan>;
|
||||
|
||||
let optimized =
|
||||
EnsureGlobalLimitForFetch::optimize_plan(join, ParentContext::default()).unwrap();
|
||||
let left = optimized.children()[0];
|
||||
let repartition = left.as_any().downcast_ref::<RepartitionExec>().unwrap();
|
||||
|
||||
assert!(matches!(
|
||||
repartition.partitioning(),
|
||||
Partitioning::Hash(_, 3)
|
||||
));
|
||||
assert!(repartition.input().as_any().is::<CoalescePartitionsExec>());
|
||||
assert_eq!(repartition.input().fetch(), Some(1));
|
||||
}
|
||||
|
||||
fn unordered_input() -> Arc<dyn ExecutionPlan> {
|
||||
let schema = schema();
|
||||
let batch = batch(schema.clone());
|
||||
@@ -307,6 +393,11 @@ mod tests {
|
||||
)
|
||||
}
|
||||
|
||||
fn hash_repartition(input: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
|
||||
let partitioning = Partitioning::Hash(vec![col("a", input.schema().as_ref()).unwrap()], 3);
|
||||
Arc::new(RepartitionExec::try_new(input, partitioning).unwrap())
|
||||
}
|
||||
|
||||
fn schema() -> Arc<Schema> {
|
||||
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user