diff --git a/src/query/src/optimizer/windowed_sort.rs b/src/query/src/optimizer/windowed_sort.rs index 63150fc1f8..0d3c08bb08 100644 --- a/src/query/src/optimizer/windowed_sort.rs +++ b/src/query/src/optimizer/windowed_sort.rs @@ -19,6 +19,7 @@ use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::sorts::sort::SortExec; +use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion::physical_plan::ExecutionPlan; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::Result as DataFusionResult; @@ -67,10 +68,12 @@ impl WindowedSortPhysicalRule { .transform_down(|plan| { if let Some(sort_exec) = plan.as_any().downcast_ref::() { // TODO: support multiple expr in windowed sort - if !sort_exec.preserve_partitioning() || sort_exec.expr().len() != 1 { + if sort_exec.expr().len() != 1 { return Ok(Transformed::no(plan)); } + let preserve_partitioning = sort_exec.preserve_partitioning(); + let Some(scanner_info) = fetch_partition_range(sort_exec.input().clone())? else { return Ok(Transformed::no(plan)); @@ -110,11 +113,23 @@ impl WindowedSortPhysicalRule { new_input, )?; - return Ok(Transformed { - data: Arc::new(windowed_sort_exec), - transformed: true, - tnr: datafusion_common::tree_node::TreeNodeRecursion::Stop, - }); + if !preserve_partitioning { + let order_preserving_merge = SortPreservingMergeExec::new( + sort_exec.expr().to_vec(), + Arc::new(windowed_sort_exec), + ); + return Ok(Transformed { + data: Arc::new(order_preserving_merge), + transformed: true, + tnr: datafusion_common::tree_node::TreeNodeRecursion::Stop, + }); + } else { + return Ok(Transformed { + data: Arc::new(windowed_sort_exec), + transformed: true, + tnr: datafusion_common::tree_node::TreeNodeRecursion::Stop, + }); + } } Ok(Transformed::no(plan)) @@ -125,6 +140,7 @@ impl WindowedSortPhysicalRule { } } +#[derive(Debug)] struct ScannerInfo { partition_ranges: Vec>, time_index: String,