Skip to main content

query/
part_sort.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Module for sorting input data within each [`PartitionRange`].
16//!
17//! This module defines the [`PartSortExec`] execution plan, which sorts each
18//! partition ([`PartitionRange`]) independently based on the provided physical
19//! sort expressions.
20
21use std::any::Any;
22use std::pin::Pin;
23use std::sync::Arc;
24use std::task::{Context, Poll};
25
26use arrow::array::{Array, ArrayRef};
27use arrow::compute::{concat, concat_batches, take_record_batch};
28use arrow_schema::{DataType, SchemaRef, TimeUnit};
29use common_recordbatch::{DfRecordBatch, DfSendableRecordBatchStream};
30use common_time::Timestamp;
31use datafusion::common::arrow::compute::sort_to_indices;
32use datafusion::execution::memory_pool::{MemoryConsumer, MemoryReservation};
33use datafusion::execution::{RecordBatchStream, TaskContext};
34use datafusion::physical_plan::execution_plan::CardinalityEffect;
35use datafusion::physical_plan::filter_pushdown::{
36    ChildFilterDescription, FilterDescription, FilterPushdownPhase,
37};
38use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
39use datafusion::physical_plan::{
40    DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties,
41};
42use datafusion_common::{DataFusionError, ScalarValue, internal_err};
43use datafusion_expr::Operator;
44use datafusion_physical_expr::expressions::{
45    BinaryExpr, DynamicFilterPhysicalExpr, is_not_null, is_null, lit,
46};
47use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
48use futures::Stream;
49use itertools::Itertools;
50use snafu::location;
51use store_api::region_engine::PartitionRange;
52
53use crate::error::Result;
54use crate::window_sort::{check_partition_range_monotonicity, project_partition_range_for_sort};
55use crate::{array_iter_helper, downcast_ts_array};
56
57/// Get the primary end of a `PartitionRange` based on sort direction.
58///
59/// - Descending: primary end is `end` (we process highest values first)
60/// - Ascending: primary end is `start` (we process lowest values first)
61fn get_primary_end(range: &PartitionRange, descending: bool) -> Timestamp {
62    if descending { range.end } else { range.start }
63}
64
65/// Group consecutive ranges by their primary end value.
66///
67/// Returns a vector of (primary_end, start_idx_inclusive, end_idx_exclusive) tuples.
68/// Ranges with the same primary end MUST be processed together because they may
69/// overlap and contain values that belong to the same "top-k" result.
70fn group_ranges_by_primary_end(
71    ranges: &[PartitionRange],
72    descending: bool,
73) -> Vec<(Timestamp, usize, usize)> {
74    if ranges.is_empty() {
75        return vec![];
76    }
77
78    let mut groups = Vec::new();
79    let mut group_start = 0;
80    let mut current_primary_end = get_primary_end(&ranges[0], descending);
81
82    for (idx, range) in ranges.iter().enumerate().skip(1) {
83        let primary_end = get_primary_end(range, descending);
84        if primary_end != current_primary_end {
85            // End current group
86            groups.push((current_primary_end, group_start, idx));
87            // Start new group
88            group_start = idx;
89            current_primary_end = primary_end;
90        }
91    }
92    // Push the last group
93    groups.push((current_primary_end, group_start, ranges.len()));
94
95    groups
96}
97
98/// Sort input within given PartitionRange
99///
100/// Partition range transitions are detected by comparing sort column values against
101/// the current range boundaries (via [`PartSortStream::try_find_next_range`]).
102/// Empty RecordBatches from upstream are tolerated but do not serve as range delimiters.
103///
104/// This operator sorts each partition independently.
105#[derive(Debug, Clone)]
106pub struct PartSortExec {
107    /// Physical sort expressions(that is, sort by timestamp)
108    expression: PhysicalSortExpr,
109    limit: Option<usize>,
110    input: Arc<dyn ExecutionPlan>,
111    /// Execution metrics
112    metrics: ExecutionPlanMetricsSet,
113    partition_ranges: Vec<Vec<PartitionRange>>,
114    properties: Arc<PlanProperties>,
115    dynamic_filter: Option<Arc<DynamicFilterPhysicalExpr>>,
116}
117
118impl PartSortExec {
119    pub fn try_new(
120        expression: PhysicalSortExpr,
121        limit: Option<usize>,
122        partition_ranges: Vec<Vec<PartitionRange>>,
123        input: Arc<dyn ExecutionPlan>,
124    ) -> Result<Self> {
125        check_partition_range_monotonicity(&partition_ranges, expression.options.descending)?;
126
127        let metrics = ExecutionPlanMetricsSet::new();
128        let properties = input.properties();
129        let properties = Arc::new(PlanProperties::new(
130            input.equivalence_properties().clone(),
131            input.output_partitioning().clone(),
132            properties.emission_type,
133            properties.boundedness,
134        ));
135
136        let dynamic_filter = Self::new_dynamic_filter(&expression, limit);
137
138        Ok(Self {
139            expression,
140            limit,
141            input,
142            metrics,
143            partition_ranges,
144            properties,
145            dynamic_filter,
146        })
147    }
148
149    fn new_dynamic_filter(
150        expression: &PhysicalSortExpr,
151        limit: Option<usize>,
152    ) -> Option<Arc<DynamicFilterPhysicalExpr>> {
153        limit.map(|_| {
154            Arc::new(DynamicFilterPhysicalExpr::new(
155                vec![expression.expr.clone()],
156                lit(true),
157            ))
158        })
159    }
160
161    pub fn to_stream(
162        &self,
163        context: Arc<TaskContext>,
164        partition: usize,
165    ) -> datafusion_common::Result<DfSendableRecordBatchStream> {
166        let input_stream: DfSendableRecordBatchStream =
167            self.input.execute(partition, context.clone())?;
168
169        if partition >= self.partition_ranges.len() {
170            internal_err!(
171                "Partition index out of range: {} >= {} at {}",
172                partition,
173                self.partition_ranges.len(),
174                snafu::location!()
175            )?;
176        }
177
178        let df_stream = Box::pin(PartSortStream::new(
179            context,
180            self,
181            self.limit,
182            input_stream,
183            self.partition_ranges[partition].clone(),
184            partition,
185        )?) as _;
186
187        Ok(df_stream)
188    }
189}
190
191impl DisplayAs for PartSortExec {
192    fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
193        write!(
194            f,
195            "PartSortExec: expr={} num_ranges={}",
196            self.expression,
197            self.partition_ranges.len(),
198        )?;
199        if let Some(limit) = self.limit {
200            write!(f, " limit={}", limit)?;
201        }
202        Ok(())
203    }
204}
205
206impl ExecutionPlan for PartSortExec {
207    fn name(&self) -> &str {
208        "PartSortExec"
209    }
210
211    fn as_any(&self) -> &dyn Any {
212        self
213    }
214
215    fn schema(&self) -> SchemaRef {
216        self.input.schema()
217    }
218
219    fn properties(&self) -> &Arc<PlanProperties> {
220        &self.properties
221    }
222
223    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
224        vec![&self.input]
225    }
226
227    fn with_new_children(
228        self: Arc<Self>,
229        children: Vec<Arc<dyn ExecutionPlan>>,
230    ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
231        let new_input = if let Some(first) = children.first() {
232            first
233        } else {
234            internal_err!("No children found")?
235        };
236        let mut new_exec = self.as_ref().clone();
237        new_exec.input = new_input.clone();
238        new_exec.properties = new_input.properties().clone();
239        Ok(Arc::new(new_exec))
240    }
241
242    fn execute(
243        &self,
244        partition: usize,
245        context: Arc<TaskContext>,
246    ) -> datafusion_common::Result<DfSendableRecordBatchStream> {
247        self.to_stream(context, partition)
248    }
249
250    fn metrics(&self) -> Option<MetricsSet> {
251        Some(self.metrics.clone_inner())
252    }
253
254    /// # Explain
255    ///
256    /// This plan needs to be executed on each partition independently,
257    /// and is expected to run directly on storage engine's output
258    /// distribution / partition.
259    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
260        vec![false]
261    }
262
263    fn cardinality_effect(&self) -> CardinalityEffect {
264        if self.limit.is_none() {
265            CardinalityEffect::Equal
266        } else {
267            CardinalityEffect::LowerEqual
268        }
269    }
270
271    fn gather_filters_for_pushdown(
272        &self,
273        phase: FilterPushdownPhase,
274        parent_filters: Vec<Arc<dyn PhysicalExpr>>,
275        config: &datafusion::config::ConfigOptions,
276    ) -> datafusion_common::Result<FilterDescription> {
277        if !matches!(phase, FilterPushdownPhase::Post) {
278            return FilterDescription::from_children(parent_filters, &self.children());
279        }
280
281        let mut child = ChildFilterDescription::from_child(&parent_filters, &self.input)?;
282        if let Some(filter) = &self.dynamic_filter
283            && config.optimizer.enable_topk_dynamic_filter_pushdown
284        {
285            let filter: Arc<dyn PhysicalExpr> = filter.clone();
286            child = child.with_self_filter(filter);
287        }
288        Ok(FilterDescription::new().with_child(child))
289    }
290
291    fn reset_state(self: Arc<Self>) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
292        let dynamic_filter = Self::new_dynamic_filter(&self.expression, self.limit);
293        Ok(Arc::new(Self {
294            expression: self.expression.clone(),
295            limit: self.limit,
296            input: self.input.clone(),
297            metrics: self.metrics.clone(),
298            partition_ranges: self.partition_ranges.clone(),
299            properties: self.properties.clone(),
300            dynamic_filter,
301        }))
302    }
303}
304
305enum PartSortBuffer {
306    All(Vec<DfRecordBatch>),
307    TopK(Vec<DfRecordBatch>),
308}
309
310#[derive(Clone, Debug, PartialEq, Eq)]
311enum TopKThreshold {
312    Null,
313    Value(i64),
314}
315
316impl PartSortBuffer {
317    pub fn is_empty(&self) -> bool {
318        match self {
319            PartSortBuffer::All(v) => v.is_empty(),
320            PartSortBuffer::TopK(v) => v.is_empty(),
321        }
322    }
323
324    pub fn num_rows(&self) -> usize {
325        match self {
326            PartSortBuffer::All(v) => v.iter().map(|batch| batch.num_rows()).sum(),
327            PartSortBuffer::TopK(v) => v.iter().map(|batch| batch.num_rows()).sum(),
328        }
329    }
330}
331
332struct PartSortStream {
333    /// Memory pool for this stream
334    reservation: MemoryReservation,
335    buffer: PartSortBuffer,
336    expression: PhysicalSortExpr,
337    limit: Option<usize>,
338    input: DfSendableRecordBatchStream,
339    input_complete: bool,
340    schema: SchemaRef,
341    partition_ranges: Vec<PartitionRange>,
342    #[allow(dead_code)] // this is used under #[debug_assertions]
343    partition: usize,
344    cur_part_idx: usize,
345    evaluating_batch: Option<DfRecordBatch>,
346    metrics: BaselineMetrics,
347    dynamic_filter: Option<Arc<DynamicFilterPhysicalExpr>>,
348    dynamic_filter_threshold: Option<TopKThreshold>,
349    /// Groups of ranges by primary end: (primary_end, start_idx_inclusive, end_idx_exclusive).
350    /// Ranges in the same group must be processed together before outputting results.
351    range_groups: Vec<(Timestamp, usize, usize)>,
352    /// Current group being processed (index into range_groups).
353    cur_group_idx: usize,
354}
355
356impl PartSortStream {
357    fn new(
358        context: Arc<TaskContext>,
359        sort: &PartSortExec,
360        limit: Option<usize>,
361        input: DfSendableRecordBatchStream,
362        partition_ranges: Vec<PartitionRange>,
363        partition: usize,
364    ) -> datafusion_common::Result<Self> {
365        let buffer = if limit.is_some() {
366            PartSortBuffer::TopK(Vec::new())
367        } else {
368            PartSortBuffer::All(Vec::new())
369        };
370
371        // Compute range groups by primary end
372        let descending = sort.expression.options.descending;
373        let range_groups = group_ranges_by_primary_end(&partition_ranges, descending);
374
375        Ok(Self {
376            reservation: MemoryConsumer::new("PartSortStream".to_string())
377                .register(&context.runtime_env().memory_pool),
378            buffer,
379            expression: sort.expression.clone(),
380            limit,
381            input,
382            input_complete: false,
383            schema: sort.input.schema(),
384            partition_ranges,
385            partition,
386            cur_part_idx: 0,
387            evaluating_batch: None,
388            metrics: BaselineMetrics::new(&sort.metrics, partition),
389            dynamic_filter: sort.dynamic_filter.clone(),
390            dynamic_filter_threshold: None,
391            range_groups,
392            cur_group_idx: 0,
393        })
394    }
395}
396
397macro_rules! array_check_helper {
398    ($t:ty, $unit:expr, $arr:expr, $cur_range:expr, $min_max_idx:expr) => {{
399            if $cur_range.start.unit().as_arrow_time_unit() != $unit
400            || $cur_range.end.unit().as_arrow_time_unit() != $unit
401        {
402            internal_err!(
403                "PartitionRange unit mismatch, expect {:?}, found {:?}",
404                $cur_range.start.unit(),
405                $unit
406            )?;
407        }
408        let arr = $arr
409            .as_any()
410            .downcast_ref::<arrow::array::PrimitiveArray<$t>>()
411            .unwrap();
412
413        let min = arr.value($min_max_idx.0);
414        let max = arr.value($min_max_idx.1);
415        let (min, max) = if min < max{
416            (min, max)
417        } else {
418            (max, min)
419        };
420        let cur_min = $cur_range.start.value();
421        let cur_max = $cur_range.end.value();
422        // note that PartitionRange is left inclusive and right exclusive
423        if !(min >= cur_min && max < cur_max) {
424            internal_err!(
425                "Sort column min/max value out of partition range: sort_column.min_max=[{:?}, {:?}] not in PartitionRange=[{:?}, {:?}]",
426                min,
427                max,
428                cur_min,
429                cur_max
430            )?;
431        }
432    }};
433}
434
435macro_rules! threshold_helper {
436    ($t:ty, $unit:expr, $arr:expr, $threshold_idx:expr) => {{
437        let arr = $arr
438            .as_any()
439            .downcast_ref::<arrow::array::PrimitiveArray<$t>>()
440            .unwrap();
441        if arr.is_null($threshold_idx) {
442            TopKThreshold::Null
443        } else {
444            TopKThreshold::Value(arr.value($threshold_idx))
445        }
446    }};
447}
448
449impl PartSortStream {
450    /// check whether the sort column's min/max value is within the current group's effective range.
451    /// For group-based processing, data from multiple ranges with the same primary end
452    /// is accumulated together, so we check against the union of all ranges in the group.
453    fn check_in_range(
454        &self,
455        sort_column: &ArrayRef,
456        min_max_idx: (usize, usize),
457    ) -> datafusion_common::Result<()> {
458        // Use the group's effective range instead of the current partition range
459        let Some(cur_range) = self.get_current_group_effective_range() else {
460            internal_err!(
461                "No effective range for current group {} at {}",
462                self.cur_group_idx,
463                snafu::location!()
464            )?
465        };
466        let cur_range = project_partition_range_for_sort(cur_range, sort_column.data_type())?;
467
468        downcast_ts_array!(
469            sort_column.data_type() => (array_check_helper, sort_column, cur_range, min_max_idx),
470            _ => internal_err!(
471                "Unsupported data type for sort column: {:?}",
472                sort_column.data_type()
473            )?,
474        );
475
476        Ok(())
477    }
478
479    /// Try find data whose value exceeds the current partition range.
480    ///
481    /// Returns `None` if no such data is found, and `Some(idx)` where idx points to
482    /// the first data that exceeds the current partition range.
483    fn try_find_next_range(
484        &self,
485        sort_column: &ArrayRef,
486    ) -> datafusion_common::Result<Option<usize>> {
487        if sort_column.is_empty() {
488            return Ok(None);
489        }
490
491        // check if the current partition index is out of range
492        if self.cur_part_idx >= self.partition_ranges.len() {
493            internal_err!(
494                "Partition index out of range: {} >= {} at {}",
495                self.cur_part_idx,
496                self.partition_ranges.len(),
497                snafu::location!()
498            )?;
499        }
500        let cur_range = project_partition_range_for_sort(
501            self.partition_ranges[self.cur_part_idx],
502            sort_column.data_type(),
503        )?;
504
505        let sort_column_iter = downcast_ts_array!(
506            sort_column.data_type() => (array_iter_helper, sort_column),
507            _ => internal_err!(
508                "Unsupported data type for sort column: {:?}",
509                sort_column.data_type()
510            )?,
511        );
512
513        for (idx, val) in sort_column_iter {
514            // ignore vacant time index data
515            if let Some(val) = val
516                && (val >= cur_range.end.value() || val < cur_range.start.value())
517            {
518                return Ok(Some(idx));
519            }
520        }
521
522        Ok(None)
523    }
524
525    fn push_buffer(
526        &mut self,
527        batch: DfRecordBatch,
528        sort_data_type: &DataType,
529    ) -> datafusion_common::Result<()> {
530        let topk = matches!(self.buffer, PartSortBuffer::TopK(_));
531        match &mut self.buffer {
532            PartSortBuffer::All(v) => v.push(batch),
533            PartSortBuffer::TopK(v) => v.push(batch),
534        }
535
536        if topk {
537            let threshold = self.compact_topk_buffer(sort_data_type)?;
538            self.update_dynamic_filter(sort_data_type, threshold)?;
539        }
540
541        Ok(())
542    }
543
544    fn compact_topk_buffer(
545        &mut self,
546        sort_data_type: &DataType,
547    ) -> datafusion_common::Result<Option<TopKThreshold>> {
548        let Some(limit) = self.limit else {
549            return Ok(None);
550        };
551
552        let PartSortBuffer::TopK(buffer) =
553            std::mem::replace(&mut self.buffer, PartSortBuffer::TopK(Vec::new()))
554        else {
555            return Ok(None);
556        };
557
558        if limit == 0 || buffer.is_empty() {
559            self.buffer = PartSortBuffer::TopK(Vec::new());
560            return Ok(None);
561        }
562
563        let total_rows: usize = buffer.iter().map(|batch| batch.num_rows()).sum();
564        if total_rows <= limit {
565            self.buffer = PartSortBuffer::TopK(buffer);
566            return Ok(None);
567        }
568
569        let topk = self.sort_record_batches(&buffer, Some(limit), false)?;
570        let threshold = self.threshold_from_sorted_batch(&topk, sort_data_type)?;
571        self.buffer = if topk.num_rows() == 0 {
572            PartSortBuffer::TopK(Vec::new())
573        } else {
574            PartSortBuffer::TopK(vec![topk])
575        };
576
577        Ok(threshold)
578    }
579
580    fn threshold_from_sorted_batch(
581        &self,
582        batch: &DfRecordBatch,
583        sort_data_type: &DataType,
584    ) -> datafusion_common::Result<Option<TopKThreshold>> {
585        if batch.num_rows() == 0 {
586            return Ok(None);
587        }
588
589        let threshold_idx = batch.num_rows() - 1;
590        let sort_column = self.expression.evaluate_to_sort_column(batch)?.values;
591        let threshold = downcast_ts_array!(
592            sort_data_type => (threshold_helper, sort_column, threshold_idx),
593            _ => internal_err!(
594                "Unsupported data type for sort column: {:?}",
595                sort_data_type
596            )?,
597        );
598
599        Ok(Some(threshold))
600    }
601
602    fn topk_threshold(
603        &self,
604        sort_data_type: &arrow_schema::DataType,
605    ) -> datafusion_common::Result<Option<TopKThreshold>> {
606        let Some(limit) = self.limit else {
607            return Ok(None);
608        };
609
610        if limit == 0 || self.buffer.num_rows() < limit {
611            return Ok(None);
612        }
613
614        let buffer = match &self.buffer {
615            PartSortBuffer::All(buffer) | PartSortBuffer::TopK(buffer) => buffer,
616        };
617        let mut sort_columns = Vec::with_capacity(buffer.len());
618        let mut opt = None;
619        for batch in buffer {
620            let sort_column = self.expression.evaluate_to_sort_column(batch)?;
621            opt = opt.or(sort_column.options);
622            sort_columns.push(sort_column.values);
623        }
624
625        let sort_column =
626            concat(&sort_columns.iter().map(|a| a.as_ref()).collect_vec()).map_err(|e| {
627                DataFusionError::ArrowError(
628                    Box::new(e),
629                    Some(format!("Fail to concat sort columns at {}", location!())),
630                )
631            })?;
632
633        let indices = sort_to_indices(&sort_column, opt, Some(limit)).map_err(|e| {
634            DataFusionError::ArrowError(
635                Box::new(e),
636                Some(format!("Fail to sort to indices at {}", location!())),
637            )
638        })?;
639
640        if indices.len() < limit {
641            return Ok(None);
642        }
643
644        let threshold_idx = indices.value(indices.len() - 1) as usize;
645        let threshold = downcast_ts_array!(
646            sort_data_type => (threshold_helper, sort_column, threshold_idx),
647            _ => internal_err!(
648                "Unsupported data type for sort column: {:?}",
649                sort_data_type
650            )?,
651        );
652
653        Ok(Some(threshold))
654    }
655
656    fn threshold_scalar_value(
657        sort_data_type: &DataType,
658        threshold: &TopKThreshold,
659    ) -> datafusion_common::Result<ScalarValue> {
660        let value = match threshold {
661            TopKThreshold::Null => None,
662            TopKThreshold::Value(value) => Some(*value),
663        };
664
665        let scalar = match sort_data_type {
666            DataType::Timestamp(TimeUnit::Second, tz) => {
667                ScalarValue::TimestampSecond(value, tz.clone())
668            }
669            DataType::Timestamp(TimeUnit::Millisecond, tz) => {
670                ScalarValue::TimestampMillisecond(value, tz.clone())
671            }
672            DataType::Timestamp(TimeUnit::Microsecond, tz) => {
673                ScalarValue::TimestampMicrosecond(value, tz.clone())
674            }
675            DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
676                ScalarValue::TimestampNanosecond(value, tz.clone())
677            }
678            _ => internal_err!(
679                "Unsupported data type for sort column: {:?}",
680                sort_data_type
681            )?,
682        };
683
684        Ok(scalar)
685    }
686
687    fn build_dynamic_filter_expr(
688        &self,
689        sort_data_type: &DataType,
690        threshold: &TopKThreshold,
691    ) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
692        let op = if self.expression.options.descending {
693            Operator::Gt
694        } else {
695            Operator::Lt
696        };
697        let value_null = matches!(threshold, TopKThreshold::Null);
698        let value = Self::threshold_scalar_value(sort_data_type, threshold)?;
699        let comparison: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
700            self.expression.expr.clone(),
701            op,
702            lit(value),
703        ));
704
705        match (self.expression.options.nulls_first, value_null) {
706            (true, true) => Ok(lit(false)),
707            (true, false) => Ok(Arc::new(BinaryExpr::new(
708                is_null(self.expression.expr.clone())?,
709                Operator::Or,
710                comparison,
711            ))),
712            (false, true) => is_not_null(self.expression.expr.clone()),
713            (false, false) => Ok(comparison),
714        }
715    }
716
717    fn update_dynamic_filter(
718        &mut self,
719        sort_data_type: &DataType,
720        threshold: Option<TopKThreshold>,
721    ) -> datafusion_common::Result<()> {
722        let Some(filter) = &self.dynamic_filter else {
723            return Ok(());
724        };
725
726        let threshold = if let Some(threshold) = threshold {
727            threshold
728        } else {
729            let Some(threshold) = self.topk_threshold(sort_data_type)? else {
730                return Ok(());
731            };
732            threshold
733        };
734
735        if self.dynamic_filter_threshold.as_ref() == Some(&threshold) {
736            return Ok(());
737        }
738
739        let predicate = self.build_dynamic_filter_expr(sort_data_type, &threshold)?;
740        filter.update(predicate)?;
741        self.dynamic_filter_threshold = Some(threshold);
742
743        Ok(())
744    }
745
746    /// Returns true when all rows in the next group are guaranteed to be worse
747    /// than the current top-k threshold.
748    fn can_stop_before_group(
749        &self,
750        group_idx: usize,
751        sort_data_type: &arrow_schema::DataType,
752    ) -> datafusion_common::Result<bool> {
753        if group_idx >= self.range_groups.len() {
754            return Ok(false);
755        }
756
757        let threshold = if let Some(threshold) = &self.dynamic_filter_threshold {
758            threshold.clone()
759        } else {
760            let Some(threshold) = self.topk_threshold(sort_data_type)? else {
761                return Ok(false);
762            };
763            threshold
764        };
765
766        let (_, start_idx, _) = self.range_groups[group_idx];
767        let next_range =
768            project_partition_range_for_sort(self.partition_ranges[start_idx], sort_data_type)?;
769        let descending = self.expression.options.descending;
770        let next_primary = get_primary_end(&next_range, descending).value();
771
772        let can_stop = match threshold {
773            // When the k-th element is NULL:
774            // - nulls_first=true: nulls are the best values (Arrow sorts NULLs first
775            //   regardless of ASC/DESC), so top-k is already optimal → stop early.
776            // - nulls_first=false: nulls are the worst, non-null values from the
777            //   next group could displace them → continue reading.
778            TopKThreshold::Null => self.expression.options.nulls_first,
779            TopKThreshold::Value(value) => {
780                if descending {
781                    value >= next_primary
782                } else {
783                    value < next_primary
784                }
785            }
786        };
787
788        Ok(can_stop)
789    }
790
791    /// Check if the given partition index is within the current group.
792    fn is_in_current_group(&self, part_idx: usize) -> bool {
793        if self.cur_group_idx >= self.range_groups.len() {
794            return false;
795        }
796        let (_, start, end) = self.range_groups[self.cur_group_idx];
797        part_idx >= start && part_idx < end
798    }
799
800    /// Advance to the next group. Returns true if there is a next group.
801    fn advance_to_next_group(&mut self) -> bool {
802        self.cur_group_idx += 1;
803        self.cur_group_idx < self.range_groups.len()
804    }
805
806    /// Get the effective range for the current group.
807    /// For a group of ranges with the same primary end, the effective range is
808    /// the union of all ranges in the group.
809    fn get_current_group_effective_range(&self) -> Option<PartitionRange> {
810        if self.cur_group_idx >= self.range_groups.len() {
811            return None;
812        }
813        let (_, start_idx, end_idx) = self.range_groups[self.cur_group_idx];
814        if start_idx >= end_idx || start_idx >= self.partition_ranges.len() {
815            return None;
816        }
817
818        let ranges_in_group =
819            &self.partition_ranges[start_idx..end_idx.min(self.partition_ranges.len())];
820        if ranges_in_group.is_empty() {
821            return None;
822        }
823
824        // Compute union of all ranges in the group
825        let mut min_start = ranges_in_group[0].start;
826        let mut max_end = ranges_in_group[0].end;
827        for range in ranges_in_group.iter().skip(1) {
828            if range.start < min_start {
829                min_start = range.start;
830            }
831            if range.end > max_end {
832                max_end = range.end;
833            }
834        }
835
836        Some(PartitionRange {
837            start: min_start,
838            end: max_end,
839            num_rows: 0,   // Not used for validation
840            identifier: 0, // Not used for validation
841        })
842    }
843
844    /// Sort and clear the buffer and return the sorted record batch
845    ///
846    /// this function will return a empty record batch if the buffer is empty
847    fn sort_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
848        match &mut self.buffer {
849            PartSortBuffer::All(_) => self.sort_all_buffer(),
850            PartSortBuffer::TopK(_) => self.sort_topk_buffer(),
851        }
852    }
853
854    fn sort_record_batches(
855        &mut self,
856        buffer: &[DfRecordBatch],
857        limit: Option<usize>,
858        check_range: bool,
859    ) -> datafusion_common::Result<DfRecordBatch> {
860        if buffer.is_empty() {
861            return Ok(DfRecordBatch::new_empty(self.schema.clone()));
862        }
863
864        let mut sort_columns = Vec::with_capacity(buffer.len());
865        let mut opt = None;
866        for batch in buffer.iter() {
867            let sort_column = self.expression.evaluate_to_sort_column(batch)?;
868            opt = opt.or(sort_column.options);
869            sort_columns.push(sort_column.values);
870        }
871
872        let sort_column =
873            concat(&sort_columns.iter().map(|a| a.as_ref()).collect_vec()).map_err(|e| {
874                DataFusionError::ArrowError(
875                    Box::new(e),
876                    Some(format!("Fail to concat sort columns at {}", location!())),
877                )
878            })?;
879
880        let indices = sort_to_indices(&sort_column, opt, limit).map_err(|e| {
881            DataFusionError::ArrowError(
882                Box::new(e),
883                Some(format!("Fail to sort to indices at {}", location!())),
884            )
885        })?;
886        if indices.is_empty() {
887            return Ok(DfRecordBatch::new_empty(self.schema.clone()));
888        }
889
890        if check_range {
891            self.check_in_range(
892                &sort_column,
893                (
894                    indices.value(0) as usize,
895                    indices.value(indices.len() - 1) as usize,
896                ),
897            )
898            .inspect_err(|_e| {
899                #[cfg(debug_assertions)]
900                common_telemetry::error!(
901                    "Fail to check sort column in range at {}, current_idx: {}, num_rows: {}, err: {}",
902                    self.partition,
903                    self.cur_part_idx,
904                    sort_column.len(),
905                    _e
906                );
907            })?;
908        }
909
910        // reserve memory for the concat input and sorted output
911        let total_mem: usize = buffer.iter().map(|r| r.get_array_memory_size()).sum();
912        self.reservation.try_grow(total_mem * 2)?;
913
914        let full_input = concat_batches(&self.schema, buffer).map_err(|e| {
915            DataFusionError::ArrowError(
916                Box::new(e),
917                Some(format!(
918                    "Fail to concat input batches when sorting at {}",
919                    location!()
920                )),
921            )
922        })?;
923
924        let sorted = take_record_batch(&full_input, &indices).map_err(|e| {
925            DataFusionError::ArrowError(
926                Box::new(e),
927                Some(format!(
928                    "Fail to take result record batch when sorting at {}",
929                    location!()
930                )),
931            )
932        })?;
933
934        drop(full_input);
935        // here remove both buffer and full_input memory
936        self.reservation.shrink(2 * total_mem);
937        Ok(sorted)
938    }
939
940    /// Internal method for sorting `All` buffer (without limit).
941    fn sort_all_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
942        let PartSortBuffer::All(buffer) =
943            std::mem::replace(&mut self.buffer, PartSortBuffer::All(Vec::new()))
944        else {
945            unreachable!()
946        };
947
948        self.sort_record_batches(&buffer, self.limit, self.limit.is_none())
949    }
950
951    fn sort_topk_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
952        let PartSortBuffer::TopK(buffer) =
953            std::mem::replace(&mut self.buffer, PartSortBuffer::TopK(Vec::new()))
954        else {
955            unreachable!()
956        };
957
958        self.sort_record_batches(&buffer, self.limit, false)
959    }
960
961    /// Sorts current buffer and returns `None` when there is nothing to emit.
962    fn sorted_buffer_if_non_empty(&mut self) -> datafusion_common::Result<Option<DfRecordBatch>> {
963        if self.buffer.is_empty() {
964            return Ok(None);
965        }
966
967        let sorted = self.sort_buffer()?;
968        if sorted.num_rows() == 0 {
969            Ok(None)
970        } else {
971            Ok(Some(sorted))
972        }
973    }
974
975    fn mark_dynamic_filter_complete(&self) {
976        if let Some(filter) = &self.dynamic_filter {
977            filter.mark_complete();
978        }
979    }
980
981    /// Try to split the input batch if it contains data that exceeds the current partition range.
982    ///
983    /// When the input batch contains data that exceeds the current partition range, this function
984    /// will split the input batch into two parts, the first part is within the current partition
985    /// range will be merged and sorted with previous buffer, and the second part will be registered
986    /// to `evaluating_batch` for next polling.
987    ///
988    /// **Group-based processing**: Ranges with the same primary end are grouped together.
989    /// We only sort and output when transitioning to a NEW group, not when moving between
990    /// ranges within the same group.
991    ///
992    /// Returns `None` if the input batch is empty or fully within the current partition range
993    /// (or we're still collecting data within the same group), and `Some(batch)` when we've
994    /// completed a group and have sorted output. When operating in limit mode, this
995    /// function will not emit intermediate batches; it only prepares state for a single final
996    /// output.
997    fn split_batch(
998        &mut self,
999        batch: DfRecordBatch,
1000    ) -> datafusion_common::Result<Option<DfRecordBatch>> {
1001        if self.limit.is_some() {
1002            self.split_batch_topk(batch)?;
1003            return Ok(None);
1004        }
1005
1006        self.split_batch_all(batch)
1007    }
1008
1009    /// Specialized splitting logic for limit mode.
1010    ///
1011    /// We only emit once when input is fully consumed.
1012    /// When the buffer is fulfilled and we are about to enter a new group, we stop consuming
1013    /// further ranges.
1014    fn split_batch_topk(&mut self, batch: DfRecordBatch) -> datafusion_common::Result<()> {
1015        if batch.num_rows() == 0 {
1016            return Ok(());
1017        }
1018
1019        let sort_column = self
1020            .expression
1021            .expr
1022            .evaluate(&batch)?
1023            .into_array(batch.num_rows())?;
1024
1025        let next_range_idx = self.try_find_next_range(&sort_column)?;
1026        let Some(idx) = next_range_idx else {
1027            self.push_buffer(batch, sort_column.data_type())?;
1028            // keep polling input for next batch
1029            return Ok(());
1030        };
1031
1032        let this_range = batch.slice(0, idx);
1033        let remaining_range = batch.slice(idx, batch.num_rows() - idx);
1034        if this_range.num_rows() != 0 {
1035            self.push_buffer(this_range, sort_column.data_type())?;
1036        }
1037
1038        // Step to next proper PartitionRange
1039        self.cur_part_idx += 1;
1040
1041        // If we've processed all partitions, mark completion.
1042        if self.cur_part_idx >= self.partition_ranges.len() {
1043            debug_assert!(remaining_range.num_rows() == 0);
1044            self.input_complete = true;
1045            return Ok(());
1046        }
1047
1048        // Check if we're still in the same group
1049        let in_same_group = self.is_in_current_group(self.cur_part_idx);
1050
1051        if !in_same_group {
1052            let next_group_idx = self.cur_group_idx + 1;
1053            if self.can_stop_before_group(next_group_idx, sort_column.data_type())? {
1054                self.input_complete = true;
1055                return Ok(());
1056            }
1057            self.advance_to_next_group();
1058        }
1059
1060        let next_sort_column = sort_column.slice(idx, batch.num_rows() - idx);
1061        if self.try_find_next_range(&next_sort_column)?.is_some() {
1062            // remaining batch still contains data that exceeds the current partition range
1063            // register the remaining batch for next polling
1064            self.evaluating_batch = Some(remaining_range);
1065        } else if remaining_range.num_rows() != 0 {
1066            // remaining batch is within the current partition range
1067            // push to the buffer and continue polling
1068            self.push_buffer(remaining_range, sort_column.data_type())?;
1069        }
1070
1071        Ok(())
1072    }
1073
1074    fn split_batch_all(
1075        &mut self,
1076        batch: DfRecordBatch,
1077    ) -> datafusion_common::Result<Option<DfRecordBatch>> {
1078        if batch.num_rows() == 0 {
1079            return Ok(None);
1080        }
1081
1082        let sort_column = self
1083            .expression
1084            .expr
1085            .evaluate(&batch)?
1086            .into_array(batch.num_rows())?;
1087
1088        let next_range_idx = self.try_find_next_range(&sort_column)?;
1089        let Some(idx) = next_range_idx else {
1090            self.push_buffer(batch, sort_column.data_type())?;
1091            // keep polling input for next batch
1092            return Ok(None);
1093        };
1094
1095        let this_range = batch.slice(0, idx);
1096        let remaining_range = batch.slice(idx, batch.num_rows() - idx);
1097        if this_range.num_rows() != 0 {
1098            self.push_buffer(this_range, sort_column.data_type())?;
1099        }
1100
1101        // Step to next proper PartitionRange
1102        self.cur_part_idx += 1;
1103
1104        // If we've processed all partitions, sort and output
1105        if self.cur_part_idx >= self.partition_ranges.len() {
1106            // assert there is no data beyond the last partition range (remaining is empty).
1107            debug_assert!(remaining_range.num_rows() == 0);
1108
1109            // Sort and output the final group
1110            return self.sorted_buffer_if_non_empty();
1111        }
1112
1113        // Check if we're still in the same group
1114        if self.is_in_current_group(self.cur_part_idx) {
1115            // Same group - don't sort yet, keep collecting
1116            let next_sort_column = sort_column.slice(idx, batch.num_rows() - idx);
1117            if self.try_find_next_range(&next_sort_column)?.is_some() {
1118                // remaining batch still contains data that exceeds the current partition range
1119                self.evaluating_batch = Some(remaining_range);
1120            } else {
1121                // remaining batch is within the current partition range
1122                if remaining_range.num_rows() != 0 {
1123                    self.push_buffer(remaining_range, sort_column.data_type())?;
1124                }
1125            }
1126            // Return None to continue collecting within the same group
1127            return Ok(None);
1128        }
1129
1130        // Transitioning to a new group - sort current group and output
1131        let sorted_batch = self.sorted_buffer_if_non_empty()?;
1132        self.advance_to_next_group();
1133
1134        let next_sort_column = sort_column.slice(idx, batch.num_rows() - idx);
1135        if self.try_find_next_range(&next_sort_column)?.is_some() {
1136            // remaining batch still contains data that exceeds the current partition range
1137            // register the remaining batch for next polling
1138            self.evaluating_batch = Some(remaining_range);
1139        } else {
1140            // remaining batch is within the current partition range
1141            // push to the buffer and continue polling
1142            if remaining_range.num_rows() != 0 {
1143                self.push_buffer(remaining_range, sort_column.data_type())?;
1144            }
1145        }
1146
1147        Ok(sorted_batch)
1148    }
1149
1150    pub fn poll_next_inner(
1151        mut self: Pin<&mut Self>,
1152        cx: &mut Context<'_>,
1153    ) -> Poll<Option<datafusion_common::Result<DfRecordBatch>>> {
1154        loop {
1155            if self.input_complete {
1156                if let Some(sorted_batch) = self.sorted_buffer_if_non_empty()? {
1157                    self.mark_dynamic_filter_complete();
1158                    return Poll::Ready(Some(Ok(sorted_batch)));
1159                }
1160                self.mark_dynamic_filter_complete();
1161                return Poll::Ready(None);
1162            }
1163
1164            // if there is a remaining batch being evaluated from last run,
1165            // split on it instead of fetching new batch
1166            if let Some(evaluating_batch) = self.evaluating_batch.take()
1167                && evaluating_batch.num_rows() != 0
1168            {
1169                // Check if we've already processed all partitions
1170                if self.cur_part_idx >= self.partition_ranges.len() {
1171                    // All partitions processed, discard remaining data
1172                    if let Some(sorted_batch) = self.sorted_buffer_if_non_empty()? {
1173                        self.mark_dynamic_filter_complete();
1174                        return Poll::Ready(Some(Ok(sorted_batch)));
1175                    }
1176                    self.mark_dynamic_filter_complete();
1177                    return Poll::Ready(None);
1178                }
1179
1180                if let Some(sorted_batch) = self.split_batch(evaluating_batch)? {
1181                    return Poll::Ready(Some(Ok(sorted_batch)));
1182                }
1183                continue;
1184            }
1185
1186            // fetch next batch from input
1187            let res = self.input.as_mut().poll_next(cx);
1188            match res {
1189                Poll::Ready(Some(Ok(batch))) => {
1190                    if let Some(sorted_batch) = self.split_batch(batch)? {
1191                        return Poll::Ready(Some(Ok(sorted_batch)));
1192                    }
1193                }
1194                // input stream end, mark and continue
1195                Poll::Ready(None) => {
1196                    self.input_complete = true;
1197                }
1198                Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
1199                Poll::Pending => return Poll::Pending,
1200            }
1201        }
1202    }
1203}
1204
1205impl Stream for PartSortStream {
1206    type Item = datafusion_common::Result<DfRecordBatch>;
1207
1208    fn poll_next(
1209        mut self: Pin<&mut Self>,
1210        cx: &mut Context<'_>,
1211    ) -> Poll<Option<datafusion_common::Result<DfRecordBatch>>> {
1212        let result = self.as_mut().poll_next_inner(cx);
1213        self.metrics.record_poll(result)
1214    }
1215}
1216
1217impl RecordBatchStream for PartSortStream {
1218    fn schema(&self) -> SchemaRef {
1219        self.schema.clone()
1220    }
1221}
1222
1223#[cfg(test)]
1224mod test {
1225    use std::sync::Arc;
1226
1227    use arrow::array::{
1228        BooleanArray, TimestampMicrosecondArray, TimestampMillisecondArray,
1229        TimestampNanosecondArray, TimestampSecondArray,
1230    };
1231    use arrow::json::ArrayWriter;
1232    use arrow_schema::{DataType, Field, Schema, SortOptions, TimeUnit};
1233    use common_time::Timestamp;
1234    use datafusion_physical_expr::expressions::Column;
1235    use futures::StreamExt;
1236    use store_api::region_engine::PartitionRange;
1237
1238    use super::*;
1239    use crate::test_util::{MockInputExec, new_ts_array};
1240
1241    #[ignore = "hard to gen expected data correctly here, TODO(discord9): fix it later"]
1242    #[tokio::test]
1243    async fn fuzzy_test() {
1244        let test_cnt = 100;
1245        // bound for total count of PartitionRange
1246        let part_cnt_bound = 100;
1247        // bound for timestamp range size and offset for each PartitionRange
1248        let range_size_bound = 100;
1249        let range_offset_bound = 100;
1250        // bound for batch count and size within each PartitionRange
1251        let batch_cnt_bound = 20;
1252        let batch_size_bound = 100;
1253
1254        let mut rng = fastrand::Rng::new();
1255        rng.seed(1337);
1256
1257        let mut test_cases = Vec::new();
1258
1259        for case_id in 0..test_cnt {
1260            let mut bound_val: Option<i64> = None;
1261            let descending = rng.bool();
1262            let nulls_first = rng.bool();
1263            let opt = SortOptions {
1264                descending,
1265                nulls_first,
1266            };
1267            let limit = if rng.bool() {
1268                Some(rng.usize(1..batch_cnt_bound * batch_size_bound))
1269            } else {
1270                None
1271            };
1272            let unit = match rng.u8(0..3) {
1273                0 => TimeUnit::Second,
1274                1 => TimeUnit::Millisecond,
1275                2 => TimeUnit::Microsecond,
1276                _ => TimeUnit::Nanosecond,
1277            };
1278
1279            let schema = Schema::new(vec![Field::new(
1280                "ts",
1281                DataType::Timestamp(unit, None),
1282                false,
1283            )]);
1284            let schema = Arc::new(schema);
1285
1286            let mut input_ranged_data = vec![];
1287            let mut output_ranges = vec![];
1288            let mut output_data = vec![];
1289            // generate each input `PartitionRange`
1290            for part_id in 0..rng.usize(0..part_cnt_bound) {
1291                // generate each `PartitionRange`'s timestamp range
1292                let (start, end) = if descending {
1293                    // Use 1..=range_offset_bound to ensure strictly decreasing end values
1294                    let end = bound_val
1295                        .map(
1296                            |i| i
1297                            .checked_sub(rng.i64(1..=range_offset_bound))
1298                            .expect("Bad luck, fuzzy test generate data that will overflow, change seed and try again")
1299                        )
1300                        .unwrap_or_else(|| rng.i64(-100000000..100000000));
1301                    bound_val = Some(end);
1302                    let start = end - rng.i64(1..range_size_bound);
1303                    let start = Timestamp::new(start, unit.into());
1304                    let end = Timestamp::new(end, unit.into());
1305                    (start, end)
1306                } else {
1307                    // Use 1..=range_offset_bound to ensure strictly increasing start values
1308                    let start = bound_val
1309                        .map(|i| i + rng.i64(1..=range_offset_bound))
1310                        .unwrap_or_else(|| rng.i64(..));
1311                    bound_val = Some(start);
1312                    let end = start + rng.i64(1..range_size_bound);
1313                    let start = Timestamp::new(start, unit.into());
1314                    let end = Timestamp::new(end, unit.into());
1315                    (start, end)
1316                };
1317                assert!(start < end);
1318
1319                let mut per_part_sort_data = vec![];
1320                let mut batches = vec![];
1321                for _batch_idx in 0..rng.usize(1..batch_cnt_bound) {
1322                    let cnt = rng.usize(0..batch_size_bound) + 1;
1323                    let iter = 0..rng.usize(0..cnt);
1324                    let mut data_gen = iter
1325                        .map(|_| rng.i64(start.value()..end.value()))
1326                        .collect_vec();
1327                    if data_gen.is_empty() {
1328                        // current batch is empty, skip
1329                        continue;
1330                    }
1331                    // mito always sort on ASC order
1332                    data_gen.sort();
1333                    per_part_sort_data.extend(data_gen.clone());
1334                    let arr = new_ts_array(unit, data_gen.clone());
1335                    let batch = DfRecordBatch::try_new(schema.clone(), vec![arr]).unwrap();
1336                    batches.push(batch);
1337                }
1338
1339                let range = PartitionRange {
1340                    start,
1341                    end,
1342                    num_rows: batches.iter().map(|b| b.num_rows()).sum(),
1343                    identifier: part_id,
1344                };
1345                input_ranged_data.push((range, batches));
1346
1347                output_ranges.push(range);
1348                if per_part_sort_data.is_empty() {
1349                    continue;
1350                }
1351                output_data.extend_from_slice(&per_part_sort_data);
1352            }
1353
1354            // adjust output data with adjacent PartitionRanges
1355            let mut output_data_iter = output_data.iter().peekable();
1356            let mut output_data = vec![];
1357            for range in output_ranges.clone() {
1358                let mut cur_data = vec![];
1359                while let Some(val) = output_data_iter.peek() {
1360                    if **val < range.start.value() || **val >= range.end.value() {
1361                        break;
1362                    }
1363                    cur_data.push(*output_data_iter.next().unwrap());
1364                }
1365
1366                if cur_data.is_empty() {
1367                    continue;
1368                }
1369
1370                if descending {
1371                    cur_data.sort_by(|a, b| b.cmp(a));
1372                } else {
1373                    cur_data.sort();
1374                }
1375                output_data.push(cur_data);
1376            }
1377
1378            let expected_output = if let Some(limit) = limit {
1379                let mut accumulated = Vec::new();
1380                let mut seen = 0usize;
1381                for mut range_values in output_data {
1382                    seen += range_values.len();
1383                    accumulated.append(&mut range_values);
1384                    if seen >= limit {
1385                        break;
1386                    }
1387                }
1388
1389                if accumulated.is_empty() {
1390                    None
1391                } else {
1392                    if descending {
1393                        accumulated.sort_by(|a, b| b.cmp(a));
1394                    } else {
1395                        accumulated.sort();
1396                    }
1397                    accumulated.truncate(limit.min(accumulated.len()));
1398
1399                    Some(
1400                        DfRecordBatch::try_new(
1401                            schema.clone(),
1402                            vec![new_ts_array(unit, accumulated)],
1403                        )
1404                        .unwrap(),
1405                    )
1406                }
1407            } else {
1408                let batches = output_data
1409                    .into_iter()
1410                    .map(|a| {
1411                        DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, a)]).unwrap()
1412                    })
1413                    .collect_vec();
1414                if batches.is_empty() {
1415                    None
1416                } else {
1417                    Some(concat_batches(&schema, &batches).unwrap())
1418                }
1419            };
1420
1421            test_cases.push((
1422                case_id,
1423                unit,
1424                input_ranged_data,
1425                schema,
1426                opt,
1427                limit,
1428                expected_output,
1429            ));
1430        }
1431
1432        for (case_id, _unit, input_ranged_data, schema, opt, limit, expected_output) in test_cases {
1433            run_test(
1434                case_id,
1435                input_ranged_data,
1436                schema,
1437                opt,
1438                limit,
1439                expected_output,
1440                None,
1441            )
1442            .await;
1443        }
1444    }
1445
1446    #[tokio::test]
1447    async fn simple_cases() {
1448        let testcases = vec![
1449            (
1450                TimeUnit::Millisecond,
1451                vec![
1452                    ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]]),
1453                    ((5, 10), vec![vec![5, 6], vec![7, 8]]),
1454                ],
1455                false,
1456                None,
1457                vec![vec![1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9]],
1458            ),
1459            // Case 1: Descending sort with overlapping ranges that have the same primary end (end=10).
1460            // Ranges [5,10) and [0,10) are grouped together, so their data is merged before sorting.
1461            (
1462                TimeUnit::Millisecond,
1463                vec![
1464                    ((5, 10), vec![vec![5, 6], vec![7, 8, 9]]),
1465                    ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
1466                ],
1467                true,
1468                None,
1469                vec![vec![9, 8, 8, 7, 7, 6, 6, 5, 5, 4, 3, 2, 1]],
1470            ),
1471            (
1472                TimeUnit::Millisecond,
1473                vec![
1474                    ((5, 10), vec![]),
1475                    ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
1476                ],
1477                true,
1478                None,
1479                vec![vec![8, 7, 6, 5, 4, 3, 2, 1]],
1480            ),
1481            (
1482                TimeUnit::Millisecond,
1483                vec![
1484                    ((15, 20), vec![vec![17, 18, 19]]),
1485                    ((10, 15), vec![]),
1486                    ((5, 10), vec![]),
1487                    ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
1488                ],
1489                true,
1490                None,
1491                vec![vec![19, 18, 17], vec![8, 7, 6, 5, 4, 3, 2, 1]],
1492            ),
1493            (
1494                TimeUnit::Millisecond,
1495                vec![
1496                    ((15, 20), vec![]),
1497                    ((10, 15), vec![]),
1498                    ((5, 10), vec![]),
1499                    ((0, 10), vec![]),
1500                ],
1501                true,
1502                None,
1503                vec![],
1504            ),
1505            // Case 5: Data from one batch spans multiple ranges. Ranges with same end are grouped.
1506            // Ranges: [15,20) end=20, [10,15) end=15, [5,10) end=10, [0,10) end=10
1507            // Groups: {[15,20)}, {[10,15)}, {[5,10), [0,10)}
1508            // The last two ranges are merged because they share end=10.
1509            (
1510                TimeUnit::Millisecond,
1511                vec![
1512                    (
1513                        (15, 20),
1514                        vec![vec![15, 17, 19, 10, 11, 12, 5, 6, 7, 8, 9, 1, 2, 3, 4]],
1515                    ),
1516                    ((10, 15), vec![]),
1517                    ((5, 10), vec![]),
1518                    ((0, 10), vec![]),
1519                ],
1520                true,
1521                None,
1522                vec![
1523                    vec![19, 17, 15],
1524                    vec![12, 11, 10],
1525                    vec![9, 8, 7, 6, 5, 4, 3, 2, 1],
1526                ],
1527            ),
1528            (
1529                TimeUnit::Millisecond,
1530                vec![
1531                    (
1532                        (15, 20),
1533                        vec![vec![15, 17, 19, 10, 11, 12, 5, 6, 7, 8, 9, 1, 2, 3, 4]],
1534                    ),
1535                    ((10, 15), vec![]),
1536                    ((5, 10), vec![]),
1537                    ((0, 10), vec![]),
1538                ],
1539                true,
1540                Some(2),
1541                vec![vec![19, 17]],
1542            ),
1543        ];
1544
1545        for (identifier, (unit, input_ranged_data, descending, limit, expected_output)) in
1546            testcases.into_iter().enumerate()
1547        {
1548            let schema = Schema::new(vec![Field::new(
1549                "ts",
1550                DataType::Timestamp(unit, None),
1551                false,
1552            )]);
1553            let schema = Arc::new(schema);
1554            let opt = SortOptions {
1555                descending,
1556                ..Default::default()
1557            };
1558
1559            let input_ranged_data = input_ranged_data
1560                .into_iter()
1561                .map(|(range, data)| {
1562                    let part = PartitionRange {
1563                        start: Timestamp::new(range.0, unit.into()),
1564                        end: Timestamp::new(range.1, unit.into()),
1565                        num_rows: data.iter().map(|b| b.len()).sum(),
1566                        identifier,
1567                    };
1568
1569                    let batches = data
1570                        .into_iter()
1571                        .map(|b| {
1572                            let arr = new_ts_array(unit, b);
1573                            DfRecordBatch::try_new(schema.clone(), vec![arr]).unwrap()
1574                        })
1575                        .collect_vec();
1576                    (part, batches)
1577                })
1578                .collect_vec();
1579
1580            let expected_output = expected_output
1581                .into_iter()
1582                .map(|a| {
1583                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, a)]).unwrap()
1584                })
1585                .collect_vec();
1586            let expected_output = if expected_output.is_empty() {
1587                None
1588            } else {
1589                Some(concat_batches(&schema, &expected_output).unwrap())
1590            };
1591
1592            run_test(
1593                identifier,
1594                input_ranged_data,
1595                schema.clone(),
1596                opt,
1597                limit,
1598                expected_output,
1599                None,
1600            )
1601            .await;
1602        }
1603    }
1604
1605    #[allow(clippy::print_stdout)]
1606    async fn run_test(
1607        case_id: usize,
1608        input_ranged_data: Vec<(PartitionRange, Vec<DfRecordBatch>)>,
1609        schema: SchemaRef,
1610        opt: SortOptions,
1611        limit: Option<usize>,
1612        expected_output: Option<DfRecordBatch>,
1613        expected_polled_rows: Option<usize>,
1614    ) {
1615        if let (Some(limit), Some(rb)) = (limit, &expected_output) {
1616            assert!(
1617                rb.num_rows() <= limit,
1618                "Expect row count in expected output({}) <= limit({})",
1619                rb.num_rows(),
1620                limit
1621            );
1622        }
1623
1624        let mut data_partition = Vec::with_capacity(input_ranged_data.len());
1625        let mut ranges = Vec::with_capacity(input_ranged_data.len());
1626        for (part_range, batches) in input_ranged_data {
1627            data_partition.push(batches);
1628            ranges.push(part_range);
1629        }
1630
1631        let mock_input = Arc::new(MockInputExec::new(data_partition, schema.clone()));
1632
1633        let exec = PartSortExec::try_new(
1634            PhysicalSortExpr {
1635                expr: Arc::new(Column::new("ts", 0)),
1636                options: opt,
1637            },
1638            limit,
1639            vec![ranges.clone()],
1640            mock_input.clone(),
1641        )
1642        .unwrap();
1643
1644        let exec_stream = exec.execute(0, Arc::new(TaskContext::default())).unwrap();
1645
1646        let real_output = exec_stream.map(|r| r.unwrap()).collect::<Vec<_>>().await;
1647        if limit.is_some() {
1648            assert!(
1649                real_output.len() <= 1,
1650                "case_{case_id} expects a single output batch when limit is set, got {}",
1651                real_output.len()
1652            );
1653        }
1654
1655        let actual_output = if real_output.is_empty() {
1656            None
1657        } else {
1658            Some(concat_batches(&schema, &real_output).unwrap())
1659        };
1660
1661        if let Some(expected_polled_rows) = expected_polled_rows {
1662            let input_pulled_rows = mock_input.metrics().unwrap().output_rows().unwrap();
1663            assert_eq!(input_pulled_rows, expected_polled_rows);
1664        }
1665
1666        match (actual_output, expected_output) {
1667            (None, None) => {}
1668            (Some(actual), Some(expected)) => {
1669                if actual != expected {
1670                    let mut actual_json: Vec<u8> = Vec::new();
1671                    let mut writer = ArrayWriter::new(&mut actual_json);
1672                    writer.write(&actual).unwrap();
1673                    writer.finish().unwrap();
1674
1675                    let mut expected_json: Vec<u8> = Vec::new();
1676                    let mut writer = ArrayWriter::new(&mut expected_json);
1677                    writer.write(&expected).unwrap();
1678                    writer.finish().unwrap();
1679
1680                    panic!(
1681                        "case_{} failed (limit {limit:?}), opt: {:?},\nreal_output: {}\nexpected: {}",
1682                        case_id,
1683                        opt,
1684                        String::from_utf8_lossy(&actual_json),
1685                        String::from_utf8_lossy(&expected_json),
1686                    );
1687                }
1688            }
1689            (None, Some(expected)) => panic!(
1690                "case_{} failed (limit {limit:?}), opt: {:?},\nreal output is empty, expected {} rows",
1691                case_id,
1692                opt,
1693                expected.num_rows()
1694            ),
1695            (Some(actual), None) => panic!(
1696                "case_{} failed (limit {limit:?}), opt: {:?},\nreal output has {} rows, expected empty",
1697                case_id,
1698                opt,
1699                actual.num_rows()
1700            ),
1701        }
1702    }
1703
1704    /// Test that verifies the limit is correctly applied per partition when
1705    /// multiple batches are received for the same partition.
1706    #[tokio::test]
1707    async fn test_limit_with_multiple_batches_per_partition() {
1708        let unit = TimeUnit::Millisecond;
1709        let schema = Arc::new(Schema::new(vec![Field::new(
1710            "ts",
1711            DataType::Timestamp(unit, None),
1712            false,
1713        )]));
1714
1715        // Test case: Multiple batches in a single partition with limit=3
1716        // Input: 3 batches with [1,2,3], [4,5,6], [7,8,9] all in partition (0,10)
1717        // Expected: Only top 3 values [9,8,7] for descending sort
1718        let input_ranged_data = vec![(
1719            PartitionRange {
1720                start: Timestamp::new(0, unit.into()),
1721                end: Timestamp::new(10, unit.into()),
1722                num_rows: 9,
1723                identifier: 0,
1724            },
1725            vec![
1726                DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![1, 2, 3])])
1727                    .unwrap(),
1728                DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![4, 5, 6])])
1729                    .unwrap(),
1730                DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![7, 8, 9])])
1731                    .unwrap(),
1732            ],
1733        )];
1734
1735        let expected_output = Some(
1736            DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![9, 8, 7])])
1737                .unwrap(),
1738        );
1739
1740        run_test(
1741            1000,
1742            input_ranged_data,
1743            schema.clone(),
1744            SortOptions {
1745                descending: true,
1746                ..Default::default()
1747            },
1748            Some(3),
1749            expected_output,
1750            None,
1751        )
1752        .await;
1753
1754        // Test case: Multiple batches across multiple partitions with limit=2
1755        // Partition 0: batches [10,11,12], [13,14,15] -> top 2 descending = [15,14]
1756        // Partition 1: batches [1,2,3], [4,5] -> top 2 descending = [5,4]
1757        let input_ranged_data = vec![
1758            (
1759                PartitionRange {
1760                    start: Timestamp::new(10, unit.into()),
1761                    end: Timestamp::new(20, unit.into()),
1762                    num_rows: 6,
1763                    identifier: 0,
1764                },
1765                vec![
1766                    DfRecordBatch::try_new(
1767                        schema.clone(),
1768                        vec![new_ts_array(unit, vec![10, 11, 12])],
1769                    )
1770                    .unwrap(),
1771                    DfRecordBatch::try_new(
1772                        schema.clone(),
1773                        vec![new_ts_array(unit, vec![13, 14, 15])],
1774                    )
1775                    .unwrap(),
1776                ],
1777            ),
1778            (
1779                PartitionRange {
1780                    start: Timestamp::new(0, unit.into()),
1781                    end: Timestamp::new(10, unit.into()),
1782                    num_rows: 5,
1783                    identifier: 1,
1784                },
1785                vec![
1786                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![1, 2, 3])])
1787                        .unwrap(),
1788                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![4, 5])])
1789                        .unwrap(),
1790                ],
1791            ),
1792        ];
1793
1794        let expected_output = Some(
1795            DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![15, 14])]).unwrap(),
1796        );
1797
1798        run_test(
1799            1001,
1800            input_ranged_data,
1801            schema.clone(),
1802            SortOptions {
1803                descending: true,
1804                ..Default::default()
1805            },
1806            Some(2),
1807            expected_output,
1808            None,
1809        )
1810        .await;
1811
1812        // Test case: Ascending sort with limit
1813        // Partition: batches [7,8,9], [4,5,6], [1,2,3] -> top 2 ascending = [1,2]
1814        let input_ranged_data = vec![(
1815            PartitionRange {
1816                start: Timestamp::new(0, unit.into()),
1817                end: Timestamp::new(10, unit.into()),
1818                num_rows: 9,
1819                identifier: 0,
1820            },
1821            vec![
1822                DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![7, 8, 9])])
1823                    .unwrap(),
1824                DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![4, 5, 6])])
1825                    .unwrap(),
1826                DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![1, 2, 3])])
1827                    .unwrap(),
1828            ],
1829        )];
1830
1831        let expected_output = Some(
1832            DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![1, 2])]).unwrap(),
1833        );
1834
1835        run_test(
1836            1002,
1837            input_ranged_data,
1838            schema.clone(),
1839            SortOptions {
1840                descending: false,
1841                ..Default::default()
1842            },
1843            Some(2),
1844            expected_output,
1845            None,
1846        )
1847        .await;
1848    }
1849
1850    #[test]
1851    fn test_topk_buffer_is_bounded_and_updates_dynamic_filter() {
1852        let unit = TimeUnit::Millisecond;
1853        let schema = Arc::new(Schema::new(vec![Field::new(
1854            "ts",
1855            DataType::Timestamp(unit, None),
1856            false,
1857        )]));
1858        let sort_data_type = DataType::Timestamp(unit, None);
1859        let partition_range = PartitionRange {
1860            start: Timestamp::new(0, unit.into()),
1861            end: Timestamp::new(10, unit.into()),
1862            num_rows: 9,
1863            identifier: 0,
1864        };
1865        let mock_input = Arc::new(MockInputExec::new(vec![vec![]], schema.clone()));
1866        let exec = PartSortExec::try_new(
1867            PhysicalSortExpr {
1868                expr: Arc::new(Column::new("ts", 0)),
1869                options: SortOptions {
1870                    descending: true,
1871                    ..Default::default()
1872                },
1873            },
1874            Some(3),
1875            vec![vec![partition_range]],
1876            mock_input.clone(),
1877        )
1878        .unwrap();
1879        let input_stream = mock_input
1880            .execute(0, Arc::new(TaskContext::default()))
1881            .unwrap();
1882        let mut stream = PartSortStream::new(
1883            Arc::new(TaskContext::default()),
1884            &exec,
1885            Some(3),
1886            input_stream,
1887            vec![partition_range],
1888            0,
1889        )
1890        .unwrap();
1891
1892        for batch in [vec![1, 2, 3], vec![4, 5, 6], vec![0, 7, 8]] {
1893            stream
1894                .push_buffer(
1895                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, batch)])
1896                        .unwrap(),
1897                    &sort_data_type,
1898                )
1899                .unwrap();
1900            assert_eq!(stream.buffer.num_rows(), 3);
1901        }
1902
1903        let dynamic_filter = stream.dynamic_filter.as_ref().unwrap().clone();
1904        let probe = DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![5, 6, 7])])
1905            .unwrap();
1906        let predicate = dynamic_filter.current().unwrap();
1907        let result = predicate
1908            .evaluate(&probe)
1909            .unwrap()
1910            .into_array(probe.num_rows())
1911            .unwrap();
1912        let result = result.as_any().downcast_ref::<BooleanArray>().unwrap();
1913        assert_eq!(result, &BooleanArray::from(vec![false, false, true]));
1914
1915        let expected =
1916            DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![8, 7, 6])])
1917                .unwrap();
1918        assert_eq!(stream.sort_buffer().unwrap(), expected);
1919    }
1920
1921    #[test]
1922    fn test_topk_limit_zero_clears_buffer_without_threshold() {
1923        let unit = TimeUnit::Millisecond;
1924        let schema = Arc::new(Schema::new(vec![Field::new(
1925            "ts",
1926            DataType::Timestamp(unit, None),
1927            false,
1928        )]));
1929        let sort_data_type = DataType::Timestamp(unit, None);
1930        let partition_range = PartitionRange {
1931            start: Timestamp::new(0, unit.into()),
1932            end: Timestamp::new(10, unit.into()),
1933            num_rows: 3,
1934            identifier: 0,
1935        };
1936        let mock_input = Arc::new(MockInputExec::new(vec![vec![]], schema.clone()));
1937        let exec = PartSortExec::try_new(
1938            PhysicalSortExpr {
1939                expr: Arc::new(Column::new("ts", 0)),
1940                options: SortOptions {
1941                    descending: true,
1942                    ..Default::default()
1943                },
1944            },
1945            Some(0),
1946            vec![vec![partition_range]],
1947            mock_input.clone(),
1948        )
1949        .unwrap();
1950        let input_stream = mock_input
1951            .execute(0, Arc::new(TaskContext::default()))
1952            .unwrap();
1953        let mut stream = PartSortStream::new(
1954            Arc::new(TaskContext::default()),
1955            &exec,
1956            Some(0),
1957            input_stream,
1958            vec![partition_range],
1959            0,
1960        )
1961        .unwrap();
1962
1963        stream
1964            .push_buffer(
1965                DfRecordBatch::try_new(schema, vec![new_ts_array(unit, vec![1, 2, 3])]).unwrap(),
1966                &sort_data_type,
1967            )
1968            .unwrap();
1969
1970        assert_eq!(stream.buffer.num_rows(), 0);
1971        assert_eq!(stream.dynamic_filter_threshold, None);
1972    }
1973
1974    /// Test that verifies early termination behavior.
1975    /// Once we've produced limit * num_partitions rows, we should stop
1976    /// pulling from input stream.
1977    #[tokio::test]
1978    async fn test_early_termination() {
1979        let unit = TimeUnit::Millisecond;
1980        let schema = Arc::new(Schema::new(vec![Field::new(
1981            "ts",
1982            DataType::Timestamp(unit, None),
1983            false,
1984        )]));
1985
1986        // Create 3 partitions, each with more data than the limit
1987        // limit=2 per partition, so total expected output = 6 rows
1988        // After producing 6 rows, early termination should kick in
1989        // For descending sort, ranges must be ordered by (end DESC, start DESC)
1990        let input_ranged_data = vec![
1991            (
1992                PartitionRange {
1993                    start: Timestamp::new(20, unit.into()),
1994                    end: Timestamp::new(30, unit.into()),
1995                    num_rows: 10,
1996                    identifier: 2,
1997                },
1998                vec![
1999                    DfRecordBatch::try_new(
2000                        schema.clone(),
2001                        vec![new_ts_array(unit, vec![21, 22, 23, 24, 25])],
2002                    )
2003                    .unwrap(),
2004                    DfRecordBatch::try_new(
2005                        schema.clone(),
2006                        vec![new_ts_array(unit, vec![26, 27, 28, 29, 30])],
2007                    )
2008                    .unwrap(),
2009                ],
2010            ),
2011            (
2012                PartitionRange {
2013                    start: Timestamp::new(10, unit.into()),
2014                    end: Timestamp::new(20, unit.into()),
2015                    num_rows: 10,
2016                    identifier: 1,
2017                },
2018                vec![
2019                    DfRecordBatch::try_new(
2020                        schema.clone(),
2021                        vec![new_ts_array(unit, vec![11, 12, 13, 14, 15])],
2022                    )
2023                    .unwrap(),
2024                    DfRecordBatch::try_new(
2025                        schema.clone(),
2026                        vec![new_ts_array(unit, vec![16, 17, 18, 19, 20])],
2027                    )
2028                    .unwrap(),
2029                ],
2030            ),
2031            (
2032                PartitionRange {
2033                    start: Timestamp::new(0, unit.into()),
2034                    end: Timestamp::new(10, unit.into()),
2035                    num_rows: 10,
2036                    identifier: 0,
2037                },
2038                vec![
2039                    DfRecordBatch::try_new(
2040                        schema.clone(),
2041                        vec![new_ts_array(unit, vec![1, 2, 3, 4, 5])],
2042                    )
2043                    .unwrap(),
2044                    DfRecordBatch::try_new(
2045                        schema.clone(),
2046                        vec![new_ts_array(unit, vec![6, 7, 8, 9, 10])],
2047                    )
2048                    .unwrap(),
2049                ],
2050            ),
2051        ];
2052
2053        // PartSort won't reorder `PartitionRange` (it assumes it's already ordered), so it will not read other partitions.
2054        // This case is just to verify that early termination works as expected.
2055        // First partition [20, 30) produces top 2 values: 29, 28
2056        let expected_output = Some(
2057            DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![29, 28])]).unwrap(),
2058        );
2059
2060        run_test(
2061            1003,
2062            input_ranged_data,
2063            schema.clone(),
2064            SortOptions {
2065                descending: true,
2066                ..Default::default()
2067            },
2068            Some(2),
2069            expected_output,
2070            Some(10),
2071        )
2072        .await;
2073    }
2074
2075    /// Example:
2076    /// - Range [70, 100) has data [80, 90, 95]
2077    /// - Range [50, 100) has data [55, 65, 75, 85, 95]
2078    #[tokio::test]
2079    async fn test_primary_end_grouping_with_limit() {
2080        let unit = TimeUnit::Millisecond;
2081        let schema = Arc::new(Schema::new(vec![Field::new(
2082            "ts",
2083            DataType::Timestamp(unit, None),
2084            false,
2085        )]));
2086
2087        // Two ranges with the same end (100) - they should be grouped together
2088        // For descending, ranges are ordered by (end DESC, start DESC)
2089        // So [70, 100) comes before [50, 100) (70 > 50)
2090        let input_ranged_data = vec![
2091            (
2092                PartitionRange {
2093                    start: Timestamp::new(70, unit.into()),
2094                    end: Timestamp::new(100, unit.into()),
2095                    num_rows: 3,
2096                    identifier: 0,
2097                },
2098                vec![
2099                    DfRecordBatch::try_new(
2100                        schema.clone(),
2101                        vec![new_ts_array(unit, vec![80, 90, 95])],
2102                    )
2103                    .unwrap(),
2104                ],
2105            ),
2106            (
2107                PartitionRange {
2108                    start: Timestamp::new(50, unit.into()),
2109                    end: Timestamp::new(100, unit.into()),
2110                    num_rows: 5,
2111                    identifier: 1,
2112                },
2113                vec![
2114                    DfRecordBatch::try_new(
2115                        schema.clone(),
2116                        vec![new_ts_array(unit, vec![55, 65, 75, 85, 95])],
2117                    )
2118                    .unwrap(),
2119                ],
2120            ),
2121        ];
2122
2123        // With limit=4, descending: top 4 values from combined data
2124        // Combined: [80, 90, 95, 55, 65, 75, 85, 95] -> sorted desc: [95, 95, 90, 85, 80, 75, 65, 55]
2125        // Top 4: [95, 95, 90, 85]
2126        let expected_output = Some(
2127            DfRecordBatch::try_new(
2128                schema.clone(),
2129                vec![new_ts_array(unit, vec![95, 95, 90, 85])],
2130            )
2131            .unwrap(),
2132        );
2133
2134        run_test(
2135            2000,
2136            input_ranged_data,
2137            schema.clone(),
2138            SortOptions {
2139                descending: true,
2140                ..Default::default()
2141            },
2142            Some(4),
2143            expected_output,
2144            None,
2145        )
2146        .await;
2147    }
2148
2149    /// Test case with three ranges demonstrating the "keep pulling" behavior.
2150    /// After processing ranges with end=100, the smallest value in top-k might still
2151    /// be reachable by the next group.
2152    ///
2153    /// Ranges: [70, 100), [50, 100), [40, 95)
2154    /// With descending sort and limit=4:
2155    /// - Group 1 (end=100): [70, 100) and [50, 100) merged
2156    /// - Group 2 (end=95): [40, 95)
2157    /// After group 1, smallest in top-4 is 85. Range [40, 95) could have values >= 85,
2158    /// so we continue to group 2.
2159    #[tokio::test]
2160    async fn test_three_ranges_keep_pulling() {
2161        let unit = TimeUnit::Millisecond;
2162        let schema = Arc::new(Schema::new(vec![Field::new(
2163            "ts",
2164            DataType::Timestamp(unit, None),
2165            false,
2166        )]));
2167
2168        // Three ranges, two with same end (100), one with different end (95)
2169        let input_ranged_data = vec![
2170            (
2171                PartitionRange {
2172                    start: Timestamp::new(70, unit.into()),
2173                    end: Timestamp::new(100, unit.into()),
2174                    num_rows: 3,
2175                    identifier: 0,
2176                },
2177                vec![
2178                    DfRecordBatch::try_new(
2179                        schema.clone(),
2180                        vec![new_ts_array(unit, vec![80, 90, 95])],
2181                    )
2182                    .unwrap(),
2183                ],
2184            ),
2185            (
2186                PartitionRange {
2187                    start: Timestamp::new(50, unit.into()),
2188                    end: Timestamp::new(100, unit.into()),
2189                    num_rows: 3,
2190                    identifier: 1,
2191                },
2192                vec![
2193                    DfRecordBatch::try_new(
2194                        schema.clone(),
2195                        vec![new_ts_array(unit, vec![55, 75, 85])],
2196                    )
2197                    .unwrap(),
2198                ],
2199            ),
2200            (
2201                PartitionRange {
2202                    start: Timestamp::new(40, unit.into()),
2203                    end: Timestamp::new(95, unit.into()),
2204                    num_rows: 3,
2205                    identifier: 2,
2206                },
2207                vec![
2208                    DfRecordBatch::try_new(
2209                        schema.clone(),
2210                        vec![new_ts_array(unit, vec![45, 65, 94])],
2211                    )
2212                    .unwrap(),
2213                ],
2214            ),
2215        ];
2216
2217        // All data: [80, 90, 95, 55, 75, 85, 45, 65, 94]
2218        // Sorted descending: [95, 94, 90, 85, 80, 75, 65, 55, 45]
2219        // With limit=4: should be top 4 largest values across all ranges: [95, 94, 90, 85]
2220        let expected_output = Some(
2221            DfRecordBatch::try_new(
2222                schema.clone(),
2223                vec![new_ts_array(unit, vec![95, 94, 90, 85])],
2224            )
2225            .unwrap(),
2226        );
2227
2228        run_test(
2229            2001,
2230            input_ranged_data,
2231            schema.clone(),
2232            SortOptions {
2233                descending: true,
2234                ..Default::default()
2235            },
2236            Some(4),
2237            expected_output,
2238            None,
2239        )
2240        .await;
2241    }
2242
2243    /// Test early termination based on threshold comparison with next group.
2244    /// When the threshold (smallest value for descending) is >= next group's primary end,
2245    /// we can stop early because the next group cannot have better values.
2246    #[tokio::test]
2247    async fn test_threshold_based_early_termination() {
2248        let unit = TimeUnit::Millisecond;
2249        let schema = Arc::new(Schema::new(vec![Field::new(
2250            "ts",
2251            DataType::Timestamp(unit, None),
2252            false,
2253        )]));
2254
2255        // Group 1 (end=100) has 6 rows, TopK will keep top 4
2256        // Group 2 (end=90) has 3 rows - should NOT be processed because
2257        // threshold (96) >= next_primary_end (90)
2258        let input_ranged_data = vec![
2259            (
2260                PartitionRange {
2261                    start: Timestamp::new(70, unit.into()),
2262                    end: Timestamp::new(100, unit.into()),
2263                    num_rows: 6,
2264                    identifier: 0,
2265                },
2266                vec![
2267                    DfRecordBatch::try_new(
2268                        schema.clone(),
2269                        vec![new_ts_array(unit, vec![94, 95, 96, 97, 98, 99])],
2270                    )
2271                    .unwrap(),
2272                ],
2273            ),
2274            (
2275                PartitionRange {
2276                    start: Timestamp::new(50, unit.into()),
2277                    end: Timestamp::new(90, unit.into()),
2278                    num_rows: 3,
2279                    identifier: 1,
2280                },
2281                vec![
2282                    DfRecordBatch::try_new(
2283                        schema.clone(),
2284                        vec![new_ts_array(unit, vec![85, 86, 87])],
2285                    )
2286                    .unwrap(),
2287                ],
2288            ),
2289        ];
2290
2291        // With limit=4, descending: top 4 from group 1 are [99, 98, 97, 96]
2292        // Threshold is 96, next group's primary_end is 90
2293        // Since 96 >= 90, we stop after group 1
2294        let expected_output = Some(
2295            DfRecordBatch::try_new(
2296                schema.clone(),
2297                vec![new_ts_array(unit, vec![99, 98, 97, 96])],
2298            )
2299            .unwrap(),
2300        );
2301
2302        run_test(
2303            2002,
2304            input_ranged_data,
2305            schema.clone(),
2306            SortOptions {
2307                descending: true,
2308                ..Default::default()
2309            },
2310            Some(4),
2311            expected_output,
2312            Some(9), // Pull both batches since all rows fall within the first range
2313        )
2314        .await;
2315    }
2316
2317    /// Test that we continue to next group when threshold is within next group's range.
2318    /// Even after fulfilling limit, if threshold < next_primary_end (descending),
2319    /// we would need to continue... but limit exhaustion stops us first.
2320    #[tokio::test]
2321    async fn test_continue_when_threshold_in_next_group_range() {
2322        let unit = TimeUnit::Millisecond;
2323        let schema = Arc::new(Schema::new(vec![Field::new(
2324            "ts",
2325            DataType::Timestamp(unit, None),
2326            false,
2327        )]));
2328
2329        // Group 1 (end=100) has 6 rows, TopK will keep top 4
2330        // Group 2 (end=98) has 3 rows - threshold (96) < 98, so next group
2331        // could theoretically have better values. Continue reading.
2332        let input_ranged_data = vec![
2333            (
2334                PartitionRange {
2335                    start: Timestamp::new(90, unit.into()),
2336                    end: Timestamp::new(100, unit.into()),
2337                    num_rows: 6,
2338                    identifier: 0,
2339                },
2340                vec![
2341                    DfRecordBatch::try_new(
2342                        schema.clone(),
2343                        vec![new_ts_array(unit, vec![94, 95, 96, 97, 98, 99])],
2344                    )
2345                    .unwrap(),
2346                ],
2347            ),
2348            (
2349                PartitionRange {
2350                    start: Timestamp::new(50, unit.into()),
2351                    end: Timestamp::new(98, unit.into()),
2352                    num_rows: 3,
2353                    identifier: 1,
2354                },
2355                vec![
2356                    // Values must be < 70 (outside group 1's range) to avoid ambiguity
2357                    DfRecordBatch::try_new(
2358                        schema.clone(),
2359                        vec![new_ts_array(unit, vec![55, 60, 65])],
2360                    )
2361                    .unwrap(),
2362                ],
2363            ),
2364        ];
2365
2366        // With limit=4, we get [99, 98, 97, 96] from group 1
2367        // Threshold is 96, next group's primary_end is 98
2368        // 96 < 98, so threshold check says "could continue"
2369        // But limit is exhausted (0), so we stop anyway
2370        let expected_output = Some(
2371            DfRecordBatch::try_new(
2372                schema.clone(),
2373                vec![new_ts_array(unit, vec![99, 98, 97, 96])],
2374            )
2375            .unwrap(),
2376        );
2377
2378        // Note: We pull 9 rows (both batches) because we need to read batch 2
2379        // to detect the group boundary, even though we stop after outputting group 1.
2380        run_test(
2381            2003,
2382            input_ranged_data,
2383            schema.clone(),
2384            SortOptions {
2385                descending: true,
2386                ..Default::default()
2387            },
2388            Some(4),
2389            expected_output,
2390            Some(9), // Pull both batches to detect boundary
2391        )
2392        .await;
2393    }
2394
2395    /// Test ascending sort with threshold-based early termination.
2396    #[tokio::test]
2397    async fn test_ascending_threshold_early_termination() {
2398        let unit = TimeUnit::Millisecond;
2399        let schema = Arc::new(Schema::new(vec![Field::new(
2400            "ts",
2401            DataType::Timestamp(unit, None),
2402            false,
2403        )]));
2404
2405        // For ascending: primary_end is start, ranges sorted by (start ASC, end ASC)
2406        // Group 1 (start=10) has 6 rows
2407        // Group 2 (start=20) has 3 rows - should NOT be processed because
2408        // threshold (13) < next_primary_end (20)
2409        let input_ranged_data = vec![
2410            (
2411                PartitionRange {
2412                    start: Timestamp::new(10, unit.into()),
2413                    end: Timestamp::new(50, unit.into()),
2414                    num_rows: 6,
2415                    identifier: 0,
2416                },
2417                vec![
2418                    DfRecordBatch::try_new(
2419                        schema.clone(),
2420                        vec![new_ts_array(unit, vec![10, 11, 12, 13, 14, 15])],
2421                    )
2422                    .unwrap(),
2423                ],
2424            ),
2425            (
2426                PartitionRange {
2427                    start: Timestamp::new(20, unit.into()),
2428                    end: Timestamp::new(60, unit.into()),
2429                    num_rows: 3,
2430                    identifier: 1,
2431                },
2432                vec![
2433                    DfRecordBatch::try_new(
2434                        schema.clone(),
2435                        vec![new_ts_array(unit, vec![25, 30, 35])],
2436                    )
2437                    .unwrap(),
2438                ],
2439            ),
2440            // still read this batch to detect group boundary(?)
2441            (
2442                PartitionRange {
2443                    start: Timestamp::new(60, unit.into()),
2444                    end: Timestamp::new(70, unit.into()),
2445                    num_rows: 2,
2446                    identifier: 1,
2447                },
2448                vec![
2449                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![60, 61])])
2450                        .unwrap(),
2451                ],
2452            ),
2453            // after boundary detected, this following one should not be read
2454            (
2455                PartitionRange {
2456                    start: Timestamp::new(61, unit.into()),
2457                    end: Timestamp::new(70, unit.into()),
2458                    num_rows: 2,
2459                    identifier: 1,
2460                },
2461                vec![
2462                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![71, 72])])
2463                        .unwrap(),
2464                ],
2465            ),
2466        ];
2467
2468        // With limit=4, ascending: top 4 (smallest) from group 1 are [10, 11, 12, 13]
2469        // Threshold is 13 (largest in top-k), next group's primary_end is 20
2470        // Since 13 < 20, we stop after group 1 (no value in group 2 can be < 13)
2471        let expected_output = Some(
2472            DfRecordBatch::try_new(
2473                schema.clone(),
2474                vec![new_ts_array(unit, vec![10, 11, 12, 13])],
2475            )
2476            .unwrap(),
2477        );
2478
2479        run_test(
2480            2004,
2481            input_ranged_data,
2482            schema.clone(),
2483            SortOptions {
2484                descending: false,
2485                ..Default::default()
2486            },
2487            Some(4),
2488            expected_output,
2489            Some(11), // Pull first two batches to detect boundary
2490        )
2491        .await;
2492    }
2493
2494    #[tokio::test]
2495    async fn test_ascending_threshold_early_termination_case_two() {
2496        let unit = TimeUnit::Millisecond;
2497        let schema = Arc::new(Schema::new(vec![Field::new(
2498            "ts",
2499            DataType::Timestamp(unit, None),
2500            false,
2501        )]));
2502
2503        // For ascending: primary_end is start, ranges sorted by (start ASC, end ASC)
2504        // Group 1 (start=0) has 4 rows, Group 2 (start=4) has 1 row, Group 3 (start=5) has 4 rows
2505        // After reading all data: [9,10,11,12, 21, 5,6,7,8]
2506        // Sorted ascending: [5,6,7,8, 9,10,11,12, 21]
2507        // With limit=4, output should be smallest 4: [5,6,7,8]
2508        // Algorithm continues reading until start=42 > threshold=8, confirming no smaller values exist
2509        let input_ranged_data = vec![
2510            (
2511                PartitionRange {
2512                    start: Timestamp::new(0, unit.into()),
2513                    end: Timestamp::new(20, unit.into()),
2514                    num_rows: 4,
2515                    identifier: 0,
2516                },
2517                vec![
2518                    DfRecordBatch::try_new(
2519                        schema.clone(),
2520                        vec![new_ts_array(unit, vec![9, 10, 11, 12])],
2521                    )
2522                    .unwrap(),
2523                ],
2524            ),
2525            (
2526                PartitionRange {
2527                    start: Timestamp::new(4, unit.into()),
2528                    end: Timestamp::new(25, unit.into()),
2529                    num_rows: 1,
2530                    identifier: 1,
2531                },
2532                vec![
2533                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![21])])
2534                        .unwrap(),
2535                ],
2536            ),
2537            (
2538                PartitionRange {
2539                    start: Timestamp::new(5, unit.into()),
2540                    end: Timestamp::new(25, unit.into()),
2541                    num_rows: 4,
2542                    identifier: 1,
2543                },
2544                vec![
2545                    DfRecordBatch::try_new(
2546                        schema.clone(),
2547                        vec![new_ts_array(unit, vec![5, 6, 7, 8])],
2548                    )
2549                    .unwrap(),
2550                ],
2551            ),
2552            // This still will be read to detect boundary, but should not contribute to output
2553            (
2554                PartitionRange {
2555                    start: Timestamp::new(42, unit.into()),
2556                    end: Timestamp::new(52, unit.into()),
2557                    num_rows: 2,
2558                    identifier: 1,
2559                },
2560                vec![
2561                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![42, 51])])
2562                        .unwrap(),
2563                ],
2564            ),
2565            // This following one should not be read after boundary detected
2566            (
2567                PartitionRange {
2568                    start: Timestamp::new(48, unit.into()),
2569                    end: Timestamp::new(53, unit.into()),
2570                    num_rows: 2,
2571                    identifier: 1,
2572                },
2573                vec![
2574                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![48, 51])])
2575                        .unwrap(),
2576                ],
2577            ),
2578        ];
2579
2580        // With limit=4, ascending: after processing all ranges, smallest 4 are [5, 6, 7, 8]
2581        // Threshold is 8 (4th smallest value), algorithm reads until start=42 > threshold=8
2582        let expected_output = Some(
2583            DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![5, 6, 7, 8])])
2584                .unwrap(),
2585        );
2586
2587        run_test(
2588            2005,
2589            input_ranged_data,
2590            schema.clone(),
2591            SortOptions {
2592                descending: false,
2593                ..Default::default()
2594            },
2595            Some(4),
2596            expected_output,
2597            Some(11), // Read first 4 ranges to confirm threshold boundary
2598        )
2599        .await;
2600    }
2601
2602    /// Test early stop behavior with null values in sort column.
2603    /// Verifies that nulls are handled correctly based on nulls_first option.
2604    #[tokio::test]
2605    async fn test_early_stop_with_nulls() {
2606        let unit = TimeUnit::Millisecond;
2607        let schema = Arc::new(Schema::new(vec![Field::new(
2608            "ts",
2609            DataType::Timestamp(unit, None),
2610            true, // nullable
2611        )]));
2612
2613        // Helper function to create nullable timestamp array
2614        let new_nullable_ts_array = |unit: TimeUnit, arr: Vec<Option<i64>>| -> ArrayRef {
2615            match unit {
2616                TimeUnit::Second => Arc::new(TimestampSecondArray::from(arr)) as ArrayRef,
2617                TimeUnit::Millisecond => Arc::new(TimestampMillisecondArray::from(arr)) as ArrayRef,
2618                TimeUnit::Microsecond => Arc::new(TimestampMicrosecondArray::from(arr)) as ArrayRef,
2619                TimeUnit::Nanosecond => Arc::new(TimestampNanosecondArray::from(arr)) as ArrayRef,
2620            }
2621        };
2622
2623        // Test case 1: nulls_first=true, null values should appear first
2624        // Group 1 (end=100): [null, null, 99, 98, 97] -> with limit=3, top 3 are [null, null, 99]
2625        // Threshold is 99, next group end=90, since 99 >= 90, we should stop early
2626        let input_ranged_data = vec![
2627            (
2628                PartitionRange {
2629                    start: Timestamp::new(70, unit.into()),
2630                    end: Timestamp::new(100, unit.into()),
2631                    num_rows: 5,
2632                    identifier: 0,
2633                },
2634                vec![
2635                    DfRecordBatch::try_new(
2636                        schema.clone(),
2637                        vec![new_nullable_ts_array(
2638                            unit,
2639                            vec![Some(99), Some(98), None, Some(97), None],
2640                        )],
2641                    )
2642                    .unwrap(),
2643                ],
2644            ),
2645            (
2646                PartitionRange {
2647                    start: Timestamp::new(50, unit.into()),
2648                    end: Timestamp::new(90, unit.into()),
2649                    num_rows: 3,
2650                    identifier: 1,
2651                },
2652                vec![
2653                    DfRecordBatch::try_new(
2654                        schema.clone(),
2655                        vec![new_nullable_ts_array(
2656                            unit,
2657                            vec![Some(89), Some(88), Some(87)],
2658                        )],
2659                    )
2660                    .unwrap(),
2661                ],
2662            ),
2663        ];
2664
2665        // With nulls_first=true, nulls sort before all values
2666        // For descending, order is: null, null, 99, 98, 97
2667        // With limit=3, we get: null, null, 99
2668        let expected_output = Some(
2669            DfRecordBatch::try_new(
2670                schema.clone(),
2671                vec![new_nullable_ts_array(unit, vec![None, None, Some(99)])],
2672            )
2673            .unwrap(),
2674        );
2675
2676        run_test(
2677            3000,
2678            input_ranged_data,
2679            schema.clone(),
2680            SortOptions {
2681                descending: true,
2682                nulls_first: true,
2683            },
2684            Some(3),
2685            expected_output,
2686            Some(8), // Must read both batches to detect group boundary
2687        )
2688        .await;
2689
2690        // Test case 2: nulls_last=true, null values should appear last
2691        // Group 1 (end=100): [99, 98, 97, null, null] -> with limit=3, top 3 are [99, 98, 97]
2692        // Threshold is 97, next group end=90, since 97 >= 90, we should stop early
2693        let input_ranged_data = vec![
2694            (
2695                PartitionRange {
2696                    start: Timestamp::new(70, unit.into()),
2697                    end: Timestamp::new(100, unit.into()),
2698                    num_rows: 5,
2699                    identifier: 0,
2700                },
2701                vec![
2702                    DfRecordBatch::try_new(
2703                        schema.clone(),
2704                        vec![new_nullable_ts_array(
2705                            unit,
2706                            vec![Some(99), Some(98), Some(97), None, None],
2707                        )],
2708                    )
2709                    .unwrap(),
2710                ],
2711            ),
2712            (
2713                PartitionRange {
2714                    start: Timestamp::new(50, unit.into()),
2715                    end: Timestamp::new(90, unit.into()),
2716                    num_rows: 3,
2717                    identifier: 1,
2718                },
2719                vec![
2720                    DfRecordBatch::try_new(
2721                        schema.clone(),
2722                        vec![new_nullable_ts_array(
2723                            unit,
2724                            vec![Some(89), Some(88), Some(87)],
2725                        )],
2726                    )
2727                    .unwrap(),
2728                ],
2729            ),
2730        ];
2731
2732        // With nulls_last=false (equivalent to nulls_first=false), values sort before nulls
2733        // For descending, order is: 99, 98, 97, null, null
2734        // With limit=3, we get: 99, 98, 97
2735        let expected_output = Some(
2736            DfRecordBatch::try_new(
2737                schema.clone(),
2738                vec![new_nullable_ts_array(
2739                    unit,
2740                    vec![Some(99), Some(98), Some(97)],
2741                )],
2742            )
2743            .unwrap(),
2744        );
2745
2746        run_test(
2747            3001,
2748            input_ranged_data,
2749            schema.clone(),
2750            SortOptions {
2751                descending: true,
2752                nulls_first: false,
2753            },
2754            Some(3),
2755            expected_output,
2756            Some(8), // Must read both batches to detect group boundary
2757        )
2758        .await;
2759    }
2760
2761    /// Test early stop behavior when there's only one group (no next group).
2762    /// In this case, can_stop_early should return false and we should process all data.
2763    #[tokio::test]
2764    async fn test_early_stop_single_group() {
2765        let unit = TimeUnit::Millisecond;
2766        let schema = Arc::new(Schema::new(vec![Field::new(
2767            "ts",
2768            DataType::Timestamp(unit, None),
2769            false,
2770        )]));
2771
2772        // Only one group (all ranges have the same end), no next group to compare against
2773        let input_ranged_data = vec![
2774            (
2775                PartitionRange {
2776                    start: Timestamp::new(70, unit.into()),
2777                    end: Timestamp::new(100, unit.into()),
2778                    num_rows: 6,
2779                    identifier: 0,
2780                },
2781                vec![
2782                    DfRecordBatch::try_new(
2783                        schema.clone(),
2784                        vec![new_ts_array(unit, vec![94, 95, 96, 97, 98, 99])],
2785                    )
2786                    .unwrap(),
2787                ],
2788            ),
2789            (
2790                PartitionRange {
2791                    start: Timestamp::new(50, unit.into()),
2792                    end: Timestamp::new(100, unit.into()),
2793                    num_rows: 3,
2794                    identifier: 1,
2795                },
2796                vec![
2797                    DfRecordBatch::try_new(
2798                        schema.clone(),
2799                        vec![new_ts_array(unit, vec![85, 86, 87])],
2800                    )
2801                    .unwrap(),
2802                ],
2803            ),
2804        ];
2805
2806        // Even though we have enough data in first range, we must process all
2807        // because there's no next group to compare threshold against
2808        let expected_output = Some(
2809            DfRecordBatch::try_new(
2810                schema.clone(),
2811                vec![new_ts_array(unit, vec![99, 98, 97, 96])],
2812            )
2813            .unwrap(),
2814        );
2815
2816        run_test(
2817            3002,
2818            input_ranged_data,
2819            schema.clone(),
2820            SortOptions {
2821                descending: true,
2822                ..Default::default()
2823            },
2824            Some(4),
2825            expected_output,
2826            Some(9), // Must read all batches since no early stop is possible
2827        )
2828        .await;
2829    }
2830
2831    /// Test early stop behavior when threshold exactly equals next group's boundary.
2832    #[tokio::test]
2833    async fn test_early_stop_exact_boundary_equality() {
2834        let unit = TimeUnit::Millisecond;
2835        let schema = Arc::new(Schema::new(vec![Field::new(
2836            "ts",
2837            DataType::Timestamp(unit, None),
2838            false,
2839        )]));
2840
2841        // Test case 1: Descending sort, threshold == next_group_end
2842        // Group 1 (end=100): data up to 90, threshold = 90, next_group_end = 90
2843        // Since 90 >= 90, we should stop early
2844        let input_ranged_data = vec![
2845            (
2846                PartitionRange {
2847                    start: Timestamp::new(70, unit.into()),
2848                    end: Timestamp::new(100, unit.into()),
2849                    num_rows: 4,
2850                    identifier: 0,
2851                },
2852                vec![
2853                    DfRecordBatch::try_new(
2854                        schema.clone(),
2855                        vec![new_ts_array(unit, vec![92, 91, 90, 89])],
2856                    )
2857                    .unwrap(),
2858                ],
2859            ),
2860            (
2861                PartitionRange {
2862                    start: Timestamp::new(50, unit.into()),
2863                    end: Timestamp::new(90, unit.into()),
2864                    num_rows: 3,
2865                    identifier: 1,
2866                },
2867                vec![
2868                    DfRecordBatch::try_new(
2869                        schema.clone(),
2870                        vec![new_ts_array(unit, vec![88, 87, 86])],
2871                    )
2872                    .unwrap(),
2873                ],
2874            ),
2875        ];
2876
2877        let expected_output = Some(
2878            DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![92, 91, 90])])
2879                .unwrap(),
2880        );
2881
2882        run_test(
2883            3003,
2884            input_ranged_data,
2885            schema.clone(),
2886            SortOptions {
2887                descending: true,
2888                ..Default::default()
2889            },
2890            Some(3),
2891            expected_output,
2892            Some(7), // Must read both batches to detect boundary
2893        )
2894        .await;
2895
2896        // Test case 2: Ascending sort, threshold == next_group_start
2897        // Group 1 (start=10): data from 10, threshold = 20, next_group_start = 20
2898        // Since 20 < 20 is false, we should continue
2899        let input_ranged_data = vec![
2900            (
2901                PartitionRange {
2902                    start: Timestamp::new(10, unit.into()),
2903                    end: Timestamp::new(50, unit.into()),
2904                    num_rows: 4,
2905                    identifier: 0,
2906                },
2907                vec![
2908                    DfRecordBatch::try_new(
2909                        schema.clone(),
2910                        vec![new_ts_array(unit, vec![10, 15, 20, 25])],
2911                    )
2912                    .unwrap(),
2913                ],
2914            ),
2915            (
2916                PartitionRange {
2917                    start: Timestamp::new(20, unit.into()),
2918                    end: Timestamp::new(60, unit.into()),
2919                    num_rows: 3,
2920                    identifier: 1,
2921                },
2922                vec![
2923                    DfRecordBatch::try_new(
2924                        schema.clone(),
2925                        vec![new_ts_array(unit, vec![21, 22, 23])],
2926                    )
2927                    .unwrap(),
2928                ],
2929            ),
2930        ];
2931
2932        let expected_output = Some(
2933            DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![10, 15, 20])])
2934                .unwrap(),
2935        );
2936
2937        run_test(
2938            3004,
2939            input_ranged_data,
2940            schema.clone(),
2941            SortOptions {
2942                descending: false,
2943                ..Default::default()
2944            },
2945            Some(3),
2946            expected_output,
2947            Some(7), // Must read both batches since 20 is not < 20
2948        )
2949        .await;
2950    }
2951
2952    /// Test early stop behavior with empty partition groups.
2953    #[tokio::test]
2954    async fn test_early_stop_with_empty_partitions() {
2955        let unit = TimeUnit::Millisecond;
2956        let schema = Arc::new(Schema::new(vec![Field::new(
2957            "ts",
2958            DataType::Timestamp(unit, None),
2959            false,
2960        )]));
2961
2962        // Test case 1: First group is empty, second group has data
2963        let input_ranged_data = vec![
2964            (
2965                PartitionRange {
2966                    start: Timestamp::new(70, unit.into()),
2967                    end: Timestamp::new(100, unit.into()),
2968                    num_rows: 0,
2969                    identifier: 0,
2970                },
2971                vec![
2972                    // Empty batch for first range
2973                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![])])
2974                        .unwrap(),
2975                ],
2976            ),
2977            (
2978                PartitionRange {
2979                    start: Timestamp::new(50, unit.into()),
2980                    end: Timestamp::new(100, unit.into()),
2981                    num_rows: 0,
2982                    identifier: 1,
2983                },
2984                vec![
2985                    // Empty batch for second range
2986                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![])])
2987                        .unwrap(),
2988                ],
2989            ),
2990            (
2991                PartitionRange {
2992                    start: Timestamp::new(30, unit.into()),
2993                    end: Timestamp::new(80, unit.into()),
2994                    num_rows: 4,
2995                    identifier: 2,
2996                },
2997                vec![
2998                    DfRecordBatch::try_new(
2999                        schema.clone(),
3000                        vec![new_ts_array(unit, vec![74, 75, 76, 77])],
3001                    )
3002                    .unwrap(),
3003                ],
3004            ),
3005            (
3006                PartitionRange {
3007                    start: Timestamp::new(10, unit.into()),
3008                    end: Timestamp::new(60, unit.into()),
3009                    num_rows: 3,
3010                    identifier: 3,
3011                },
3012                vec![
3013                    DfRecordBatch::try_new(
3014                        schema.clone(),
3015                        vec![new_ts_array(unit, vec![58, 59, 60])],
3016                    )
3017                    .unwrap(),
3018                ],
3019            ),
3020        ];
3021
3022        // Group 1 (end=100) is empty, Group 2 (end=80) has data
3023        // Should continue to Group 2 since Group 1 has no data
3024        let expected_output = Some(
3025            DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![77, 76])]).unwrap(),
3026        );
3027
3028        run_test(
3029            3005,
3030            input_ranged_data,
3031            schema.clone(),
3032            SortOptions {
3033                descending: true,
3034                ..Default::default()
3035            },
3036            Some(2),
3037            expected_output,
3038            Some(7), // Must read until finding actual data
3039        )
3040        .await;
3041
3042        // Test case 2: Empty partitions between data groups
3043        let input_ranged_data = vec![
3044            (
3045                PartitionRange {
3046                    start: Timestamp::new(70, unit.into()),
3047                    end: Timestamp::new(100, unit.into()),
3048                    num_rows: 4,
3049                    identifier: 0,
3050                },
3051                vec![
3052                    DfRecordBatch::try_new(
3053                        schema.clone(),
3054                        vec![new_ts_array(unit, vec![96, 97, 98, 99])],
3055                    )
3056                    .unwrap(),
3057                ],
3058            ),
3059            (
3060                PartitionRange {
3061                    start: Timestamp::new(50, unit.into()),
3062                    end: Timestamp::new(90, unit.into()),
3063                    num_rows: 0,
3064                    identifier: 1,
3065                },
3066                vec![
3067                    // Empty range - should be skipped
3068                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![])])
3069                        .unwrap(),
3070                ],
3071            ),
3072            (
3073                PartitionRange {
3074                    start: Timestamp::new(30, unit.into()),
3075                    end: Timestamp::new(70, unit.into()),
3076                    num_rows: 0,
3077                    identifier: 2,
3078                },
3079                vec![
3080                    // Another empty range
3081                    DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![])])
3082                        .unwrap(),
3083                ],
3084            ),
3085            (
3086                PartitionRange {
3087                    start: Timestamp::new(10, unit.into()),
3088                    end: Timestamp::new(50, unit.into()),
3089                    num_rows: 3,
3090                    identifier: 3,
3091                },
3092                vec![
3093                    DfRecordBatch::try_new(
3094                        schema.clone(),
3095                        vec![new_ts_array(unit, vec![48, 49, 50])],
3096                    )
3097                    .unwrap(),
3098                ],
3099            ),
3100        ];
3101
3102        // With limit=2 from group 1: [99, 98], threshold=98, next group end=50
3103        // Since 98 >= 50, we should stop early
3104        let expected_output = Some(
3105            DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![99, 98])]).unwrap(),
3106        );
3107
3108        run_test(
3109            3006,
3110            input_ranged_data,
3111            schema.clone(),
3112            SortOptions {
3113                descending: true,
3114                ..Default::default()
3115            },
3116            Some(2),
3117            expected_output,
3118            Some(7), // Must read to detect early stop condition
3119        )
3120        .await;
3121    }
3122}