Skip to main content

query/
part_sort.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Module for sorting input data within each [`PartitionRange`].
16//!
17//! This module defines the [`PartSortExec`] execution plan, which sorts each
18//! partition ([`PartitionRange`]) independently based on the provided physical
19//! sort expressions.
20
21use std::any::Any;
22use std::pin::Pin;
23use std::sync::Arc;
24use std::task::{Context, Poll};
25
26use arrow::array::{
27    ArrayRef, AsArray, TimestampMicrosecondArray, TimestampMillisecondArray,
28    TimestampNanosecondArray, TimestampSecondArray,
29};
30use arrow::compute::{concat, concat_batches, take_record_batch};
31use arrow_schema::{Schema, SchemaRef};
32use common_recordbatch::{DfRecordBatch, DfSendableRecordBatchStream};
33use common_telemetry::warn;
34use common_time::Timestamp;
35use common_time::timestamp::TimeUnit;
36use datafusion::common::arrow::compute::sort_to_indices;
37use datafusion::execution::memory_pool::{MemoryConsumer, MemoryReservation};
38use datafusion::execution::{RecordBatchStream, TaskContext};
39use datafusion::physical_plan::execution_plan::CardinalityEffect;
40use datafusion::physical_plan::filter_pushdown::{
41    ChildFilterDescription, FilterDescription, FilterPushdownPhase,
42};
43use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
44use datafusion::physical_plan::{
45    DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, TopK,
46    TopKDynamicFilters,
47};
48use datafusion_common::tree_node::{Transformed, TreeNode};
49use datafusion_common::{DataFusionError, internal_err};
50use datafusion_physical_expr::expressions::{Column, DynamicFilterPhysicalExpr, lit};
51use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
52use futures::{Stream, StreamExt};
53use itertools::Itertools;
54use parking_lot::RwLock;
55use snafu::location;
56use store_api::region_engine::PartitionRange;
57
58use crate::error::Result;
59use crate::window_sort::check_partition_range_monotonicity;
60use crate::{array_iter_helper, downcast_ts_array};
61
62/// Get the primary end of a `PartitionRange` based on sort direction.
63///
64/// - Descending: primary end is `end` (we process highest values first)
65/// - Ascending: primary end is `start` (we process lowest values first)
66fn get_primary_end(range: &PartitionRange, descending: bool) -> Timestamp {
67    if descending { range.end } else { range.start }
68}
69
70/// Group consecutive ranges by their primary end value.
71///
72/// Returns a vector of (primary_end, start_idx_inclusive, end_idx_exclusive) tuples.
73/// Ranges with the same primary end MUST be processed together because they may
74/// overlap and contain values that belong to the same "top-k" result.
75fn group_ranges_by_primary_end(
76    ranges: &[PartitionRange],
77    descending: bool,
78) -> Vec<(Timestamp, usize, usize)> {
79    if ranges.is_empty() {
80        return vec![];
81    }
82
83    let mut groups = Vec::new();
84    let mut group_start = 0;
85    let mut current_primary_end = get_primary_end(&ranges[0], descending);
86
87    for (idx, range) in ranges.iter().enumerate().skip(1) {
88        let primary_end = get_primary_end(range, descending);
89        if primary_end != current_primary_end {
90            // End current group
91            groups.push((current_primary_end, group_start, idx));
92            // Start new group
93            group_start = idx;
94            current_primary_end = primary_end;
95        }
96    }
97    // Push the last group
98    groups.push((current_primary_end, group_start, ranges.len()));
99
100    groups
101}
102
103/// Sort input within given PartitionRange
104///
105/// Input is assumed to be segmented by empty RecordBatch, which indicates a new `PartitionRange` is starting
106///
107/// and this operator will sort each partition independently within the partition.
108#[derive(Debug, Clone)]
109pub struct PartSortExec {
110    /// Physical sort expressions(that is, sort by timestamp)
111    expression: PhysicalSortExpr,
112    limit: Option<usize>,
113    input: Arc<dyn ExecutionPlan>,
114    /// Execution metrics
115    metrics: ExecutionPlanMetricsSet,
116    partition_ranges: Vec<Vec<PartitionRange>>,
117    properties: Arc<PlanProperties>,
118    /// Filter matching the state of the sort for dynamic filter pushdown.
119    /// If `limit` is `Some`, this will also be set and a TopK operator may be used.
120    /// If `limit` is `None`, this will be `None`.
121    filter: Option<Arc<RwLock<TopKDynamicFilters>>>,
122}
123
124impl PartSortExec {
125    pub fn try_new(
126        expression: PhysicalSortExpr,
127        limit: Option<usize>,
128        partition_ranges: Vec<Vec<PartitionRange>>,
129        input: Arc<dyn ExecutionPlan>,
130    ) -> Result<Self> {
131        check_partition_range_monotonicity(&partition_ranges, expression.options.descending)?;
132
133        let metrics = ExecutionPlanMetricsSet::new();
134        let properties = input.properties();
135        let properties = Arc::new(PlanProperties::new(
136            input.equivalence_properties().clone(),
137            input.output_partitioning().clone(),
138            properties.emission_type,
139            properties.boundedness,
140        ));
141
142        let filter = limit
143            .is_some()
144            .then(|| Self::create_filter(expression.expr.clone()));
145
146        Ok(Self {
147            expression,
148            limit,
149            input,
150            metrics,
151            partition_ranges,
152            properties,
153            filter,
154        })
155    }
156
157    /// Add or reset `self.filter` to a new `TopKDynamicFilters`.
158    fn create_filter(expr: Arc<dyn PhysicalExpr>) -> Arc<RwLock<TopKDynamicFilters>> {
159        Arc::new(RwLock::new(TopKDynamicFilters::new(Arc::new(
160            DynamicFilterPhysicalExpr::new(vec![expr], lit(true)),
161        ))))
162    }
163
164    pub fn to_stream(
165        &self,
166        context: Arc<TaskContext>,
167        partition: usize,
168    ) -> datafusion_common::Result<DfSendableRecordBatchStream> {
169        let input_stream: DfSendableRecordBatchStream =
170            self.input.execute(partition, context.clone())?;
171
172        if partition >= self.partition_ranges.len() {
173            internal_err!(
174                "Partition index out of range: {} >= {} at {}",
175                partition,
176                self.partition_ranges.len(),
177                snafu::location!()
178            )?;
179        }
180
181        let df_stream = Box::pin(PartSortStream::new(
182            context,
183            self,
184            self.limit,
185            input_stream,
186            self.partition_ranges[partition].clone(),
187            partition,
188            self.filter.clone(),
189        )?) as _;
190
191        Ok(df_stream)
192    }
193}
194
195impl DisplayAs for PartSortExec {
196    fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197        write!(
198            f,
199            "PartSortExec: expr={} num_ranges={}",
200            self.expression,
201            self.partition_ranges.len(),
202        )?;
203        if let Some(limit) = self.limit {
204            write!(f, " limit={}", limit)?;
205        }
206        Ok(())
207    }
208}
209
210impl ExecutionPlan for PartSortExec {
211    fn name(&self) -> &str {
212        "PartSortExec"
213    }
214
215    fn as_any(&self) -> &dyn Any {
216        self
217    }
218
219    fn schema(&self) -> SchemaRef {
220        self.input.schema()
221    }
222
223    fn properties(&self) -> &Arc<PlanProperties> {
224        &self.properties
225    }
226
227    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
228        vec![&self.input]
229    }
230
231    fn with_new_children(
232        self: Arc<Self>,
233        children: Vec<Arc<dyn ExecutionPlan>>,
234    ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
235        let new_input = if let Some(first) = children.first() {
236            first
237        } else {
238            internal_err!("No children found")?
239        };
240        let mut new_exec = self.as_ref().clone();
241        new_exec.input = new_input.clone();
242        new_exec.properties = new_input.properties().clone();
243        Ok(Arc::new(new_exec))
244    }
245
246    fn execute(
247        &self,
248        partition: usize,
249        context: Arc<TaskContext>,
250    ) -> datafusion_common::Result<DfSendableRecordBatchStream> {
251        self.to_stream(context, partition)
252    }
253
254    fn metrics(&self) -> Option<MetricsSet> {
255        Some(self.metrics.clone_inner())
256    }
257
258    /// # Explain
259    ///
260    /// This plan needs to be executed on each partition independently,
261    /// and is expected to run directly on storage engine's output
262    /// distribution / partition.
263    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
264        vec![false]
265    }
266
267    fn cardinality_effect(&self) -> CardinalityEffect {
268        if self.limit.is_none() {
269            CardinalityEffect::Equal
270        } else {
271            CardinalityEffect::LowerEqual
272        }
273    }
274
275    fn gather_filters_for_pushdown(
276        &self,
277        phase: FilterPushdownPhase,
278        parent_filters: Vec<Arc<dyn PhysicalExpr>>,
279        _config: &datafusion::config::ConfigOptions,
280    ) -> datafusion_common::Result<FilterDescription> {
281        if !matches!(phase, FilterPushdownPhase::Post) {
282            return FilterDescription::from_children(parent_filters, &self.children());
283        }
284
285        let mut child = ChildFilterDescription::from_child(&parent_filters, &self.input)?;
286
287        if let Some(filter) = &self.filter {
288            child = child.with_self_filter(filter.read().expr());
289        }
290
291        Ok(FilterDescription::new().with_child(child))
292    }
293
294    fn reset_state(self: Arc<Self>) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
295        // shared dynamic filter needs to be reset
296        let new_filter = self
297            .limit
298            .is_some()
299            .then(|| Self::create_filter(self.expression.expr.clone()));
300
301        Ok(Arc::new(Self {
302            expression: self.expression.clone(),
303            limit: self.limit,
304            input: self.input.clone(),
305            metrics: self.metrics.clone(),
306            partition_ranges: self.partition_ranges.clone(),
307            properties: self.properties.clone(),
308            filter: new_filter,
309        }))
310    }
311}
312
313enum PartSortBuffer {
314    All(Vec<DfRecordBatch>),
315    /// TopK buffer with row count.
316    ///
317    /// Given this heap only keeps k element, the capacity of this buffer
318    /// is not accurate, and is only used for empty check.
319    Top(TopK, usize),
320}
321
322impl PartSortBuffer {
323    pub fn is_empty(&self) -> bool {
324        match self {
325            PartSortBuffer::All(v) => v.is_empty(),
326            PartSortBuffer::Top(_, cnt) => *cnt == 0,
327        }
328    }
329}
330
331struct PartSortStream {
332    /// Memory pool for this stream
333    reservation: MemoryReservation,
334    buffer: PartSortBuffer,
335    expression: PhysicalSortExpr,
336    limit: Option<usize>,
337    input: DfSendableRecordBatchStream,
338    input_complete: bool,
339    schema: SchemaRef,
340    partition_ranges: Vec<PartitionRange>,
341    #[allow(dead_code)] // this is used under #[debug_assertions]
342    partition: usize,
343    cur_part_idx: usize,
344    evaluating_batch: Option<DfRecordBatch>,
345    metrics: BaselineMetrics,
346    context: Arc<TaskContext>,
347    root_metrics: ExecutionPlanMetricsSet,
348    /// Groups of ranges by primary end: (primary_end, start_idx_inclusive, end_idx_exclusive).
349    /// Ranges in the same group must be processed together before outputting results.
350    range_groups: Vec<(Timestamp, usize, usize)>,
351    /// Current group being processed (index into range_groups).
352    cur_group_idx: usize,
353    /// Dynamic Filter for all TopK instance, notice the `PartSortExec`/`PartSortStream`/`TopK` must share the same filter
354    /// so that updates from each `TopK` can be seen by others(and by the table scan operator).
355    filter: Option<Arc<RwLock<TopKDynamicFilters>>>,
356}
357
358impl PartSortStream {
359    fn new(
360        context: Arc<TaskContext>,
361        sort: &PartSortExec,
362        limit: Option<usize>,
363        input: DfSendableRecordBatchStream,
364        partition_ranges: Vec<PartitionRange>,
365        partition: usize,
366        filter: Option<Arc<RwLock<TopKDynamicFilters>>>,
367    ) -> datafusion_common::Result<Self> {
368        let buffer = if let Some(limit) = limit {
369            let Some(filter) = filter.clone() else {
370                return internal_err!(
371                    "TopKDynamicFilters must be provided when limit is set at {}",
372                    snafu::location!()
373                );
374            };
375
376            PartSortBuffer::Top(
377                TopK::try_new(
378                    partition,
379                    sort.schema().clone(),
380                    vec![],
381                    [sort.expression.clone()].into(),
382                    limit,
383                    context.session_config().batch_size(),
384                    context.runtime_env(),
385                    &sort.metrics,
386                    filter.clone(),
387                )?,
388                0,
389            )
390        } else {
391            PartSortBuffer::All(Vec::new())
392        };
393
394        // Compute range groups by primary end
395        let descending = sort.expression.options.descending;
396        let range_groups = group_ranges_by_primary_end(&partition_ranges, descending);
397
398        Ok(Self {
399            reservation: MemoryConsumer::new("PartSortStream".to_string())
400                .register(&context.runtime_env().memory_pool),
401            buffer,
402            expression: sort.expression.clone(),
403            limit,
404            input,
405            input_complete: false,
406            schema: sort.input.schema(),
407            partition_ranges,
408            partition,
409            cur_part_idx: 0,
410            evaluating_batch: None,
411            metrics: BaselineMetrics::new(&sort.metrics, partition),
412            context,
413            root_metrics: sort.metrics.clone(),
414            range_groups,
415            cur_group_idx: 0,
416            filter,
417        })
418    }
419}
420
421macro_rules! array_check_helper {
422    ($t:ty, $unit:expr, $arr:expr, $cur_range:expr, $min_max_idx:expr) => {{
423            if $cur_range.start.unit().as_arrow_time_unit() != $unit
424            || $cur_range.end.unit().as_arrow_time_unit() != $unit
425        {
426            internal_err!(
427                "PartitionRange unit mismatch, expect {:?}, found {:?}",
428                $cur_range.start.unit(),
429                $unit
430            )?;
431        }
432        let arr = $arr
433            .as_any()
434            .downcast_ref::<arrow::array::PrimitiveArray<$t>>()
435            .unwrap();
436
437        let min = arr.value($min_max_idx.0);
438        let max = arr.value($min_max_idx.1);
439        let (min, max) = if min < max{
440            (min, max)
441        } else {
442            (max, min)
443        };
444        let cur_min = $cur_range.start.value();
445        let cur_max = $cur_range.end.value();
446        // note that PartitionRange is left inclusive and right exclusive
447        if !(min >= cur_min && max < cur_max) {
448            internal_err!(
449                "Sort column min/max value out of partition range: sort_column.min_max=[{:?}, {:?}] not in PartitionRange=[{:?}, {:?}]",
450                min,
451                max,
452                cur_min,
453                cur_max
454            )?;
455        }
456    }};
457}
458
459impl PartSortStream {
460    /// check whether the sort column's min/max value is within the current group's effective range.
461    /// For group-based processing, data from multiple ranges with the same primary end
462    /// is accumulated together, so we check against the union of all ranges in the group.
463    fn check_in_range(
464        &self,
465        sort_column: &ArrayRef,
466        min_max_idx: (usize, usize),
467    ) -> datafusion_common::Result<()> {
468        // Use the group's effective range instead of the current partition range
469        let Some(cur_range) = self.get_current_group_effective_range() else {
470            internal_err!(
471                "No effective range for current group {} at {}",
472                self.cur_group_idx,
473                snafu::location!()
474            )?
475        };
476
477        downcast_ts_array!(
478            sort_column.data_type() => (array_check_helper, sort_column, cur_range, min_max_idx),
479            _ => internal_err!(
480                "Unsupported data type for sort column: {:?}",
481                sort_column.data_type()
482            )?,
483        );
484
485        Ok(())
486    }
487
488    /// Try find data whose value exceeds the current partition range.
489    ///
490    /// Returns `None` if no such data is found, and `Some(idx)` where idx points to
491    /// the first data that exceeds the current partition range.
492    fn try_find_next_range(
493        &self,
494        sort_column: &ArrayRef,
495    ) -> datafusion_common::Result<Option<usize>> {
496        if sort_column.is_empty() {
497            return Ok(None);
498        }
499
500        // check if the current partition index is out of range
501        if self.cur_part_idx >= self.partition_ranges.len() {
502            internal_err!(
503                "Partition index out of range: {} >= {} at {}",
504                self.cur_part_idx,
505                self.partition_ranges.len(),
506                snafu::location!()
507            )?;
508        }
509        let cur_range = self.partition_ranges[self.cur_part_idx];
510
511        let sort_column_iter = downcast_ts_array!(
512            sort_column.data_type() => (array_iter_helper, sort_column),
513            _ => internal_err!(
514                "Unsupported data type for sort column: {:?}",
515                sort_column.data_type()
516            )?,
517        );
518
519        for (idx, val) in sort_column_iter {
520            // ignore vacant time index data
521            if let Some(val) = val
522                && (val >= cur_range.end.value() || val < cur_range.start.value())
523            {
524                return Ok(Some(idx));
525            }
526        }
527
528        Ok(None)
529    }
530
531    fn push_buffer(&mut self, batch: DfRecordBatch) -> datafusion_common::Result<()> {
532        match &mut self.buffer {
533            PartSortBuffer::All(v) => v.push(batch),
534            PartSortBuffer::Top(top, cnt) => {
535                *cnt += batch.num_rows();
536                top.insert_batch(batch)?;
537            }
538        }
539
540        Ok(())
541    }
542
543    /// Stop read earlier when current group do not overlap with any of those next group
544    /// If not overlap, we can stop read further input as current top k is final
545    /// Use dynamic filter to evaluate the next group's primary end
546    fn can_stop_early(&mut self, schema: &Arc<Schema>) -> datafusion_common::Result<bool> {
547        let topk_cnt = match &self.buffer {
548            PartSortBuffer::Top(_, cnt) => *cnt,
549            _ => return Ok(false),
550        };
551        // not fulfill topk yet
552        if Some(topk_cnt) < self.limit {
553            return Ok(false);
554        }
555        let next_group_primary_end = if self.cur_group_idx + 1 < self.range_groups.len() {
556            self.range_groups[self.cur_group_idx + 1].0
557        } else {
558            // no next group
559            return Ok(false);
560        };
561
562        // dyn filter is updated based on the last value of topk heap("threshold")
563        // it's a max-heap for a ASC TopK operator
564        // so can use dyn filter to prune data range
565        let filter = self
566            .filter
567            .as_ref()
568            .expect("TopKDynamicFilters must be provided when limit is set");
569        let filter = filter.read().expr().current()?;
570        let mut ts_index = None;
571        // invariant: the filter must contain only the same column expr that's time index column
572        let filter = filter
573            .transform_down(|c| {
574                // rewrite all column's index as 0
575                if let Some(column) = c.as_any().downcast_ref::<Column>() {
576                    ts_index = Some(column.index());
577                    Ok(Transformed::yes(
578                        Arc::new(Column::new(column.name(), 0)) as Arc<dyn PhysicalExpr>
579                    ))
580                } else {
581                    Ok(Transformed::no(c))
582                }
583            })?
584            .data;
585        let Some(ts_index) = ts_index else {
586            return Ok(false); // dyn filter is still true, cannot decide, continue read
587        };
588        let field = if schema.fields().len() <= ts_index {
589            warn!(
590                "Schema mismatch when evaluating dynamic filter for PartSortExec at {}, schema: {:?}, ts_index: {}",
591                self.partition, schema, ts_index
592            );
593            return Ok(false); // schema mismatch, cannot decide, continue read
594        } else {
595            schema.field(ts_index)
596        };
597        let schema = Arc::new(Schema::new(vec![field.clone()]));
598        // convert next_group_primary_end to array&filter, if eval to false, means no overlap, can stop early
599        let primary_end_array = match next_group_primary_end.unit() {
600            TimeUnit::Second => Arc::new(TimestampSecondArray::from(vec![
601                next_group_primary_end.value(),
602            ])) as ArrayRef,
603            TimeUnit::Millisecond => Arc::new(TimestampMillisecondArray::from(vec![
604                next_group_primary_end.value(),
605            ])) as ArrayRef,
606            TimeUnit::Microsecond => Arc::new(TimestampMicrosecondArray::from(vec![
607                next_group_primary_end.value(),
608            ])) as ArrayRef,
609            TimeUnit::Nanosecond => Arc::new(TimestampNanosecondArray::from(vec![
610                next_group_primary_end.value(),
611            ])) as ArrayRef,
612        };
613        let primary_end_batch = DfRecordBatch::try_new(schema, vec![primary_end_array])?;
614        let res = filter.evaluate(&primary_end_batch)?;
615        let array = res.into_array(primary_end_batch.num_rows())?;
616        let filter = array.as_boolean().clone();
617        let overlap = filter.iter().next().flatten();
618        if let Some(false) = overlap {
619            Ok(true)
620        } else {
621            Ok(false)
622        }
623    }
624
625    /// Check if the given partition index is within the current group.
626    fn is_in_current_group(&self, part_idx: usize) -> bool {
627        if self.cur_group_idx >= self.range_groups.len() {
628            return false;
629        }
630        let (_, start, end) = self.range_groups[self.cur_group_idx];
631        part_idx >= start && part_idx < end
632    }
633
634    /// Advance to the next group. Returns true if there is a next group.
635    fn advance_to_next_group(&mut self) -> bool {
636        self.cur_group_idx += 1;
637        self.cur_group_idx < self.range_groups.len()
638    }
639
640    /// Get the effective range for the current group.
641    /// For a group of ranges with the same primary end, the effective range is
642    /// the union of all ranges in the group.
643    fn get_current_group_effective_range(&self) -> Option<PartitionRange> {
644        if self.cur_group_idx >= self.range_groups.len() {
645            return None;
646        }
647        let (_, start_idx, end_idx) = self.range_groups[self.cur_group_idx];
648        if start_idx >= end_idx || start_idx >= self.partition_ranges.len() {
649            return None;
650        }
651
652        let ranges_in_group =
653            &self.partition_ranges[start_idx..end_idx.min(self.partition_ranges.len())];
654        if ranges_in_group.is_empty() {
655            return None;
656        }
657
658        // Compute union of all ranges in the group
659        let mut min_start = ranges_in_group[0].start;
660        let mut max_end = ranges_in_group[0].end;
661        for range in ranges_in_group.iter().skip(1) {
662            if range.start < min_start {
663                min_start = range.start;
664            }
665            if range.end > max_end {
666                max_end = range.end;
667            }
668        }
669
670        Some(PartitionRange {
671            start: min_start,
672            end: max_end,
673            num_rows: 0,   // Not used for validation
674            identifier: 0, // Not used for validation
675        })
676    }
677
678    /// Sort and clear the buffer and return the sorted record batch
679    ///
680    /// this function will return a empty record batch if the buffer is empty
681    fn sort_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
682        match &mut self.buffer {
683            PartSortBuffer::All(_) => self.sort_all_buffer(),
684            PartSortBuffer::Top(_, _) => self.sort_top_buffer(),
685        }
686    }
687
688    /// Internal method for sorting `All` buffer (without limit).
689    fn sort_all_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
690        let PartSortBuffer::All(buffer) =
691            std::mem::replace(&mut self.buffer, PartSortBuffer::All(Vec::new()))
692        else {
693            unreachable!("buffer type is checked before and should be All variant")
694        };
695
696        if buffer.is_empty() {
697            return Ok(DfRecordBatch::new_empty(self.schema.clone()));
698        }
699        let mut sort_columns = Vec::with_capacity(buffer.len());
700        let mut opt = None;
701        for batch in buffer.iter() {
702            let sort_column = self.expression.evaluate_to_sort_column(batch)?;
703            opt = opt.or(sort_column.options);
704            sort_columns.push(sort_column.values);
705        }
706
707        let sort_column =
708            concat(&sort_columns.iter().map(|a| a.as_ref()).collect_vec()).map_err(|e| {
709                DataFusionError::ArrowError(
710                    Box::new(e),
711                    Some(format!("Fail to concat sort columns at {}", location!())),
712                )
713            })?;
714
715        let indices = sort_to_indices(&sort_column, opt, self.limit).map_err(|e| {
716            DataFusionError::ArrowError(
717                Box::new(e),
718                Some(format!("Fail to sort to indices at {}", location!())),
719            )
720        })?;
721        if indices.is_empty() {
722            return Ok(DfRecordBatch::new_empty(self.schema.clone()));
723        }
724
725        self.check_in_range(
726            &sort_column,
727            (
728                indices.value(0) as usize,
729                indices.value(indices.len() - 1) as usize,
730            ),
731        )
732        .inspect_err(|_e| {
733            #[cfg(debug_assertions)]
734            common_telemetry::error!(
735                "Fail to check sort column in range at {}, current_idx: {}, num_rows: {}, err: {}",
736                self.partition,
737                self.cur_part_idx,
738                sort_column.len(),
739                _e
740            );
741        })?;
742
743        // reserve memory for the concat input and sorted output
744        let total_mem: usize = buffer.iter().map(|r| r.get_array_memory_size()).sum();
745        self.reservation.try_grow(total_mem * 2)?;
746
747        let full_input = concat_batches(&self.schema, &buffer).map_err(|e| {
748            DataFusionError::ArrowError(
749                Box::new(e),
750                Some(format!(
751                    "Fail to concat input batches when sorting at {}",
752                    location!()
753                )),
754            )
755        })?;
756
757        let sorted = take_record_batch(&full_input, &indices).map_err(|e| {
758            DataFusionError::ArrowError(
759                Box::new(e),
760                Some(format!(
761                    "Fail to take result record batch when sorting at {}",
762                    location!()
763                )),
764            )
765        })?;
766
767        drop(full_input);
768        // here remove both buffer and full_input memory
769        self.reservation.shrink(2 * total_mem);
770        Ok(sorted)
771    }
772
773    /// Internal method for sorting `Top` buffer (with limit).
774    fn sort_top_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
775        let Some(filter) = self.filter.clone() else {
776            return internal_err!(
777                "TopKDynamicFilters must be provided when sorting with limit at {}",
778                snafu::location!()
779            );
780        };
781
782        let new_top_buffer = TopK::try_new(
783            self.partition,
784            self.schema().clone(),
785            vec![],
786            [self.expression.clone()].into(),
787            self.limit.unwrap(),
788            self.context.session_config().batch_size(),
789            self.context.runtime_env(),
790            &self.root_metrics,
791            filter,
792        )?;
793        let PartSortBuffer::Top(top_k, _) =
794            std::mem::replace(&mut self.buffer, PartSortBuffer::Top(new_top_buffer, 0))
795        else {
796            unreachable!("buffer type is checked before and should be Top variant")
797        };
798
799        let mut result_stream = top_k.emit()?;
800        let mut placeholder_ctx = std::task::Context::from_waker(futures::task::noop_waker_ref());
801        let mut results = vec![];
802        // according to the current implementation of `TopK`, the result stream will always be ready
803        loop {
804            match result_stream.poll_next_unpin(&mut placeholder_ctx) {
805                Poll::Ready(Some(batch)) => {
806                    let batch = batch?;
807                    results.push(batch);
808                }
809                Poll::Pending => {
810                    #[cfg(debug_assertions)]
811                    unreachable!("TopK result stream should always be ready")
812                }
813                Poll::Ready(None) => {
814                    break;
815                }
816            }
817        }
818
819        let concat_batch = concat_batches(&self.schema, &results).map_err(|e| {
820            DataFusionError::ArrowError(
821                Box::new(e),
822                Some(format!(
823                    "Fail to concat top k result record batch when sorting at {}",
824                    location!()
825                )),
826            )
827        })?;
828
829        Ok(concat_batch)
830    }
831
832    /// Sorts current buffer and returns `None` when there is nothing to emit.
833    fn sorted_buffer_if_non_empty(&mut self) -> datafusion_common::Result<Option<DfRecordBatch>> {
834        if self.buffer.is_empty() {
835            return Ok(None);
836        }
837
838        let sorted = self.sort_buffer()?;
839        if sorted.num_rows() == 0 {
840            Ok(None)
841        } else {
842            Ok(Some(sorted))
843        }
844    }
845
846    /// Try to split the input batch if it contains data that exceeds the current partition range.
847    ///
848    /// When the input batch contains data that exceeds the current partition range, this function
849    /// will split the input batch into two parts, the first part is within the current partition
850    /// range will be merged and sorted with previous buffer, and the second part will be registered
851    /// to `evaluating_batch` for next polling.
852    ///
853    /// **Group-based processing**: Ranges with the same primary end are grouped together.
854    /// We only sort and output when transitioning to a NEW group, not when moving between
855    /// ranges within the same group.
856    ///
857    /// Returns `None` if the input batch is empty or fully within the current partition range
858    /// (or we're still collecting data within the same group), and `Some(batch)` when we've
859    /// completed a group and have sorted output. When operating in TopK (limit) mode, this
860    /// function will not emit intermediate batches; it only prepares state for a single final
861    /// output.
862    fn split_batch(
863        &mut self,
864        batch: DfRecordBatch,
865    ) -> datafusion_common::Result<Option<DfRecordBatch>> {
866        if matches!(self.buffer, PartSortBuffer::Top(_, _)) {
867            self.split_batch_topk(batch)?;
868            return Ok(None);
869        }
870
871        self.split_batch_all(batch)
872    }
873
874    /// Specialized splitting logic for TopK (limit) mode.
875    ///
876    /// We only emit once when the TopK buffer is fulfilled or when input is fully consumed.
877    /// When the buffer is fulfilled and we are about to enter a new group, we stop consuming
878    /// further ranges.
879    fn split_batch_topk(&mut self, batch: DfRecordBatch) -> datafusion_common::Result<()> {
880        if batch.num_rows() == 0 {
881            return Ok(());
882        }
883
884        let sort_column = self
885            .expression
886            .expr
887            .evaluate(&batch)?
888            .into_array(batch.num_rows())?;
889
890        let next_range_idx = self.try_find_next_range(&sort_column)?;
891        let Some(idx) = next_range_idx else {
892            self.push_buffer(batch)?;
893            // keep polling input for next batch
894            return Ok(());
895        };
896
897        let this_range = batch.slice(0, idx);
898        let remaining_range = batch.slice(idx, batch.num_rows() - idx);
899        if this_range.num_rows() != 0 {
900            self.push_buffer(this_range)?;
901        }
902
903        // Step to next proper PartitionRange
904        self.cur_part_idx += 1;
905
906        // If we've processed all partitions, mark completion.
907        if self.cur_part_idx >= self.partition_ranges.len() {
908            debug_assert!(remaining_range.num_rows() == 0);
909            self.input_complete = true;
910            return Ok(());
911        }
912
913        // Check if we're still in the same group
914        let in_same_group = self.is_in_current_group(self.cur_part_idx);
915
916        // When TopK is fulfilled and we are switching to a new group, stop consuming further ranges if possible.
917        // read from topk heap and determine whether we can stop earlier.
918        if !in_same_group && self.can_stop_early(&batch.schema())? {
919            self.input_complete = true;
920            self.evaluating_batch = None;
921            return Ok(());
922        }
923
924        // Transition to a new group if needed
925        if !in_same_group {
926            self.advance_to_next_group();
927        }
928
929        let next_sort_column = sort_column.slice(idx, batch.num_rows() - idx);
930        if self.try_find_next_range(&next_sort_column)?.is_some() {
931            // remaining batch still contains data that exceeds the current partition range
932            // register the remaining batch for next polling
933            self.evaluating_batch = Some(remaining_range);
934        } else if remaining_range.num_rows() != 0 {
935            // remaining batch is within the current partition range
936            // push to the buffer and continue polling
937            self.push_buffer(remaining_range)?;
938        }
939
940        Ok(())
941    }
942
943    fn split_batch_all(
944        &mut self,
945        batch: DfRecordBatch,
946    ) -> datafusion_common::Result<Option<DfRecordBatch>> {
947        if batch.num_rows() == 0 {
948            return Ok(None);
949        }
950
951        let sort_column = self
952            .expression
953            .expr
954            .evaluate(&batch)?
955            .into_array(batch.num_rows())?;
956
957        let next_range_idx = self.try_find_next_range(&sort_column)?;
958        let Some(idx) = next_range_idx else {
959            self.push_buffer(batch)?;
960            // keep polling input for next batch
961            return Ok(None);
962        };
963
964        let this_range = batch.slice(0, idx);
965        let remaining_range = batch.slice(idx, batch.num_rows() - idx);
966        if this_range.num_rows() != 0 {
967            self.push_buffer(this_range)?;
968        }
969
970        // Step to next proper PartitionRange
971        self.cur_part_idx += 1;
972
973        // If we've processed all partitions, sort and output
974        if self.cur_part_idx >= self.partition_ranges.len() {
975            // assert there is no data beyond the last partition range (remaining is empty).
976            debug_assert!(remaining_range.num_rows() == 0);
977
978            // Sort and output the final group
979            return self.sorted_buffer_if_non_empty();
980        }
981
982        // Check if we're still in the same group
983        if self.is_in_current_group(self.cur_part_idx) {
984            // Same group - don't sort yet, keep collecting
985            let next_sort_column = sort_column.slice(idx, batch.num_rows() - idx);
986            if self.try_find_next_range(&next_sort_column)?.is_some() {
987                // remaining batch still contains data that exceeds the current partition range
988                self.evaluating_batch = Some(remaining_range);
989            } else {
990                // remaining batch is within the current partition range
991                if remaining_range.num_rows() != 0 {
992                    self.push_buffer(remaining_range)?;
993                }
994            }
995            // Return None to continue collecting within the same group
996            return Ok(None);
997        }
998
999        // Transitioning to a new group - sort current group and output
1000        let sorted_batch = self.sorted_buffer_if_non_empty()?;
1001        self.advance_to_next_group();
1002
1003        let next_sort_column = sort_column.slice(idx, batch.num_rows() - idx);
1004        if self.try_find_next_range(&next_sort_column)?.is_some() {
1005            // remaining batch still contains data that exceeds the current partition range
1006            // register the remaining batch for next polling
1007            self.evaluating_batch = Some(remaining_range);
1008        } else {
1009            // remaining batch is within the current partition range
1010            // push to the buffer and continue polling
1011            if remaining_range.num_rows() != 0 {
1012                self.push_buffer(remaining_range)?;
1013            }
1014        }
1015
1016        Ok(sorted_batch)
1017    }
1018
1019    pub fn poll_next_inner(
1020        mut self: Pin<&mut Self>,
1021        cx: &mut Context<'_>,
1022    ) -> Poll<Option<datafusion_common::Result<DfRecordBatch>>> {
1023        loop {
1024            if self.input_complete {
1025                if let Some(sorted_batch) = self.sorted_buffer_if_non_empty()? {
1026                    return Poll::Ready(Some(Ok(sorted_batch)));
1027                }
1028                return Poll::Ready(None);
1029            }
1030
1031            // if there is a remaining batch being evaluated from last run,
1032            // split on it instead of fetching new batch
1033            if let Some(evaluating_batch) = self.evaluating_batch.take()
1034                && evaluating_batch.num_rows() != 0
1035            {
1036                // Check if we've already processed all partitions
1037                if self.cur_part_idx >= self.partition_ranges.len() {
1038                    // All partitions processed, discard remaining data
1039                    if let Some(sorted_batch) = self.sorted_buffer_if_non_empty()? {
1040                        return Poll::Ready(Some(Ok(sorted_batch)));
1041                    }
1042                    return Poll::Ready(None);
1043                }
1044
1045                if let Some(sorted_batch) = self.split_batch(evaluating_batch)? {
1046                    return Poll::Ready(Some(Ok(sorted_batch)));
1047                }
1048                continue;
1049            }
1050
1051            // fetch next batch from input
1052            let res = self.input.as_mut().poll_next(cx);
1053            match res {
1054                Poll::Ready(Some(Ok(batch))) => {
1055                    if let Some(sorted_batch) = self.split_batch(batch)? {
1056                        return Poll::Ready(Some(Ok(sorted_batch)));
1057                    }
1058                }
1059                // input stream end, mark and continue
1060                Poll::Ready(None) => {
1061                    self.input_complete = true;
1062                }
1063                Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
1064                Poll::Pending => return Poll::Pending,
1065            }
1066        }
1067    }
1068}
1069
1070impl Stream for PartSortStream {
1071    type Item = datafusion_common::Result<DfRecordBatch>;
1072
1073    fn poll_next(
1074        mut self: Pin<&mut Self>,
1075        cx: &mut Context<'_>,
1076    ) -> Poll<Option<datafusion_common::Result<DfRecordBatch>>> {
1077        let result = self.as_mut().poll_next_inner(cx);
1078        self.metrics.record_poll(result)
1079    }
1080}
1081
1082impl RecordBatchStream for PartSortStream {
1083    fn schema(&self) -> SchemaRef {
1084        self.schema.clone()
1085    }
1086}
1087
1088#[cfg(test)]
1089mod test {
1090    use std::sync::Arc;
1091
1092    use arrow::array::{
1093        TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
1094        TimestampSecondArray,
1095    };
1096    use arrow::json::ArrayWriter;
1097    use arrow_schema::{DataType, Field, Schema, SortOptions, TimeUnit};
1098    use common_time::Timestamp;
1099    use datafusion_physical_expr::expressions::Column;
1100    use futures::StreamExt;
1101    use store_api::region_engine::PartitionRange;
1102
1103    use super::*;
1104    use crate::test_util::{MockInputExec, new_ts_array};
1105
1106    #[tokio::test]
1107    async fn test_can_stop_early_with_empty_topk_buffer() {
1108        let unit = TimeUnit::Millisecond;
1109        let schema = Arc::new(Schema::new(vec![Field::new(
1110            "ts",
1111            DataType::Timestamp(unit, None),
1112            false,
1113        )]));
1114
1115        // Build a minimal PartSortExec and stream, but inject a dynamic filter that
1116        // always evaluates to false so TopK will filter out all rows internally.
1117        let mock_input = Arc::new(MockInputExec::new(vec![vec![]], schema.clone()));
1118        let exec = PartSortExec::try_new(
1119            PhysicalSortExpr {
1120                expr: Arc::new(Column::new("ts", 0)),
1121                options: SortOptions {
1122                    descending: true,
1123                    ..Default::default()
1124                },
1125            },
1126            Some(3),
1127            vec![vec![]],
1128            mock_input.clone(),
1129        )
1130        .unwrap();
1131
1132        let filter = Arc::new(RwLock::new(TopKDynamicFilters::new(Arc::new(
1133            DynamicFilterPhysicalExpr::new(vec![], lit(false)),
1134        ))));
1135
1136        let input_stream = mock_input
1137            .execute(0, Arc::new(TaskContext::default()))
1138            .unwrap();
1139        let mut stream = PartSortStream::new(
1140            Arc::new(TaskContext::default()),
1141            &exec,
1142            Some(3),
1143            input_stream,
1144            vec![],
1145            0,
1146            Some(filter),
1147        )
1148        .unwrap();
1149
1150        // Push 3 rows so the external counter reaches `limit`, while TopK keeps no rows.
1151        let batch = DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![1, 2, 3])])
1152            .unwrap();
1153        stream.push_buffer(batch).unwrap();
1154
1155        // The TopK result buffer is empty, so we cannot determine early-stop.
1156        // Ensure this path returns `Ok(false)` (and, importantly, does not panic).
1157        assert!(!stream.can_stop_early(&schema).unwrap());
1158    }
1159
1160    #[ignore = "hard to gen expected data correctly here, TODO(discord9): fix it later"]
1161    #[tokio::test]
1162    async fn fuzzy_test() {
1163        let test_cnt = 100;
1164        // bound for total count of PartitionRange
1165        let part_cnt_bound = 100;
1166        // bound for timestamp range size and offset for each PartitionRange
1167        let range_size_bound = 100;
1168        let range_offset_bound = 100;
1169        // bound for batch count and size within each PartitionRange
1170        let batch_cnt_bound = 20;
1171        let batch_size_bound = 100;
1172
1173        let mut rng = fastrand::Rng::new();
1174        rng.seed(1337);
1175
1176        let mut test_cases = Vec::new();
1177
1178        for case_id in 0..test_cnt {
1179            let mut bound_val: Option<i64> = None;
1180            let descending = rng.bool();
1181            let nulls_first = rng.bool();
1182            let opt = SortOptions {
1183                descending,
1184                nulls_first,
1185            };
1186            let limit = if rng.bool() {
1187                Some(rng.usize(1..batch_cnt_bound * batch_size_bound))
1188            } else {
1189                None
1190            };
1191            let unit = match rng.u8(0..3) {
1192                0 => TimeUnit::Second,
1193                1 => TimeUnit::Millisecond,
1194                2 => TimeUnit::Microsecond,
1195                _ => TimeUnit::Nanosecond,
1196            };
1197
1198            let schema = Schema::new(vec![Field::new(
1199                "ts",
1200                DataType::Timestamp(unit, None),
1201                false,
1202            )]);
1203            let schema = Arc::new(schema);
1204
1205            let mut input_ranged_data = vec![];
1206            let mut output_ranges = vec![];
1207            let mut output_data = vec![];
1208            // generate each input `PartitionRange`
1209            for part_id in 0..rng.usize(0..part_cnt_bound) {
1210                // generate each `PartitionRange`'s timestamp range
1211                let (start, end) = if descending {
1212                    // Use 1..=range_offset_bound to ensure strictly decreasing end values
1213                    let end = bound_val
1214                        .map(
1215                            |i| i
1216                            .checked_sub(rng.i64(1..=range_offset_bound))
1217                            .expect("Bad luck, fuzzy test generate data that will overflow, change seed and try again")
1218                        )
1219                        .unwrap_or_else(|| rng.i64(-100000000..100000000));
1220                    bound_val = Some(end);
1221                    let start = end - rng.i64(1..range_size_bound);
1222                    let start = Timestamp::new(start, unit.into());
1223                    let end = Timestamp::new(end, unit.into());
1224                    (start, end)
1225                } else {
1226                    // Use 1..=range_offset_bound to ensure strictly increasing start values
1227                    let start = bound_val
1228                        .map(|i| i + rng.i64(1..=range_offset_bound))
1229                        .unwrap_or_else(|| rng.i64(..));
1230                    bound_val = Some(start);
1231                    let end = start + rng.i64(1..range_size_bound);
1232                    let start = Timestamp::new(start, unit.into());
1233                    let end = Timestamp::new(end, unit.into());
1234                    (start, end)
1235                };
1236                assert!(start < end);
1237
1238                let mut per_part_sort_data = vec![];
1239                let mut batches = vec![];
1240                for _batch_idx in 0..rng.usize(1..batch_cnt_bound) {
1241                    let cnt = rng.usize(0..batch_size_bound) + 1;
1242                    let iter = 0..rng.usize(0..cnt);
1243                    let mut data_gen = iter
1244                        .map(|_| rng.i64(start.value()..end.value()))
1245                        .collect_vec();
1246                    if data_gen.is_empty() {
1247                        // current batch is empty, skip
1248                        continue;
1249                    }
1250                    // mito always sort on ASC order
1251                    data_gen.sort();
1252                    per_part_sort_data.extend(data_gen.clone());
1253                    let arr = new_ts_array(unit, data_gen.clone());
1254                    let batch = DfRecordBatch::try_new(schema.clone(), vec![arr]).unwrap();
1255                    batches.push(batch);
1256                }
1257
1258                let range = PartitionRange {
1259                    start,
1260                    end,
1261                    num_rows: batches.iter().map(|b| b.num_rows()).sum(),
1262                    identifier: part_id,
1263                };
1264                input_ranged_data.push((range, batches));
1265
1266                output_ranges.push(range);
1267                if per_part_sort_data.is_empty() {
1268                    continue;
1269                }
1270                output_data.extend_from_slice(&per_part_sort_data);
1271            }
1272
1273            // adjust output data with adjacent PartitionRanges
1274            let mut output_data_iter = output_data.iter().peekable();
1275            let mut output_data = vec![];
1276            for range in output_ranges.clone() {
1277                let mut cur_data = vec![];
1278                while let Some(val) = output_data_iter.peek() {
1279                    if **val < range.start.value() || **val >= range.end.value() {
1280                        break;
1281                    }
1282                    cur_data.push(*output_data_iter.next().unwrap());
1283                }
1284
1285                if cur_data.is_empty() {
1286                    continue;
1287                }
1288
1289                if descending {
1290                    cur_data.sort_by(|a, b| b.cmp(a));
1291                } else {
1292                    cur_data.sort();
1293                }
1294                output_data.push(cur_data);
1295            }
1296
1297            let expected_output = if let Some(limit) = limit {
1298                let mut accumulated = Vec::new();
1299                let mut seen = 0usize;
1300                for mut range_values in output_data {
1301                    seen += range_values.len();
1302                    accumulated.append(&mut range_values);
1303                    if seen >= limit {
1304                        break;
1305                    }
1306                }
1307
1308                if accumulated.is_empty() {
1309                    None
1310                } else {
1311                    if descending {
1312                        accumulated.sort_by(|a, b| b.cmp(a));
1313                    } else {
1314                        accumulated.sort();
1315                    }
1316                    accumulated.truncate(limit.min(accumulated.len()));
1317
1318                    Some(
1319                        DfRecordBatch::try_new(
1320                            schema.clone(),
1321                            vec![new_ts_array(unit, accumulated)],
1322                        )
1323                        .unwrap(),
1324                    )
1325                }
1326            } else {
1327                let batches = output_data
1328                    .into_iter()
1329                    .map(|a| {
1330                        DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, a)]).unwrap()
1331                    })
1332                    .collect_vec();
1333                if batches.is_empty() {
1334                    None
1335                } else {
1336                    Some(concat_batches(&schema, &batches).unwrap())
1337                }
1338            };
1339
1340            test_cases.push((
1341                case_id,
1342                unit,
1343                input_ranged_data,
1344                schema,
1345                opt,
1346                limit,
1347                expected_output,
1348            ));
1349        }
1350
1351        for (case_id, _unit, input_ranged_data, schema, opt, limit, expected_output) in test_cases {
1352            run_test(
1353                case_id,
1354                input_ranged_data,
1355                schema,
1356                opt,
1357                limit,
1358                expected_output,
1359                None,
1360            )
1361            .await;
1362        }
1363    }
1364
1365    #[tokio::test]
1366    async fn simple_cases() {
1367        let testcases = vec![
1368            (
1369                TimeUnit::Millisecond,
1370                vec![
1371                    ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]]),
1372                    ((5, 10), vec![vec![5, 6], vec![7, 8]]),
1373                ],
1374                false,
1375                None,
1376                vec![vec![1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9]],
1377            ),
1378            // Case 1: Descending sort with overlapping ranges that have the same primary end (end=10).
1379            // Ranges [5,10) and [0,10) are grouped together, so their data is merged before sorting.
1380            (
1381                TimeUnit::Millisecond,
1382                vec![
1383                    ((5, 10), vec![vec![5, 6], vec![7, 8, 9]]),
1384                    ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
1385                ],
1386                true,
1387                None,
1388                vec![vec![9, 8, 8, 7, 7, 6, 6, 5, 5, 4, 3, 2, 1]],
1389            ),
1390            (
1391                TimeUnit::Millisecond,
1392                vec![
1393                    ((5, 10), vec![]),
1394                    ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
1395                ],
1396                true,
1397                None,
1398                vec![vec![8, 7, 6, 5, 4, 3, 2, 1]],
1399            ),
1400            (
1401                TimeUnit::Millisecond,
1402                vec![
1403                    ((15, 20), vec![vec![17, 18, 19]]),
1404                    ((10, 15), vec![]),
1405                    ((5, 10), vec![]),
1406                    ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
1407                ],
1408                true,
1409                None,
1410                vec![vec![19, 18, 17], vec![8, 7, 6, 5, 4, 3, 2, 1]],
1411            ),
1412            (
1413                TimeUnit::Millisecond,
1414                vec![
1415                    ((15, 20), vec![]),
1416                    ((10, 15), vec![]),
1417                    ((5, 10), vec![]),
1418                    ((0, 10), vec![]),
1419                ],
1420                true,
1421                None,
1422                vec![],
1423            ),
1424            // Case 5: Data from one batch spans multiple ranges. Ranges with same end are grouped.
1425            // Ranges: [15,20) end=20, [10,15) end=15, [5,10) end=10, [0,10) end=10
1426            // Groups: {[15,20)}, {[10,15)}, {[5,10), [0,10)}
1427            // The last two ranges are merged because they share end=10.
1428            (
1429                TimeUnit::Millisecond,
1430                vec![
1431                    (
1432                        (15, 20),
1433                        vec![vec![15, 17, 19, 10, 11, 12, 5, 6, 7, 8, 9, 1, 2, 3, 4]],
1434                    ),
1435                    ((10, 15), vec![]),
1436                    ((5, 10), vec![]),
1437                    ((0, 10), vec![]),
1438                ],
1439                true,
1440                None,
1441                vec![
1442                    vec![19, 17, 15],
1443                    vec![12, 11, 10],
1444                    vec![9, 8, 7, 6, 5, 4, 3, 2, 1],
1445                ],
1446            ),
1447            (
1448                TimeUnit::Millisecond,
1449                vec![
1450                    (
1451                        (15, 20),
1452                        vec![vec![15, 17, 19, 10, 11, 12, 5, 6, 7, 8, 9, 1, 2, 3, 4]],
1453                    ),
1454                    ((10, 15), vec![]),
1455                    ((5, 10), vec![]),
1456                    ((0, 10), vec![]),
1457                ],
1458                true,
1459                Some(2),
1460                vec![vec![19, 17]],
1461            ),
1462        ];
1463
1464        for (identifier, (unit, input_ranged_data, descending, limit, expected_output)) in
1465            testcases.into_iter().enumerate()
1466        {
1467            let schema = Schema::new(vec![Field::new(
1468                "ts",
1469                DataType::Timestamp(unit, None),
1470                false,
1471            )]);
1472            let schema = Arc::new(schema);
1473            let opt = SortOptions {
1474                descending,
1475                ..Default::default()
1476            };
1477
1478            let input_ranged_data = input_ranged_data
1479                .into_iter()
1480                .map(|(range, data)| {
1481                    let part = PartitionRange {
1482                        start: Timestamp::new(range.0, unit.into()),
1483                        end: Timestamp::new(range.1, unit.into()),
1484                        num_rows: data.iter().map(|b| b.len()).sum(),
1485                        identifier,
1486                    };
1487
1488                    let batches = data
1489                        .into_iter()
1490                        .map(|b| {
1491                            let arr = new_ts_array(unit, b);
1492                            DfRecordBatch::try_new(schema.clone(), vec![arr]).unwrap()
1493                        })
1494                        .collect_vec();
1495                    (part, batches)
1496                })
1497                .collect_vec();
1498
1499            let expected_output = expected_output
1500                .into_iter()
1501                .map(|a| {
1502                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, a)]).unwrap()
1503                })
1504                .collect_vec();
1505            let expected_output = if expected_output.is_empty() {
1506                None
1507            } else {
1508                Some(concat_batches(&schema, &expected_output).unwrap())
1509            };
1510
1511            run_test(
1512                identifier,
1513                input_ranged_data,
1514                schema.clone(),
1515                opt,
1516                limit,
1517                expected_output,
1518                None,
1519            )
1520            .await;
1521        }
1522    }
1523
1524    #[allow(clippy::print_stdout)]
1525    async fn run_test(
1526        case_id: usize,
1527        input_ranged_data: Vec<(PartitionRange, Vec<DfRecordBatch>)>,
1528        schema: SchemaRef,
1529        opt: SortOptions,
1530        limit: Option<usize>,
1531        expected_output: Option<DfRecordBatch>,
1532        expected_polled_rows: Option<usize>,
1533    ) {
1534        if let (Some(limit), Some(rb)) = (limit, &expected_output) {
1535            assert!(
1536                rb.num_rows() <= limit,
1537                "Expect row count in expected output({}) <= limit({})",
1538                rb.num_rows(),
1539                limit
1540            );
1541        }
1542
1543        let mut data_partition = Vec::with_capacity(input_ranged_data.len());
1544        let mut ranges = Vec::with_capacity(input_ranged_data.len());
1545        for (part_range, batches) in input_ranged_data {
1546            data_partition.push(batches);
1547            ranges.push(part_range);
1548        }
1549
1550        let mock_input = Arc::new(MockInputExec::new(data_partition, schema.clone()));
1551
1552        let exec = PartSortExec::try_new(
1553            PhysicalSortExpr {
1554                expr: Arc::new(Column::new("ts", 0)),
1555                options: opt,
1556            },
1557            limit,
1558            vec![ranges.clone()],
1559            mock_input.clone(),
1560        )
1561        .unwrap();
1562
1563        let exec_stream = exec.execute(0, Arc::new(TaskContext::default())).unwrap();
1564
1565        let real_output = exec_stream.map(|r| r.unwrap()).collect::<Vec<_>>().await;
1566        if limit.is_some() {
1567            assert!(
1568                real_output.len() <= 1,
1569                "case_{case_id} expects a single output batch when limit is set, got {}",
1570                real_output.len()
1571            );
1572        }
1573
1574        let actual_output = if real_output.is_empty() {
1575            None
1576        } else {
1577            Some(concat_batches(&schema, &real_output).unwrap())
1578        };
1579
1580        if let Some(expected_polled_rows) = expected_polled_rows {
1581            let input_pulled_rows = mock_input.metrics().unwrap().output_rows().unwrap();
1582            assert_eq!(input_pulled_rows, expected_polled_rows);
1583        }
1584
1585        match (actual_output, expected_output) {
1586            (None, None) => {}
1587            (Some(actual), Some(expected)) => {
1588                if actual != expected {
1589                    let mut actual_json: Vec<u8> = Vec::new();
1590                    let mut writer = ArrayWriter::new(&mut actual_json);
1591                    writer.write(&actual).unwrap();
1592                    writer.finish().unwrap();
1593
1594                    let mut expected_json: Vec<u8> = Vec::new();
1595                    let mut writer = ArrayWriter::new(&mut expected_json);
1596                    writer.write(&expected).unwrap();
1597                    writer.finish().unwrap();
1598
1599                    panic!(
1600                        "case_{} failed (limit {limit:?}), opt: {:?},\nreal_output: {}\nexpected: {}",
1601                        case_id,
1602                        opt,
1603                        String::from_utf8_lossy(&actual_json),
1604                        String::from_utf8_lossy(&expected_json),
1605                    );
1606                }
1607            }
1608            (None, Some(expected)) => panic!(
1609                "case_{} failed (limit {limit:?}), opt: {:?},\nreal output is empty, expected {} rows",
1610                case_id,
1611                opt,
1612                expected.num_rows()
1613            ),
1614            (Some(actual), None) => panic!(
1615                "case_{} failed (limit {limit:?}), opt: {:?},\nreal output has {} rows, expected empty",
1616                case_id,
1617                opt,
1618                actual.num_rows()
1619            ),
1620        }
1621    }
1622
1623    /// Test that verifies the limit is correctly applied per partition when
1624    /// multiple batches are received for the same partition.
1625    #[tokio::test]
1626    async fn test_limit_with_multiple_batches_per_partition() {
1627        let unit = TimeUnit::Millisecond;
1628        let schema = Arc::new(Schema::new(vec![Field::new(
1629            "ts",
1630            DataType::Timestamp(unit, None),
1631            false,
1632        )]));
1633
1634        // Test case: Multiple batches in a single partition with limit=3
1635        // Input: 3 batches with [1,2,3], [4,5,6], [7,8,9] all in partition (0,10)
1636        // Expected: Only top 3 values [9,8,7] for descending sort
1637        let input_ranged_data = vec![(
1638            PartitionRange {
1639                start: Timestamp::new(0, unit.into()),
1640                end: Timestamp::new(10, unit.into()),
1641                num_rows: 9,
1642                identifier: 0,
1643            },
1644            vec![
1645                DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![1, 2, 3])])
1646                    .unwrap(),
1647                DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![4, 5, 6])])
1648                    .unwrap(),
1649                DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![7, 8, 9])])
1650                    .unwrap(),
1651            ],
1652        )];
1653
1654        let expected_output = Some(
1655            DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![9, 8, 7])])
1656                .unwrap(),
1657        );
1658
1659        run_test(
1660            1000,
1661            input_ranged_data,
1662            schema.clone(),
1663            SortOptions {
1664                descending: true,
1665                ..Default::default()
1666            },
1667            Some(3),
1668            expected_output,
1669            None,
1670        )
1671        .await;
1672
1673        // Test case: Multiple batches across multiple partitions with limit=2
1674        // Partition 0: batches [10,11,12], [13,14,15] -> top 2 descending = [15,14]
1675        // Partition 1: batches [1,2,3], [4,5] -> top 2 descending = [5,4]
1676        let input_ranged_data = vec![
1677            (
1678                PartitionRange {
1679                    start: Timestamp::new(10, unit.into()),
1680                    end: Timestamp::new(20, unit.into()),
1681                    num_rows: 6,
1682                    identifier: 0,
1683                },
1684                vec![
1685                    DfRecordBatch::try_new(
1686                        schema.clone(),
1687                        vec![new_ts_array(unit, vec![10, 11, 12])],
1688                    )
1689                    .unwrap(),
1690                    DfRecordBatch::try_new(
1691                        schema.clone(),
1692                        vec![new_ts_array(unit, vec![13, 14, 15])],
1693                    )
1694                    .unwrap(),
1695                ],
1696            ),
1697            (
1698                PartitionRange {
1699                    start: Timestamp::new(0, unit.into()),
1700                    end: Timestamp::new(10, unit.into()),
1701                    num_rows: 5,
1702                    identifier: 1,
1703                },
1704                vec![
1705                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![1, 2, 3])])
1706                        .unwrap(),
1707                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![4, 5])])
1708                        .unwrap(),
1709                ],
1710            ),
1711        ];
1712
1713        let expected_output = Some(
1714            DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![15, 14])]).unwrap(),
1715        );
1716
1717        run_test(
1718            1001,
1719            input_ranged_data,
1720            schema.clone(),
1721            SortOptions {
1722                descending: true,
1723                ..Default::default()
1724            },
1725            Some(2),
1726            expected_output,
1727            None,
1728        )
1729        .await;
1730
1731        // Test case: Ascending sort with limit
1732        // Partition: batches [7,8,9], [4,5,6], [1,2,3] -> top 2 ascending = [1,2]
1733        let input_ranged_data = vec![(
1734            PartitionRange {
1735                start: Timestamp::new(0, unit.into()),
1736                end: Timestamp::new(10, unit.into()),
1737                num_rows: 9,
1738                identifier: 0,
1739            },
1740            vec![
1741                DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![7, 8, 9])])
1742                    .unwrap(),
1743                DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![4, 5, 6])])
1744                    .unwrap(),
1745                DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![1, 2, 3])])
1746                    .unwrap(),
1747            ],
1748        )];
1749
1750        let expected_output = Some(
1751            DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![1, 2])]).unwrap(),
1752        );
1753
1754        run_test(
1755            1002,
1756            input_ranged_data,
1757            schema.clone(),
1758            SortOptions {
1759                descending: false,
1760                ..Default::default()
1761            },
1762            Some(2),
1763            expected_output,
1764            None,
1765        )
1766        .await;
1767    }
1768
1769    /// Test that verifies early termination behavior.
1770    /// Once we've produced limit * num_partitions rows, we should stop
1771    /// pulling from input stream.
1772    #[tokio::test]
1773    async fn test_early_termination() {
1774        let unit = TimeUnit::Millisecond;
1775        let schema = Arc::new(Schema::new(vec![Field::new(
1776            "ts",
1777            DataType::Timestamp(unit, None),
1778            false,
1779        )]));
1780
1781        // Create 3 partitions, each with more data than the limit
1782        // limit=2 per partition, so total expected output = 6 rows
1783        // After producing 6 rows, early termination should kick in
1784        // For descending sort, ranges must be ordered by (end DESC, start DESC)
1785        let input_ranged_data = vec![
1786            (
1787                PartitionRange {
1788                    start: Timestamp::new(20, unit.into()),
1789                    end: Timestamp::new(30, unit.into()),
1790                    num_rows: 10,
1791                    identifier: 2,
1792                },
1793                vec![
1794                    DfRecordBatch::try_new(
1795                        schema.clone(),
1796                        vec![new_ts_array(unit, vec![21, 22, 23, 24, 25])],
1797                    )
1798                    .unwrap(),
1799                    DfRecordBatch::try_new(
1800                        schema.clone(),
1801                        vec![new_ts_array(unit, vec![26, 27, 28, 29, 30])],
1802                    )
1803                    .unwrap(),
1804                ],
1805            ),
1806            (
1807                PartitionRange {
1808                    start: Timestamp::new(10, unit.into()),
1809                    end: Timestamp::new(20, unit.into()),
1810                    num_rows: 10,
1811                    identifier: 1,
1812                },
1813                vec![
1814                    DfRecordBatch::try_new(
1815                        schema.clone(),
1816                        vec![new_ts_array(unit, vec![11, 12, 13, 14, 15])],
1817                    )
1818                    .unwrap(),
1819                    DfRecordBatch::try_new(
1820                        schema.clone(),
1821                        vec![new_ts_array(unit, vec![16, 17, 18, 19, 20])],
1822                    )
1823                    .unwrap(),
1824                ],
1825            ),
1826            (
1827                PartitionRange {
1828                    start: Timestamp::new(0, unit.into()),
1829                    end: Timestamp::new(10, unit.into()),
1830                    num_rows: 10,
1831                    identifier: 0,
1832                },
1833                vec![
1834                    DfRecordBatch::try_new(
1835                        schema.clone(),
1836                        vec![new_ts_array(unit, vec![1, 2, 3, 4, 5])],
1837                    )
1838                    .unwrap(),
1839                    DfRecordBatch::try_new(
1840                        schema.clone(),
1841                        vec![new_ts_array(unit, vec![6, 7, 8, 9, 10])],
1842                    )
1843                    .unwrap(),
1844                ],
1845            ),
1846        ];
1847
1848        // PartSort won't reorder `PartitionRange` (it assumes it's already ordered), so it will not read other partitions.
1849        // This case is just to verify that early termination works as expected.
1850        // First partition [20, 30) produces top 2 values: 29, 28
1851        let expected_output = Some(
1852            DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![29, 28])]).unwrap(),
1853        );
1854
1855        run_test(
1856            1003,
1857            input_ranged_data,
1858            schema.clone(),
1859            SortOptions {
1860                descending: true,
1861                ..Default::default()
1862            },
1863            Some(2),
1864            expected_output,
1865            Some(10),
1866        )
1867        .await;
1868    }
1869
1870    /// Example:
1871    /// - Range [70, 100) has data [80, 90, 95]
1872    /// - Range [50, 100) has data [55, 65, 75, 85, 95]
1873    #[tokio::test]
1874    async fn test_primary_end_grouping_with_limit() {
1875        let unit = TimeUnit::Millisecond;
1876        let schema = Arc::new(Schema::new(vec![Field::new(
1877            "ts",
1878            DataType::Timestamp(unit, None),
1879            false,
1880        )]));
1881
1882        // Two ranges with the same end (100) - they should be grouped together
1883        // For descending, ranges are ordered by (end DESC, start DESC)
1884        // So [70, 100) comes before [50, 100) (70 > 50)
1885        let input_ranged_data = vec![
1886            (
1887                PartitionRange {
1888                    start: Timestamp::new(70, unit.into()),
1889                    end: Timestamp::new(100, unit.into()),
1890                    num_rows: 3,
1891                    identifier: 0,
1892                },
1893                vec![
1894                    DfRecordBatch::try_new(
1895                        schema.clone(),
1896                        vec![new_ts_array(unit, vec![80, 90, 95])],
1897                    )
1898                    .unwrap(),
1899                ],
1900            ),
1901            (
1902                PartitionRange {
1903                    start: Timestamp::new(50, unit.into()),
1904                    end: Timestamp::new(100, unit.into()),
1905                    num_rows: 5,
1906                    identifier: 1,
1907                },
1908                vec![
1909                    DfRecordBatch::try_new(
1910                        schema.clone(),
1911                        vec![new_ts_array(unit, vec![55, 65, 75, 85, 95])],
1912                    )
1913                    .unwrap(),
1914                ],
1915            ),
1916        ];
1917
1918        // With limit=4, descending: top 4 values from combined data
1919        // Combined: [80, 90, 95, 55, 65, 75, 85, 95] -> sorted desc: [95, 95, 90, 85, 80, 75, 65, 55]
1920        // Top 4: [95, 95, 90, 85]
1921        let expected_output = Some(
1922            DfRecordBatch::try_new(
1923                schema.clone(),
1924                vec![new_ts_array(unit, vec![95, 95, 90, 85])],
1925            )
1926            .unwrap(),
1927        );
1928
1929        run_test(
1930            2000,
1931            input_ranged_data,
1932            schema.clone(),
1933            SortOptions {
1934                descending: true,
1935                ..Default::default()
1936            },
1937            Some(4),
1938            expected_output,
1939            None,
1940        )
1941        .await;
1942    }
1943
1944    /// Test case with three ranges demonstrating the "keep pulling" behavior.
1945    /// After processing ranges with end=100, the smallest value in top-k might still
1946    /// be reachable by the next group.
1947    ///
1948    /// Ranges: [70, 100), [50, 100), [40, 95)
1949    /// With descending sort and limit=4:
1950    /// - Group 1 (end=100): [70, 100) and [50, 100) merged
1951    /// - Group 2 (end=95): [40, 95)
1952    /// After group 1, smallest in top-4 is 85. Range [40, 95) could have values >= 85,
1953    /// so we continue to group 2.
1954    #[tokio::test]
1955    async fn test_three_ranges_keep_pulling() {
1956        let unit = TimeUnit::Millisecond;
1957        let schema = Arc::new(Schema::new(vec![Field::new(
1958            "ts",
1959            DataType::Timestamp(unit, None),
1960            false,
1961        )]));
1962
1963        // Three ranges, two with same end (100), one with different end (95)
1964        let input_ranged_data = vec![
1965            (
1966                PartitionRange {
1967                    start: Timestamp::new(70, unit.into()),
1968                    end: Timestamp::new(100, unit.into()),
1969                    num_rows: 3,
1970                    identifier: 0,
1971                },
1972                vec![
1973                    DfRecordBatch::try_new(
1974                        schema.clone(),
1975                        vec![new_ts_array(unit, vec![80, 90, 95])],
1976                    )
1977                    .unwrap(),
1978                ],
1979            ),
1980            (
1981                PartitionRange {
1982                    start: Timestamp::new(50, unit.into()),
1983                    end: Timestamp::new(100, unit.into()),
1984                    num_rows: 3,
1985                    identifier: 1,
1986                },
1987                vec![
1988                    DfRecordBatch::try_new(
1989                        schema.clone(),
1990                        vec![new_ts_array(unit, vec![55, 75, 85])],
1991                    )
1992                    .unwrap(),
1993                ],
1994            ),
1995            (
1996                PartitionRange {
1997                    start: Timestamp::new(40, unit.into()),
1998                    end: Timestamp::new(95, unit.into()),
1999                    num_rows: 3,
2000                    identifier: 2,
2001                },
2002                vec![
2003                    DfRecordBatch::try_new(
2004                        schema.clone(),
2005                        vec![new_ts_array(unit, vec![45, 65, 94])],
2006                    )
2007                    .unwrap(),
2008                ],
2009            ),
2010        ];
2011
2012        // All data: [80, 90, 95, 55, 75, 85, 45, 65, 94]
2013        // Sorted descending: [95, 94, 90, 85, 80, 75, 65, 55, 45]
2014        // With limit=4: should be top 4 largest values across all ranges: [95, 94, 90, 85]
2015        let expected_output = Some(
2016            DfRecordBatch::try_new(
2017                schema.clone(),
2018                vec![new_ts_array(unit, vec![95, 94, 90, 85])],
2019            )
2020            .unwrap(),
2021        );
2022
2023        run_test(
2024            2001,
2025            input_ranged_data,
2026            schema.clone(),
2027            SortOptions {
2028                descending: true,
2029                ..Default::default()
2030            },
2031            Some(4),
2032            expected_output,
2033            None,
2034        )
2035        .await;
2036    }
2037
2038    /// Test early termination based on threshold comparison with next group.
2039    /// When the threshold (smallest value for descending) is >= next group's primary end,
2040    /// we can stop early because the next group cannot have better values.
2041    #[tokio::test]
2042    async fn test_threshold_based_early_termination() {
2043        let unit = TimeUnit::Millisecond;
2044        let schema = Arc::new(Schema::new(vec![Field::new(
2045            "ts",
2046            DataType::Timestamp(unit, None),
2047            false,
2048        )]));
2049
2050        // Group 1 (end=100) has 6 rows, TopK will keep top 4
2051        // Group 2 (end=90) has 3 rows - should NOT be processed because
2052        // threshold (96) >= next_primary_end (90)
2053        let input_ranged_data = vec![
2054            (
2055                PartitionRange {
2056                    start: Timestamp::new(70, unit.into()),
2057                    end: Timestamp::new(100, unit.into()),
2058                    num_rows: 6,
2059                    identifier: 0,
2060                },
2061                vec![
2062                    DfRecordBatch::try_new(
2063                        schema.clone(),
2064                        vec![new_ts_array(unit, vec![94, 95, 96, 97, 98, 99])],
2065                    )
2066                    .unwrap(),
2067                ],
2068            ),
2069            (
2070                PartitionRange {
2071                    start: Timestamp::new(50, unit.into()),
2072                    end: Timestamp::new(90, unit.into()),
2073                    num_rows: 3,
2074                    identifier: 1,
2075                },
2076                vec![
2077                    DfRecordBatch::try_new(
2078                        schema.clone(),
2079                        vec![new_ts_array(unit, vec![85, 86, 87])],
2080                    )
2081                    .unwrap(),
2082                ],
2083            ),
2084        ];
2085
2086        // With limit=4, descending: top 4 from group 1 are [99, 98, 97, 96]
2087        // Threshold is 96, next group's primary_end is 90
2088        // Since 96 >= 90, we stop after group 1
2089        let expected_output = Some(
2090            DfRecordBatch::try_new(
2091                schema.clone(),
2092                vec![new_ts_array(unit, vec![99, 98, 97, 96])],
2093            )
2094            .unwrap(),
2095        );
2096
2097        run_test(
2098            2002,
2099            input_ranged_data,
2100            schema.clone(),
2101            SortOptions {
2102                descending: true,
2103                ..Default::default()
2104            },
2105            Some(4),
2106            expected_output,
2107            Some(9), // Pull both batches since all rows fall within the first range
2108        )
2109        .await;
2110    }
2111
2112    /// Test that we continue to next group when threshold is within next group's range.
2113    /// Even after fulfilling limit, if threshold < next_primary_end (descending),
2114    /// we would need to continue... but limit exhaustion stops us first.
2115    #[tokio::test]
2116    async fn test_continue_when_threshold_in_next_group_range() {
2117        let unit = TimeUnit::Millisecond;
2118        let schema = Arc::new(Schema::new(vec![Field::new(
2119            "ts",
2120            DataType::Timestamp(unit, None),
2121            false,
2122        )]));
2123
2124        // Group 1 (end=100) has 6 rows, TopK will keep top 4
2125        // Group 2 (end=98) has 3 rows - threshold (96) < 98, so next group
2126        // could theoretically have better values. Continue reading.
2127        let input_ranged_data = vec![
2128            (
2129                PartitionRange {
2130                    start: Timestamp::new(90, unit.into()),
2131                    end: Timestamp::new(100, unit.into()),
2132                    num_rows: 6,
2133                    identifier: 0,
2134                },
2135                vec![
2136                    DfRecordBatch::try_new(
2137                        schema.clone(),
2138                        vec![new_ts_array(unit, vec![94, 95, 96, 97, 98, 99])],
2139                    )
2140                    .unwrap(),
2141                ],
2142            ),
2143            (
2144                PartitionRange {
2145                    start: Timestamp::new(50, unit.into()),
2146                    end: Timestamp::new(98, unit.into()),
2147                    num_rows: 3,
2148                    identifier: 1,
2149                },
2150                vec![
2151                    // Values must be < 70 (outside group 1's range) to avoid ambiguity
2152                    DfRecordBatch::try_new(
2153                        schema.clone(),
2154                        vec![new_ts_array(unit, vec![55, 60, 65])],
2155                    )
2156                    .unwrap(),
2157                ],
2158            ),
2159        ];
2160
2161        // With limit=4, we get [99, 98, 97, 96] from group 1
2162        // Threshold is 96, next group's primary_end is 98
2163        // 96 < 98, so threshold check says "could continue"
2164        // But limit is exhausted (0), so we stop anyway
2165        let expected_output = Some(
2166            DfRecordBatch::try_new(
2167                schema.clone(),
2168                vec![new_ts_array(unit, vec![99, 98, 97, 96])],
2169            )
2170            .unwrap(),
2171        );
2172
2173        // Note: We pull 9 rows (both batches) because we need to read batch 2
2174        // to detect the group boundary, even though we stop after outputting group 1.
2175        run_test(
2176            2003,
2177            input_ranged_data,
2178            schema.clone(),
2179            SortOptions {
2180                descending: true,
2181                ..Default::default()
2182            },
2183            Some(4),
2184            expected_output,
2185            Some(9), // Pull both batches to detect boundary
2186        )
2187        .await;
2188    }
2189
2190    /// Test ascending sort with threshold-based early termination.
2191    #[tokio::test]
2192    async fn test_ascending_threshold_early_termination() {
2193        let unit = TimeUnit::Millisecond;
2194        let schema = Arc::new(Schema::new(vec![Field::new(
2195            "ts",
2196            DataType::Timestamp(unit, None),
2197            false,
2198        )]));
2199
2200        // For ascending: primary_end is start, ranges sorted by (start ASC, end ASC)
2201        // Group 1 (start=10) has 6 rows
2202        // Group 2 (start=20) has 3 rows - should NOT be processed because
2203        // threshold (13) < next_primary_end (20)
2204        let input_ranged_data = vec![
2205            (
2206                PartitionRange {
2207                    start: Timestamp::new(10, unit.into()),
2208                    end: Timestamp::new(50, unit.into()),
2209                    num_rows: 6,
2210                    identifier: 0,
2211                },
2212                vec![
2213                    DfRecordBatch::try_new(
2214                        schema.clone(),
2215                        vec![new_ts_array(unit, vec![10, 11, 12, 13, 14, 15])],
2216                    )
2217                    .unwrap(),
2218                ],
2219            ),
2220            (
2221                PartitionRange {
2222                    start: Timestamp::new(20, unit.into()),
2223                    end: Timestamp::new(60, unit.into()),
2224                    num_rows: 3,
2225                    identifier: 1,
2226                },
2227                vec![
2228                    DfRecordBatch::try_new(
2229                        schema.clone(),
2230                        vec![new_ts_array(unit, vec![25, 30, 35])],
2231                    )
2232                    .unwrap(),
2233                ],
2234            ),
2235            // still read this batch to detect group boundary(?)
2236            (
2237                PartitionRange {
2238                    start: Timestamp::new(60, unit.into()),
2239                    end: Timestamp::new(70, unit.into()),
2240                    num_rows: 2,
2241                    identifier: 1,
2242                },
2243                vec![
2244                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![60, 61])])
2245                        .unwrap(),
2246                ],
2247            ),
2248            // after boundary detected, this following one should not be read
2249            (
2250                PartitionRange {
2251                    start: Timestamp::new(61, unit.into()),
2252                    end: Timestamp::new(70, unit.into()),
2253                    num_rows: 2,
2254                    identifier: 1,
2255                },
2256                vec![
2257                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![71, 72])])
2258                        .unwrap(),
2259                ],
2260            ),
2261        ];
2262
2263        // With limit=4, ascending: top 4 (smallest) from group 1 are [10, 11, 12, 13]
2264        // Threshold is 13 (largest in top-k), next group's primary_end is 20
2265        // Since 13 < 20, we stop after group 1 (no value in group 2 can be < 13)
2266        let expected_output = Some(
2267            DfRecordBatch::try_new(
2268                schema.clone(),
2269                vec![new_ts_array(unit, vec![10, 11, 12, 13])],
2270            )
2271            .unwrap(),
2272        );
2273
2274        run_test(
2275            2004,
2276            input_ranged_data,
2277            schema.clone(),
2278            SortOptions {
2279                descending: false,
2280                ..Default::default()
2281            },
2282            Some(4),
2283            expected_output,
2284            Some(11), // Pull first two batches to detect boundary
2285        )
2286        .await;
2287    }
2288
2289    #[tokio::test]
2290    async fn test_ascending_threshold_early_termination_case_two() {
2291        let unit = TimeUnit::Millisecond;
2292        let schema = Arc::new(Schema::new(vec![Field::new(
2293            "ts",
2294            DataType::Timestamp(unit, None),
2295            false,
2296        )]));
2297
2298        // For ascending: primary_end is start, ranges sorted by (start ASC, end ASC)
2299        // Group 1 (start=0) has 4 rows, Group 2 (start=4) has 1 row, Group 3 (start=5) has 4 rows
2300        // After reading all data: [9,10,11,12, 21, 5,6,7,8]
2301        // Sorted ascending: [5,6,7,8, 9,10,11,12, 21]
2302        // With limit=4, output should be smallest 4: [5,6,7,8]
2303        // Algorithm continues reading until start=42 > threshold=8, confirming no smaller values exist
2304        let input_ranged_data = vec![
2305            (
2306                PartitionRange {
2307                    start: Timestamp::new(0, unit.into()),
2308                    end: Timestamp::new(20, unit.into()),
2309                    num_rows: 4,
2310                    identifier: 0,
2311                },
2312                vec![
2313                    DfRecordBatch::try_new(
2314                        schema.clone(),
2315                        vec![new_ts_array(unit, vec![9, 10, 11, 12])],
2316                    )
2317                    .unwrap(),
2318                ],
2319            ),
2320            (
2321                PartitionRange {
2322                    start: Timestamp::new(4, unit.into()),
2323                    end: Timestamp::new(25, unit.into()),
2324                    num_rows: 1,
2325                    identifier: 1,
2326                },
2327                vec![
2328                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![21])])
2329                        .unwrap(),
2330                ],
2331            ),
2332            (
2333                PartitionRange {
2334                    start: Timestamp::new(5, unit.into()),
2335                    end: Timestamp::new(25, unit.into()),
2336                    num_rows: 4,
2337                    identifier: 1,
2338                },
2339                vec![
2340                    DfRecordBatch::try_new(
2341                        schema.clone(),
2342                        vec![new_ts_array(unit, vec![5, 6, 7, 8])],
2343                    )
2344                    .unwrap(),
2345                ],
2346            ),
2347            // This still will be read to detect boundary, but should not contribute to output
2348            (
2349                PartitionRange {
2350                    start: Timestamp::new(42, unit.into()),
2351                    end: Timestamp::new(52, unit.into()),
2352                    num_rows: 2,
2353                    identifier: 1,
2354                },
2355                vec![
2356                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![42, 51])])
2357                        .unwrap(),
2358                ],
2359            ),
2360            // This following one should not be read after boundary detected
2361            (
2362                PartitionRange {
2363                    start: Timestamp::new(48, unit.into()),
2364                    end: Timestamp::new(53, unit.into()),
2365                    num_rows: 2,
2366                    identifier: 1,
2367                },
2368                vec![
2369                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![48, 51])])
2370                        .unwrap(),
2371                ],
2372            ),
2373        ];
2374
2375        // With limit=4, ascending: after processing all ranges, smallest 4 are [5, 6, 7, 8]
2376        // Threshold is 8 (4th smallest value), algorithm reads until start=42 > threshold=8
2377        let expected_output = Some(
2378            DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![5, 6, 7, 8])])
2379                .unwrap(),
2380        );
2381
2382        run_test(
2383            2005,
2384            input_ranged_data,
2385            schema.clone(),
2386            SortOptions {
2387                descending: false,
2388                ..Default::default()
2389            },
2390            Some(4),
2391            expected_output,
2392            Some(11), // Read first 4 ranges to confirm threshold boundary
2393        )
2394        .await;
2395    }
2396
2397    /// Test early stop behavior with null values in sort column.
2398    /// Verifies that nulls are handled correctly based on nulls_first option.
2399    #[tokio::test]
2400    async fn test_early_stop_with_nulls() {
2401        let unit = TimeUnit::Millisecond;
2402        let schema = Arc::new(Schema::new(vec![Field::new(
2403            "ts",
2404            DataType::Timestamp(unit, None),
2405            true, // nullable
2406        )]));
2407
2408        // Helper function to create nullable timestamp array
2409        let new_nullable_ts_array = |unit: TimeUnit, arr: Vec<Option<i64>>| -> ArrayRef {
2410            match unit {
2411                TimeUnit::Second => Arc::new(TimestampSecondArray::from(arr)) as ArrayRef,
2412                TimeUnit::Millisecond => Arc::new(TimestampMillisecondArray::from(arr)) as ArrayRef,
2413                TimeUnit::Microsecond => Arc::new(TimestampMicrosecondArray::from(arr)) as ArrayRef,
2414                TimeUnit::Nanosecond => Arc::new(TimestampNanosecondArray::from(arr)) as ArrayRef,
2415            }
2416        };
2417
2418        // Test case 1: nulls_first=true, null values should appear first
2419        // Group 1 (end=100): [null, null, 99, 98, 97] -> with limit=3, top 3 are [null, null, 99]
2420        // Threshold is 99, next group end=90, since 99 >= 90, we should stop early
2421        let input_ranged_data = vec![
2422            (
2423                PartitionRange {
2424                    start: Timestamp::new(70, unit.into()),
2425                    end: Timestamp::new(100, unit.into()),
2426                    num_rows: 5,
2427                    identifier: 0,
2428                },
2429                vec![
2430                    DfRecordBatch::try_new(
2431                        schema.clone(),
2432                        vec![new_nullable_ts_array(
2433                            unit,
2434                            vec![Some(99), Some(98), None, Some(97), None],
2435                        )],
2436                    )
2437                    .unwrap(),
2438                ],
2439            ),
2440            (
2441                PartitionRange {
2442                    start: Timestamp::new(50, unit.into()),
2443                    end: Timestamp::new(90, unit.into()),
2444                    num_rows: 3,
2445                    identifier: 1,
2446                },
2447                vec![
2448                    DfRecordBatch::try_new(
2449                        schema.clone(),
2450                        vec![new_nullable_ts_array(
2451                            unit,
2452                            vec![Some(89), Some(88), Some(87)],
2453                        )],
2454                    )
2455                    .unwrap(),
2456                ],
2457            ),
2458        ];
2459
2460        // With nulls_first=true, nulls sort before all values
2461        // For descending, order is: null, null, 99, 98, 97
2462        // With limit=3, we get: null, null, 99
2463        let expected_output = Some(
2464            DfRecordBatch::try_new(
2465                schema.clone(),
2466                vec![new_nullable_ts_array(unit, vec![None, None, Some(99)])],
2467            )
2468            .unwrap(),
2469        );
2470
2471        run_test(
2472            3000,
2473            input_ranged_data,
2474            schema.clone(),
2475            SortOptions {
2476                descending: true,
2477                nulls_first: true,
2478            },
2479            Some(3),
2480            expected_output,
2481            Some(8), // Must read both batches to detect group boundary
2482        )
2483        .await;
2484
2485        // Test case 2: nulls_last=true, null values should appear last
2486        // Group 1 (end=100): [99, 98, 97, null, null] -> with limit=3, top 3 are [99, 98, 97]
2487        // Threshold is 97, next group end=90, since 97 >= 90, we should stop early
2488        let input_ranged_data = vec![
2489            (
2490                PartitionRange {
2491                    start: Timestamp::new(70, unit.into()),
2492                    end: Timestamp::new(100, unit.into()),
2493                    num_rows: 5,
2494                    identifier: 0,
2495                },
2496                vec![
2497                    DfRecordBatch::try_new(
2498                        schema.clone(),
2499                        vec![new_nullable_ts_array(
2500                            unit,
2501                            vec![Some(99), Some(98), Some(97), None, None],
2502                        )],
2503                    )
2504                    .unwrap(),
2505                ],
2506            ),
2507            (
2508                PartitionRange {
2509                    start: Timestamp::new(50, unit.into()),
2510                    end: Timestamp::new(90, unit.into()),
2511                    num_rows: 3,
2512                    identifier: 1,
2513                },
2514                vec![
2515                    DfRecordBatch::try_new(
2516                        schema.clone(),
2517                        vec![new_nullable_ts_array(
2518                            unit,
2519                            vec![Some(89), Some(88), Some(87)],
2520                        )],
2521                    )
2522                    .unwrap(),
2523                ],
2524            ),
2525        ];
2526
2527        // With nulls_last=false (equivalent to nulls_first=false), values sort before nulls
2528        // For descending, order is: 99, 98, 97, null, null
2529        // With limit=3, we get: 99, 98, 97
2530        let expected_output = Some(
2531            DfRecordBatch::try_new(
2532                schema.clone(),
2533                vec![new_nullable_ts_array(
2534                    unit,
2535                    vec![Some(99), Some(98), Some(97)],
2536                )],
2537            )
2538            .unwrap(),
2539        );
2540
2541        run_test(
2542            3001,
2543            input_ranged_data,
2544            schema.clone(),
2545            SortOptions {
2546                descending: true,
2547                nulls_first: false,
2548            },
2549            Some(3),
2550            expected_output,
2551            Some(8), // Must read both batches to detect group boundary
2552        )
2553        .await;
2554    }
2555
2556    /// Test early stop behavior when there's only one group (no next group).
2557    /// In this case, can_stop_early should return false and we should process all data.
2558    #[tokio::test]
2559    async fn test_early_stop_single_group() {
2560        let unit = TimeUnit::Millisecond;
2561        let schema = Arc::new(Schema::new(vec![Field::new(
2562            "ts",
2563            DataType::Timestamp(unit, None),
2564            false,
2565        )]));
2566
2567        // Only one group (all ranges have the same end), no next group to compare against
2568        let input_ranged_data = vec![
2569            (
2570                PartitionRange {
2571                    start: Timestamp::new(70, unit.into()),
2572                    end: Timestamp::new(100, unit.into()),
2573                    num_rows: 6,
2574                    identifier: 0,
2575                },
2576                vec![
2577                    DfRecordBatch::try_new(
2578                        schema.clone(),
2579                        vec![new_ts_array(unit, vec![94, 95, 96, 97, 98, 99])],
2580                    )
2581                    .unwrap(),
2582                ],
2583            ),
2584            (
2585                PartitionRange {
2586                    start: Timestamp::new(50, unit.into()),
2587                    end: Timestamp::new(100, unit.into()),
2588                    num_rows: 3,
2589                    identifier: 1,
2590                },
2591                vec![
2592                    DfRecordBatch::try_new(
2593                        schema.clone(),
2594                        vec![new_ts_array(unit, vec![85, 86, 87])],
2595                    )
2596                    .unwrap(),
2597                ],
2598            ),
2599        ];
2600
2601        // Even though we have enough data in first range, we must process all
2602        // because there's no next group to compare threshold against
2603        let expected_output = Some(
2604            DfRecordBatch::try_new(
2605                schema.clone(),
2606                vec![new_ts_array(unit, vec![99, 98, 97, 96])],
2607            )
2608            .unwrap(),
2609        );
2610
2611        run_test(
2612            3002,
2613            input_ranged_data,
2614            schema.clone(),
2615            SortOptions {
2616                descending: true,
2617                ..Default::default()
2618            },
2619            Some(4),
2620            expected_output,
2621            Some(9), // Must read all batches since no early stop is possible
2622        )
2623        .await;
2624    }
2625
2626    /// Test early stop behavior when threshold exactly equals next group's boundary.
2627    #[tokio::test]
2628    async fn test_early_stop_exact_boundary_equality() {
2629        let unit = TimeUnit::Millisecond;
2630        let schema = Arc::new(Schema::new(vec![Field::new(
2631            "ts",
2632            DataType::Timestamp(unit, None),
2633            false,
2634        )]));
2635
2636        // Test case 1: Descending sort, threshold == next_group_end
2637        // Group 1 (end=100): data up to 90, threshold = 90, next_group_end = 90
2638        // Since 90 >= 90, we should stop early
2639        let input_ranged_data = vec![
2640            (
2641                PartitionRange {
2642                    start: Timestamp::new(70, unit.into()),
2643                    end: Timestamp::new(100, unit.into()),
2644                    num_rows: 4,
2645                    identifier: 0,
2646                },
2647                vec![
2648                    DfRecordBatch::try_new(
2649                        schema.clone(),
2650                        vec![new_ts_array(unit, vec![92, 91, 90, 89])],
2651                    )
2652                    .unwrap(),
2653                ],
2654            ),
2655            (
2656                PartitionRange {
2657                    start: Timestamp::new(50, unit.into()),
2658                    end: Timestamp::new(90, unit.into()),
2659                    num_rows: 3,
2660                    identifier: 1,
2661                },
2662                vec![
2663                    DfRecordBatch::try_new(
2664                        schema.clone(),
2665                        vec![new_ts_array(unit, vec![88, 87, 86])],
2666                    )
2667                    .unwrap(),
2668                ],
2669            ),
2670        ];
2671
2672        let expected_output = Some(
2673            DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![92, 91, 90])])
2674                .unwrap(),
2675        );
2676
2677        run_test(
2678            3003,
2679            input_ranged_data,
2680            schema.clone(),
2681            SortOptions {
2682                descending: true,
2683                ..Default::default()
2684            },
2685            Some(3),
2686            expected_output,
2687            Some(7), // Must read both batches to detect boundary
2688        )
2689        .await;
2690
2691        // Test case 2: Ascending sort, threshold == next_group_start
2692        // Group 1 (start=10): data from 10, threshold = 20, next_group_start = 20
2693        // Since 20 < 20 is false, we should continue
2694        let input_ranged_data = vec![
2695            (
2696                PartitionRange {
2697                    start: Timestamp::new(10, unit.into()),
2698                    end: Timestamp::new(50, unit.into()),
2699                    num_rows: 4,
2700                    identifier: 0,
2701                },
2702                vec![
2703                    DfRecordBatch::try_new(
2704                        schema.clone(),
2705                        vec![new_ts_array(unit, vec![10, 15, 20, 25])],
2706                    )
2707                    .unwrap(),
2708                ],
2709            ),
2710            (
2711                PartitionRange {
2712                    start: Timestamp::new(20, unit.into()),
2713                    end: Timestamp::new(60, unit.into()),
2714                    num_rows: 3,
2715                    identifier: 1,
2716                },
2717                vec![
2718                    DfRecordBatch::try_new(
2719                        schema.clone(),
2720                        vec![new_ts_array(unit, vec![21, 22, 23])],
2721                    )
2722                    .unwrap(),
2723                ],
2724            ),
2725        ];
2726
2727        let expected_output = Some(
2728            DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![10, 15, 20])])
2729                .unwrap(),
2730        );
2731
2732        run_test(
2733            3004,
2734            input_ranged_data,
2735            schema.clone(),
2736            SortOptions {
2737                descending: false,
2738                ..Default::default()
2739            },
2740            Some(3),
2741            expected_output,
2742            Some(7), // Must read both batches since 20 is not < 20
2743        )
2744        .await;
2745    }
2746
2747    /// Test early stop behavior with empty partition groups.
2748    #[tokio::test]
2749    async fn test_early_stop_with_empty_partitions() {
2750        let unit = TimeUnit::Millisecond;
2751        let schema = Arc::new(Schema::new(vec![Field::new(
2752            "ts",
2753            DataType::Timestamp(unit, None),
2754            false,
2755        )]));
2756
2757        // Test case 1: First group is empty, second group has data
2758        let input_ranged_data = vec![
2759            (
2760                PartitionRange {
2761                    start: Timestamp::new(70, unit.into()),
2762                    end: Timestamp::new(100, unit.into()),
2763                    num_rows: 0,
2764                    identifier: 0,
2765                },
2766                vec![
2767                    // Empty batch for first range
2768                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![])])
2769                        .unwrap(),
2770                ],
2771            ),
2772            (
2773                PartitionRange {
2774                    start: Timestamp::new(50, unit.into()),
2775                    end: Timestamp::new(100, unit.into()),
2776                    num_rows: 0,
2777                    identifier: 1,
2778                },
2779                vec![
2780                    // Empty batch for second range
2781                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![])])
2782                        .unwrap(),
2783                ],
2784            ),
2785            (
2786                PartitionRange {
2787                    start: Timestamp::new(30, unit.into()),
2788                    end: Timestamp::new(80, unit.into()),
2789                    num_rows: 4,
2790                    identifier: 2,
2791                },
2792                vec![
2793                    DfRecordBatch::try_new(
2794                        schema.clone(),
2795                        vec![new_ts_array(unit, vec![74, 75, 76, 77])],
2796                    )
2797                    .unwrap(),
2798                ],
2799            ),
2800            (
2801                PartitionRange {
2802                    start: Timestamp::new(10, unit.into()),
2803                    end: Timestamp::new(60, unit.into()),
2804                    num_rows: 3,
2805                    identifier: 3,
2806                },
2807                vec![
2808                    DfRecordBatch::try_new(
2809                        schema.clone(),
2810                        vec![new_ts_array(unit, vec![58, 59, 60])],
2811                    )
2812                    .unwrap(),
2813                ],
2814            ),
2815        ];
2816
2817        // Group 1 (end=100) is empty, Group 2 (end=80) has data
2818        // Should continue to Group 2 since Group 1 has no data
2819        let expected_output = Some(
2820            DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![77, 76])]).unwrap(),
2821        );
2822
2823        run_test(
2824            3005,
2825            input_ranged_data,
2826            schema.clone(),
2827            SortOptions {
2828                descending: true,
2829                ..Default::default()
2830            },
2831            Some(2),
2832            expected_output,
2833            Some(7), // Must read until finding actual data
2834        )
2835        .await;
2836
2837        // Test case 2: Empty partitions between data groups
2838        let input_ranged_data = vec![
2839            (
2840                PartitionRange {
2841                    start: Timestamp::new(70, unit.into()),
2842                    end: Timestamp::new(100, unit.into()),
2843                    num_rows: 4,
2844                    identifier: 0,
2845                },
2846                vec![
2847                    DfRecordBatch::try_new(
2848                        schema.clone(),
2849                        vec![new_ts_array(unit, vec![96, 97, 98, 99])],
2850                    )
2851                    .unwrap(),
2852                ],
2853            ),
2854            (
2855                PartitionRange {
2856                    start: Timestamp::new(50, unit.into()),
2857                    end: Timestamp::new(90, unit.into()),
2858                    num_rows: 0,
2859                    identifier: 1,
2860                },
2861                vec![
2862                    // Empty range - should be skipped
2863                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![])])
2864                        .unwrap(),
2865                ],
2866            ),
2867            (
2868                PartitionRange {
2869                    start: Timestamp::new(30, unit.into()),
2870                    end: Timestamp::new(70, unit.into()),
2871                    num_rows: 0,
2872                    identifier: 2,
2873                },
2874                vec![
2875                    // Another empty range
2876                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![])])
2877                        .unwrap(),
2878                ],
2879            ),
2880            (
2881                PartitionRange {
2882                    start: Timestamp::new(10, unit.into()),
2883                    end: Timestamp::new(50, unit.into()),
2884                    num_rows: 3,
2885                    identifier: 3,
2886                },
2887                vec![
2888                    DfRecordBatch::try_new(
2889                        schema.clone(),
2890                        vec![new_ts_array(unit, vec![48, 49, 50])],
2891                    )
2892                    .unwrap(),
2893                ],
2894            ),
2895        ];
2896
2897        // With limit=2 from group 1: [99, 98], threshold=98, next group end=50
2898        // Since 98 >= 50, we should stop early
2899        let expected_output = Some(
2900            DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![99, 98])]).unwrap(),
2901        );
2902
2903        run_test(
2904            3006,
2905            input_ranged_data,
2906            schema.clone(),
2907            SortOptions {
2908                descending: true,
2909                ..Default::default()
2910            },
2911            Some(2),
2912            expected_output,
2913            Some(7), // Must read to detect early stop condition
2914        )
2915        .await;
2916    }
2917
2918    /// First group: [0,20), data: [0, 5, 15]
2919    /// Second group: [10, 30), data: [21, 25, 29]
2920    /// after first group, calling early stop manually, and check if filter is updated
2921    #[tokio::test]
2922    async fn test_early_stop_check_update_dyn_filter() {
2923        let unit = TimeUnit::Millisecond;
2924        let schema = Arc::new(Schema::new(vec![Field::new(
2925            "ts",
2926            DataType::Timestamp(unit, None),
2927            false,
2928        )]));
2929
2930        let mock_input = Arc::new(MockInputExec::new(vec![vec![]], schema.clone()));
2931        let exec = PartSortExec::try_new(
2932            PhysicalSortExpr {
2933                expr: Arc::new(Column::new("ts", 0)),
2934                options: SortOptions {
2935                    descending: false,
2936                    ..Default::default()
2937                },
2938            },
2939            Some(3),
2940            vec![vec![
2941                PartitionRange {
2942                    start: Timestamp::new(0, unit.into()),
2943                    end: Timestamp::new(20, unit.into()),
2944                    num_rows: 3,
2945                    identifier: 1,
2946                },
2947                PartitionRange {
2948                    start: Timestamp::new(10, unit.into()),
2949                    end: Timestamp::new(30, unit.into()),
2950                    num_rows: 3,
2951                    identifier: 1,
2952                },
2953            ]],
2954            mock_input.clone(),
2955        )
2956        .unwrap();
2957
2958        let filter = exec.filter.clone().unwrap();
2959        let input_stream = mock_input
2960            .execute(0, Arc::new(TaskContext::default()))
2961            .unwrap();
2962        let mut stream = PartSortStream::new(
2963            Arc::new(TaskContext::default()),
2964            &exec,
2965            Some(3),
2966            input_stream,
2967            vec![],
2968            0,
2969            Some(filter.clone()),
2970        )
2971        .unwrap();
2972
2973        // initially, snapshot_generation is 1
2974        assert_eq!(filter.read().expr().snapshot_generation(), 1);
2975        let batch =
2976            DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![0, 5, 15])])
2977                .unwrap();
2978        stream.push_buffer(batch).unwrap();
2979
2980        // after pushing first batch, snapshot_generation is updated to 2
2981        assert_eq!(filter.read().expr().snapshot_generation(), 2);
2982        assert!(!stream.can_stop_early(&schema).unwrap());
2983        // still two as not updated
2984        assert_eq!(filter.read().expr().snapshot_generation(), 2);
2985
2986        let _ = stream.sort_top_buffer().unwrap();
2987
2988        let batch =
2989            DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![21, 25, 29])])
2990                .unwrap();
2991        stream.push_buffer(batch).unwrap();
2992        // still two as not updated
2993        assert_eq!(filter.read().expr().snapshot_generation(), 2);
2994        let new = stream.sort_top_buffer().unwrap();
2995        // still two as not updated
2996        assert_eq!(filter.read().expr().snapshot_generation(), 2);
2997
2998        // dyn filter kick in, and filter out all rows >= 15(the filter is rows<15)
2999        assert_eq!(new.num_rows(), 0)
3000    }
3001}