1use std::any::Any;
22use std::pin::Pin;
23use std::sync::Arc;
24use std::task::{Context, Poll};
25
26use arrow::array::{Array, ArrayRef};
27use arrow::compute::{concat, concat_batches, take_record_batch};
28use arrow_schema::{DataType, SchemaRef, TimeUnit};
29use common_recordbatch::{DfRecordBatch, DfSendableRecordBatchStream};
30use common_time::Timestamp;
31use datafusion::common::arrow::compute::sort_to_indices;
32use datafusion::execution::memory_pool::{MemoryConsumer, MemoryReservation};
33use datafusion::execution::{RecordBatchStream, TaskContext};
34use datafusion::physical_plan::execution_plan::CardinalityEffect;
35use datafusion::physical_plan::filter_pushdown::{
36 ChildFilterDescription, FilterDescription, FilterPushdownPhase,
37};
38use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
39use datafusion::physical_plan::{
40 DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties,
41};
42use datafusion_common::{DataFusionError, ScalarValue, internal_err};
43use datafusion_expr::Operator;
44use datafusion_physical_expr::expressions::{
45 BinaryExpr, DynamicFilterPhysicalExpr, is_not_null, is_null, lit,
46};
47use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
48use futures::Stream;
49use itertools::Itertools;
50use snafu::location;
51use store_api::region_engine::PartitionRange;
52
53use crate::error::Result;
54use crate::window_sort::{check_partition_range_monotonicity, project_partition_range_for_sort};
55use crate::{array_iter_helper, downcast_ts_array};
56
57fn get_primary_end(range: &PartitionRange, descending: bool) -> Timestamp {
62 if descending { range.end } else { range.start }
63}
64
65fn group_ranges_by_primary_end(
71 ranges: &[PartitionRange],
72 descending: bool,
73) -> Vec<(Timestamp, usize, usize)> {
74 if ranges.is_empty() {
75 return vec![];
76 }
77
78 let mut groups = Vec::new();
79 let mut group_start = 0;
80 let mut current_primary_end = get_primary_end(&ranges[0], descending);
81
82 for (idx, range) in ranges.iter().enumerate().skip(1) {
83 let primary_end = get_primary_end(range, descending);
84 if primary_end != current_primary_end {
85 groups.push((current_primary_end, group_start, idx));
87 group_start = idx;
89 current_primary_end = primary_end;
90 }
91 }
92 groups.push((current_primary_end, group_start, ranges.len()));
94
95 groups
96}
97
98#[derive(Debug, Clone)]
106pub struct PartSortExec {
107 expression: PhysicalSortExpr,
109 limit: Option<usize>,
110 input: Arc<dyn ExecutionPlan>,
111 metrics: ExecutionPlanMetricsSet,
113 partition_ranges: Vec<Vec<PartitionRange>>,
114 properties: Arc<PlanProperties>,
115 dynamic_filter: Option<Arc<DynamicFilterPhysicalExpr>>,
116}
117
118impl PartSortExec {
119 pub fn try_new(
120 expression: PhysicalSortExpr,
121 limit: Option<usize>,
122 partition_ranges: Vec<Vec<PartitionRange>>,
123 input: Arc<dyn ExecutionPlan>,
124 ) -> Result<Self> {
125 check_partition_range_monotonicity(&partition_ranges, expression.options.descending)?;
126
127 let metrics = ExecutionPlanMetricsSet::new();
128 let properties = input.properties();
129 let properties = Arc::new(PlanProperties::new(
130 input.equivalence_properties().clone(),
131 input.output_partitioning().clone(),
132 properties.emission_type,
133 properties.boundedness,
134 ));
135
136 let dynamic_filter = Self::new_dynamic_filter(&expression, limit);
137
138 Ok(Self {
139 expression,
140 limit,
141 input,
142 metrics,
143 partition_ranges,
144 properties,
145 dynamic_filter,
146 })
147 }
148
149 fn new_dynamic_filter(
150 expression: &PhysicalSortExpr,
151 limit: Option<usize>,
152 ) -> Option<Arc<DynamicFilterPhysicalExpr>> {
153 limit.map(|_| {
154 Arc::new(DynamicFilterPhysicalExpr::new(
155 vec![expression.expr.clone()],
156 lit(true),
157 ))
158 })
159 }
160
161 pub fn to_stream(
162 &self,
163 context: Arc<TaskContext>,
164 partition: usize,
165 ) -> datafusion_common::Result<DfSendableRecordBatchStream> {
166 let input_stream: DfSendableRecordBatchStream =
167 self.input.execute(partition, context.clone())?;
168
169 if partition >= self.partition_ranges.len() {
170 internal_err!(
171 "Partition index out of range: {} >= {} at {}",
172 partition,
173 self.partition_ranges.len(),
174 snafu::location!()
175 )?;
176 }
177
178 let df_stream = Box::pin(PartSortStream::new(
179 context,
180 self,
181 self.limit,
182 input_stream,
183 self.partition_ranges[partition].clone(),
184 partition,
185 )?) as _;
186
187 Ok(df_stream)
188 }
189}
190
191impl DisplayAs for PartSortExec {
192 fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
193 write!(
194 f,
195 "PartSortExec: expr={} num_ranges={}",
196 self.expression,
197 self.partition_ranges.len(),
198 )?;
199 if let Some(limit) = self.limit {
200 write!(f, " limit={}", limit)?;
201 }
202 Ok(())
203 }
204}
205
206impl ExecutionPlan for PartSortExec {
207 fn name(&self) -> &str {
208 "PartSortExec"
209 }
210
211 fn as_any(&self) -> &dyn Any {
212 self
213 }
214
215 fn schema(&self) -> SchemaRef {
216 self.input.schema()
217 }
218
219 fn properties(&self) -> &Arc<PlanProperties> {
220 &self.properties
221 }
222
223 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
224 vec![&self.input]
225 }
226
227 fn with_new_children(
228 self: Arc<Self>,
229 children: Vec<Arc<dyn ExecutionPlan>>,
230 ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
231 let new_input = if let Some(first) = children.first() {
232 first
233 } else {
234 internal_err!("No children found")?
235 };
236 let mut new_exec = self.as_ref().clone();
237 new_exec.input = new_input.clone();
238 new_exec.properties = new_input.properties().clone();
239 Ok(Arc::new(new_exec))
240 }
241
242 fn execute(
243 &self,
244 partition: usize,
245 context: Arc<TaskContext>,
246 ) -> datafusion_common::Result<DfSendableRecordBatchStream> {
247 self.to_stream(context, partition)
248 }
249
250 fn metrics(&self) -> Option<MetricsSet> {
251 Some(self.metrics.clone_inner())
252 }
253
254 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
260 vec![false]
261 }
262
263 fn cardinality_effect(&self) -> CardinalityEffect {
264 if self.limit.is_none() {
265 CardinalityEffect::Equal
266 } else {
267 CardinalityEffect::LowerEqual
268 }
269 }
270
271 fn gather_filters_for_pushdown(
272 &self,
273 phase: FilterPushdownPhase,
274 parent_filters: Vec<Arc<dyn PhysicalExpr>>,
275 config: &datafusion::config::ConfigOptions,
276 ) -> datafusion_common::Result<FilterDescription> {
277 if !matches!(phase, FilterPushdownPhase::Post) {
278 return FilterDescription::from_children(parent_filters, &self.children());
279 }
280
281 let mut child = ChildFilterDescription::from_child(&parent_filters, &self.input)?;
282 if let Some(filter) = &self.dynamic_filter
283 && config.optimizer.enable_topk_dynamic_filter_pushdown
284 {
285 let filter: Arc<dyn PhysicalExpr> = filter.clone();
286 child = child.with_self_filter(filter);
287 }
288 Ok(FilterDescription::new().with_child(child))
289 }
290
291 fn reset_state(self: Arc<Self>) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
292 let dynamic_filter = Self::new_dynamic_filter(&self.expression, self.limit);
293 Ok(Arc::new(Self {
294 expression: self.expression.clone(),
295 limit: self.limit,
296 input: self.input.clone(),
297 metrics: self.metrics.clone(),
298 partition_ranges: self.partition_ranges.clone(),
299 properties: self.properties.clone(),
300 dynamic_filter,
301 }))
302 }
303}
304
305enum PartSortBuffer {
306 All(Vec<DfRecordBatch>),
307 TopK(Vec<DfRecordBatch>),
308}
309
310#[derive(Clone, Debug, PartialEq, Eq)]
311enum TopKThreshold {
312 Null,
313 Value(i64),
314}
315
316impl PartSortBuffer {
317 pub fn is_empty(&self) -> bool {
318 match self {
319 PartSortBuffer::All(v) => v.is_empty(),
320 PartSortBuffer::TopK(v) => v.is_empty(),
321 }
322 }
323
324 pub fn num_rows(&self) -> usize {
325 match self {
326 PartSortBuffer::All(v) => v.iter().map(|batch| batch.num_rows()).sum(),
327 PartSortBuffer::TopK(v) => v.iter().map(|batch| batch.num_rows()).sum(),
328 }
329 }
330}
331
332struct PartSortStream {
333 reservation: MemoryReservation,
335 buffer: PartSortBuffer,
336 expression: PhysicalSortExpr,
337 limit: Option<usize>,
338 input: DfSendableRecordBatchStream,
339 input_complete: bool,
340 schema: SchemaRef,
341 partition_ranges: Vec<PartitionRange>,
342 #[allow(dead_code)] partition: usize,
344 cur_part_idx: usize,
345 evaluating_batch: Option<DfRecordBatch>,
346 metrics: BaselineMetrics,
347 dynamic_filter: Option<Arc<DynamicFilterPhysicalExpr>>,
348 dynamic_filter_threshold: Option<TopKThreshold>,
349 range_groups: Vec<(Timestamp, usize, usize)>,
352 cur_group_idx: usize,
354}
355
356impl PartSortStream {
357 fn new(
358 context: Arc<TaskContext>,
359 sort: &PartSortExec,
360 limit: Option<usize>,
361 input: DfSendableRecordBatchStream,
362 partition_ranges: Vec<PartitionRange>,
363 partition: usize,
364 ) -> datafusion_common::Result<Self> {
365 let buffer = if limit.is_some() {
366 PartSortBuffer::TopK(Vec::new())
367 } else {
368 PartSortBuffer::All(Vec::new())
369 };
370
371 let descending = sort.expression.options.descending;
373 let range_groups = group_ranges_by_primary_end(&partition_ranges, descending);
374
375 Ok(Self {
376 reservation: MemoryConsumer::new("PartSortStream".to_string())
377 .register(&context.runtime_env().memory_pool),
378 buffer,
379 expression: sort.expression.clone(),
380 limit,
381 input,
382 input_complete: false,
383 schema: sort.input.schema(),
384 partition_ranges,
385 partition,
386 cur_part_idx: 0,
387 evaluating_batch: None,
388 metrics: BaselineMetrics::new(&sort.metrics, partition),
389 dynamic_filter: sort.dynamic_filter.clone(),
390 dynamic_filter_threshold: None,
391 range_groups,
392 cur_group_idx: 0,
393 })
394 }
395}
396
397macro_rules! array_check_helper {
398 ($t:ty, $unit:expr, $arr:expr, $cur_range:expr, $min_max_idx:expr) => {{
399 if $cur_range.start.unit().as_arrow_time_unit() != $unit
400 || $cur_range.end.unit().as_arrow_time_unit() != $unit
401 {
402 internal_err!(
403 "PartitionRange unit mismatch, expect {:?}, found {:?}",
404 $cur_range.start.unit(),
405 $unit
406 )?;
407 }
408 let arr = $arr
409 .as_any()
410 .downcast_ref::<arrow::array::PrimitiveArray<$t>>()
411 .unwrap();
412
413 let min = arr.value($min_max_idx.0);
414 let max = arr.value($min_max_idx.1);
415 let (min, max) = if min < max{
416 (min, max)
417 } else {
418 (max, min)
419 };
420 let cur_min = $cur_range.start.value();
421 let cur_max = $cur_range.end.value();
422 if !(min >= cur_min && max < cur_max) {
424 internal_err!(
425 "Sort column min/max value out of partition range: sort_column.min_max=[{:?}, {:?}] not in PartitionRange=[{:?}, {:?}]",
426 min,
427 max,
428 cur_min,
429 cur_max
430 )?;
431 }
432 }};
433}
434
435macro_rules! threshold_helper {
436 ($t:ty, $unit:expr, $arr:expr, $threshold_idx:expr) => {{
437 let arr = $arr
438 .as_any()
439 .downcast_ref::<arrow::array::PrimitiveArray<$t>>()
440 .unwrap();
441 if arr.is_null($threshold_idx) {
442 TopKThreshold::Null
443 } else {
444 TopKThreshold::Value(arr.value($threshold_idx))
445 }
446 }};
447}
448
449impl PartSortStream {
450 fn check_in_range(
454 &self,
455 sort_column: &ArrayRef,
456 min_max_idx: (usize, usize),
457 ) -> datafusion_common::Result<()> {
458 let Some(cur_range) = self.get_current_group_effective_range() else {
460 internal_err!(
461 "No effective range for current group {} at {}",
462 self.cur_group_idx,
463 snafu::location!()
464 )?
465 };
466 let cur_range = project_partition_range_for_sort(cur_range, sort_column.data_type())?;
467
468 downcast_ts_array!(
469 sort_column.data_type() => (array_check_helper, sort_column, cur_range, min_max_idx),
470 _ => internal_err!(
471 "Unsupported data type for sort column: {:?}",
472 sort_column.data_type()
473 )?,
474 );
475
476 Ok(())
477 }
478
479 fn try_find_next_range(
484 &self,
485 sort_column: &ArrayRef,
486 ) -> datafusion_common::Result<Option<usize>> {
487 if sort_column.is_empty() {
488 return Ok(None);
489 }
490
491 if self.cur_part_idx >= self.partition_ranges.len() {
493 internal_err!(
494 "Partition index out of range: {} >= {} at {}",
495 self.cur_part_idx,
496 self.partition_ranges.len(),
497 snafu::location!()
498 )?;
499 }
500 let cur_range = project_partition_range_for_sort(
501 self.partition_ranges[self.cur_part_idx],
502 sort_column.data_type(),
503 )?;
504
505 let sort_column_iter = downcast_ts_array!(
506 sort_column.data_type() => (array_iter_helper, sort_column),
507 _ => internal_err!(
508 "Unsupported data type for sort column: {:?}",
509 sort_column.data_type()
510 )?,
511 );
512
513 for (idx, val) in sort_column_iter {
514 if let Some(val) = val
516 && (val >= cur_range.end.value() || val < cur_range.start.value())
517 {
518 return Ok(Some(idx));
519 }
520 }
521
522 Ok(None)
523 }
524
525 fn push_buffer(
526 &mut self,
527 batch: DfRecordBatch,
528 sort_data_type: &DataType,
529 ) -> datafusion_common::Result<()> {
530 let topk = matches!(self.buffer, PartSortBuffer::TopK(_));
531 match &mut self.buffer {
532 PartSortBuffer::All(v) => v.push(batch),
533 PartSortBuffer::TopK(v) => v.push(batch),
534 }
535
536 if topk {
537 let threshold = self.compact_topk_buffer(sort_data_type)?;
538 self.update_dynamic_filter(sort_data_type, threshold)?;
539 }
540
541 Ok(())
542 }
543
544 fn compact_topk_buffer(
545 &mut self,
546 sort_data_type: &DataType,
547 ) -> datafusion_common::Result<Option<TopKThreshold>> {
548 let Some(limit) = self.limit else {
549 return Ok(None);
550 };
551
552 let PartSortBuffer::TopK(buffer) =
553 std::mem::replace(&mut self.buffer, PartSortBuffer::TopK(Vec::new()))
554 else {
555 return Ok(None);
556 };
557
558 if limit == 0 || buffer.is_empty() {
559 self.buffer = PartSortBuffer::TopK(Vec::new());
560 return Ok(None);
561 }
562
563 let total_rows: usize = buffer.iter().map(|batch| batch.num_rows()).sum();
564 if total_rows <= limit {
565 self.buffer = PartSortBuffer::TopK(buffer);
566 return Ok(None);
567 }
568
569 let topk = self.sort_record_batches(&buffer, Some(limit), false)?;
570 let threshold = self.threshold_from_sorted_batch(&topk, sort_data_type)?;
571 self.buffer = if topk.num_rows() == 0 {
572 PartSortBuffer::TopK(Vec::new())
573 } else {
574 PartSortBuffer::TopK(vec![topk])
575 };
576
577 Ok(threshold)
578 }
579
580 fn threshold_from_sorted_batch(
581 &self,
582 batch: &DfRecordBatch,
583 sort_data_type: &DataType,
584 ) -> datafusion_common::Result<Option<TopKThreshold>> {
585 if batch.num_rows() == 0 {
586 return Ok(None);
587 }
588
589 let threshold_idx = batch.num_rows() - 1;
590 let sort_column = self.expression.evaluate_to_sort_column(batch)?.values;
591 let threshold = downcast_ts_array!(
592 sort_data_type => (threshold_helper, sort_column, threshold_idx),
593 _ => internal_err!(
594 "Unsupported data type for sort column: {:?}",
595 sort_data_type
596 )?,
597 );
598
599 Ok(Some(threshold))
600 }
601
602 fn topk_threshold(
603 &self,
604 sort_data_type: &arrow_schema::DataType,
605 ) -> datafusion_common::Result<Option<TopKThreshold>> {
606 let Some(limit) = self.limit else {
607 return Ok(None);
608 };
609
610 if limit == 0 || self.buffer.num_rows() < limit {
611 return Ok(None);
612 }
613
614 let buffer = match &self.buffer {
615 PartSortBuffer::All(buffer) | PartSortBuffer::TopK(buffer) => buffer,
616 };
617 let mut sort_columns = Vec::with_capacity(buffer.len());
618 let mut opt = None;
619 for batch in buffer {
620 let sort_column = self.expression.evaluate_to_sort_column(batch)?;
621 opt = opt.or(sort_column.options);
622 sort_columns.push(sort_column.values);
623 }
624
625 let sort_column =
626 concat(&sort_columns.iter().map(|a| a.as_ref()).collect_vec()).map_err(|e| {
627 DataFusionError::ArrowError(
628 Box::new(e),
629 Some(format!("Fail to concat sort columns at {}", location!())),
630 )
631 })?;
632
633 let indices = sort_to_indices(&sort_column, opt, Some(limit)).map_err(|e| {
634 DataFusionError::ArrowError(
635 Box::new(e),
636 Some(format!("Fail to sort to indices at {}", location!())),
637 )
638 })?;
639
640 if indices.len() < limit {
641 return Ok(None);
642 }
643
644 let threshold_idx = indices.value(indices.len() - 1) as usize;
645 let threshold = downcast_ts_array!(
646 sort_data_type => (threshold_helper, sort_column, threshold_idx),
647 _ => internal_err!(
648 "Unsupported data type for sort column: {:?}",
649 sort_data_type
650 )?,
651 );
652
653 Ok(Some(threshold))
654 }
655
656 fn threshold_scalar_value(
657 sort_data_type: &DataType,
658 threshold: &TopKThreshold,
659 ) -> datafusion_common::Result<ScalarValue> {
660 let value = match threshold {
661 TopKThreshold::Null => None,
662 TopKThreshold::Value(value) => Some(*value),
663 };
664
665 let scalar = match sort_data_type {
666 DataType::Timestamp(TimeUnit::Second, tz) => {
667 ScalarValue::TimestampSecond(value, tz.clone())
668 }
669 DataType::Timestamp(TimeUnit::Millisecond, tz) => {
670 ScalarValue::TimestampMillisecond(value, tz.clone())
671 }
672 DataType::Timestamp(TimeUnit::Microsecond, tz) => {
673 ScalarValue::TimestampMicrosecond(value, tz.clone())
674 }
675 DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
676 ScalarValue::TimestampNanosecond(value, tz.clone())
677 }
678 _ => internal_err!(
679 "Unsupported data type for sort column: {:?}",
680 sort_data_type
681 )?,
682 };
683
684 Ok(scalar)
685 }
686
687 fn build_dynamic_filter_expr(
688 &self,
689 sort_data_type: &DataType,
690 threshold: &TopKThreshold,
691 ) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
692 let op = if self.expression.options.descending {
693 Operator::Gt
694 } else {
695 Operator::Lt
696 };
697 let value_null = matches!(threshold, TopKThreshold::Null);
698 let value = Self::threshold_scalar_value(sort_data_type, threshold)?;
699 let comparison: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
700 self.expression.expr.clone(),
701 op,
702 lit(value),
703 ));
704
705 match (self.expression.options.nulls_first, value_null) {
706 (true, true) => Ok(lit(false)),
707 (true, false) => Ok(Arc::new(BinaryExpr::new(
708 is_null(self.expression.expr.clone())?,
709 Operator::Or,
710 comparison,
711 ))),
712 (false, true) => is_not_null(self.expression.expr.clone()),
713 (false, false) => Ok(comparison),
714 }
715 }
716
717 fn update_dynamic_filter(
718 &mut self,
719 sort_data_type: &DataType,
720 threshold: Option<TopKThreshold>,
721 ) -> datafusion_common::Result<()> {
722 let Some(filter) = &self.dynamic_filter else {
723 return Ok(());
724 };
725
726 let threshold = if let Some(threshold) = threshold {
727 threshold
728 } else {
729 let Some(threshold) = self.topk_threshold(sort_data_type)? else {
730 return Ok(());
731 };
732 threshold
733 };
734
735 if self.dynamic_filter_threshold.as_ref() == Some(&threshold) {
736 return Ok(());
737 }
738
739 let predicate = self.build_dynamic_filter_expr(sort_data_type, &threshold)?;
740 filter.update(predicate)?;
741 self.dynamic_filter_threshold = Some(threshold);
742
743 Ok(())
744 }
745
746 fn can_stop_before_group(
749 &self,
750 group_idx: usize,
751 sort_data_type: &arrow_schema::DataType,
752 ) -> datafusion_common::Result<bool> {
753 if group_idx >= self.range_groups.len() {
754 return Ok(false);
755 }
756
757 let threshold = if let Some(threshold) = &self.dynamic_filter_threshold {
758 threshold.clone()
759 } else {
760 let Some(threshold) = self.topk_threshold(sort_data_type)? else {
761 return Ok(false);
762 };
763 threshold
764 };
765
766 let (_, start_idx, _) = self.range_groups[group_idx];
767 let next_range =
768 project_partition_range_for_sort(self.partition_ranges[start_idx], sort_data_type)?;
769 let descending = self.expression.options.descending;
770 let next_primary = get_primary_end(&next_range, descending).value();
771
772 let can_stop = match threshold {
773 TopKThreshold::Null => self.expression.options.nulls_first,
779 TopKThreshold::Value(value) => {
780 if descending {
781 value >= next_primary
782 } else {
783 value < next_primary
784 }
785 }
786 };
787
788 Ok(can_stop)
789 }
790
791 fn is_in_current_group(&self, part_idx: usize) -> bool {
793 if self.cur_group_idx >= self.range_groups.len() {
794 return false;
795 }
796 let (_, start, end) = self.range_groups[self.cur_group_idx];
797 part_idx >= start && part_idx < end
798 }
799
800 fn advance_to_next_group(&mut self) -> bool {
802 self.cur_group_idx += 1;
803 self.cur_group_idx < self.range_groups.len()
804 }
805
806 fn get_current_group_effective_range(&self) -> Option<PartitionRange> {
810 if self.cur_group_idx >= self.range_groups.len() {
811 return None;
812 }
813 let (_, start_idx, end_idx) = self.range_groups[self.cur_group_idx];
814 if start_idx >= end_idx || start_idx >= self.partition_ranges.len() {
815 return None;
816 }
817
818 let ranges_in_group =
819 &self.partition_ranges[start_idx..end_idx.min(self.partition_ranges.len())];
820 if ranges_in_group.is_empty() {
821 return None;
822 }
823
824 let mut min_start = ranges_in_group[0].start;
826 let mut max_end = ranges_in_group[0].end;
827 for range in ranges_in_group.iter().skip(1) {
828 if range.start < min_start {
829 min_start = range.start;
830 }
831 if range.end > max_end {
832 max_end = range.end;
833 }
834 }
835
836 Some(PartitionRange {
837 start: min_start,
838 end: max_end,
839 num_rows: 0, identifier: 0, })
842 }
843
844 fn sort_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
848 match &mut self.buffer {
849 PartSortBuffer::All(_) => self.sort_all_buffer(),
850 PartSortBuffer::TopK(_) => self.sort_topk_buffer(),
851 }
852 }
853
854 fn sort_record_batches(
855 &mut self,
856 buffer: &[DfRecordBatch],
857 limit: Option<usize>,
858 check_range: bool,
859 ) -> datafusion_common::Result<DfRecordBatch> {
860 if buffer.is_empty() {
861 return Ok(DfRecordBatch::new_empty(self.schema.clone()));
862 }
863
864 let mut sort_columns = Vec::with_capacity(buffer.len());
865 let mut opt = None;
866 for batch in buffer.iter() {
867 let sort_column = self.expression.evaluate_to_sort_column(batch)?;
868 opt = opt.or(sort_column.options);
869 sort_columns.push(sort_column.values);
870 }
871
872 let sort_column =
873 concat(&sort_columns.iter().map(|a| a.as_ref()).collect_vec()).map_err(|e| {
874 DataFusionError::ArrowError(
875 Box::new(e),
876 Some(format!("Fail to concat sort columns at {}", location!())),
877 )
878 })?;
879
880 let indices = sort_to_indices(&sort_column, opt, limit).map_err(|e| {
881 DataFusionError::ArrowError(
882 Box::new(e),
883 Some(format!("Fail to sort to indices at {}", location!())),
884 )
885 })?;
886 if indices.is_empty() {
887 return Ok(DfRecordBatch::new_empty(self.schema.clone()));
888 }
889
890 if check_range {
891 self.check_in_range(
892 &sort_column,
893 (
894 indices.value(0) as usize,
895 indices.value(indices.len() - 1) as usize,
896 ),
897 )
898 .inspect_err(|_e| {
899 #[cfg(debug_assertions)]
900 common_telemetry::error!(
901 "Fail to check sort column in range at {}, current_idx: {}, num_rows: {}, err: {}",
902 self.partition,
903 self.cur_part_idx,
904 sort_column.len(),
905 _e
906 );
907 })?;
908 }
909
910 let total_mem: usize = buffer.iter().map(|r| r.get_array_memory_size()).sum();
912 self.reservation.try_grow(total_mem * 2)?;
913
914 let full_input = concat_batches(&self.schema, buffer).map_err(|e| {
915 DataFusionError::ArrowError(
916 Box::new(e),
917 Some(format!(
918 "Fail to concat input batches when sorting at {}",
919 location!()
920 )),
921 )
922 })?;
923
924 let sorted = take_record_batch(&full_input, &indices).map_err(|e| {
925 DataFusionError::ArrowError(
926 Box::new(e),
927 Some(format!(
928 "Fail to take result record batch when sorting at {}",
929 location!()
930 )),
931 )
932 })?;
933
934 drop(full_input);
935 self.reservation.shrink(2 * total_mem);
937 Ok(sorted)
938 }
939
940 fn sort_all_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
942 let PartSortBuffer::All(buffer) =
943 std::mem::replace(&mut self.buffer, PartSortBuffer::All(Vec::new()))
944 else {
945 unreachable!()
946 };
947
948 self.sort_record_batches(&buffer, self.limit, self.limit.is_none())
949 }
950
951 fn sort_topk_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
952 let PartSortBuffer::TopK(buffer) =
953 std::mem::replace(&mut self.buffer, PartSortBuffer::TopK(Vec::new()))
954 else {
955 unreachable!()
956 };
957
958 self.sort_record_batches(&buffer, self.limit, false)
959 }
960
961 fn sorted_buffer_if_non_empty(&mut self) -> datafusion_common::Result<Option<DfRecordBatch>> {
963 if self.buffer.is_empty() {
964 return Ok(None);
965 }
966
967 let sorted = self.sort_buffer()?;
968 if sorted.num_rows() == 0 {
969 Ok(None)
970 } else {
971 Ok(Some(sorted))
972 }
973 }
974
975 fn mark_dynamic_filter_complete(&self) {
976 if let Some(filter) = &self.dynamic_filter {
977 filter.mark_complete();
978 }
979 }
980
981 fn split_batch(
998 &mut self,
999 batch: DfRecordBatch,
1000 ) -> datafusion_common::Result<Option<DfRecordBatch>> {
1001 if self.limit.is_some() {
1002 self.split_batch_topk(batch)?;
1003 return Ok(None);
1004 }
1005
1006 self.split_batch_all(batch)
1007 }
1008
1009 fn split_batch_topk(&mut self, batch: DfRecordBatch) -> datafusion_common::Result<()> {
1015 if batch.num_rows() == 0 {
1016 return Ok(());
1017 }
1018
1019 let sort_column = self
1020 .expression
1021 .expr
1022 .evaluate(&batch)?
1023 .into_array(batch.num_rows())?;
1024
1025 let next_range_idx = self.try_find_next_range(&sort_column)?;
1026 let Some(idx) = next_range_idx else {
1027 self.push_buffer(batch, sort_column.data_type())?;
1028 return Ok(());
1030 };
1031
1032 let this_range = batch.slice(0, idx);
1033 let remaining_range = batch.slice(idx, batch.num_rows() - idx);
1034 if this_range.num_rows() != 0 {
1035 self.push_buffer(this_range, sort_column.data_type())?;
1036 }
1037
1038 self.cur_part_idx += 1;
1040
1041 if self.cur_part_idx >= self.partition_ranges.len() {
1043 debug_assert!(remaining_range.num_rows() == 0);
1044 self.input_complete = true;
1045 return Ok(());
1046 }
1047
1048 let in_same_group = self.is_in_current_group(self.cur_part_idx);
1050
1051 if !in_same_group {
1052 let next_group_idx = self.cur_group_idx + 1;
1053 if self.can_stop_before_group(next_group_idx, sort_column.data_type())? {
1054 self.input_complete = true;
1055 return Ok(());
1056 }
1057 self.advance_to_next_group();
1058 }
1059
1060 let next_sort_column = sort_column.slice(idx, batch.num_rows() - idx);
1061 if self.try_find_next_range(&next_sort_column)?.is_some() {
1062 self.evaluating_batch = Some(remaining_range);
1065 } else if remaining_range.num_rows() != 0 {
1066 self.push_buffer(remaining_range, sort_column.data_type())?;
1069 }
1070
1071 Ok(())
1072 }
1073
1074 fn split_batch_all(
1075 &mut self,
1076 batch: DfRecordBatch,
1077 ) -> datafusion_common::Result<Option<DfRecordBatch>> {
1078 if batch.num_rows() == 0 {
1079 return Ok(None);
1080 }
1081
1082 let sort_column = self
1083 .expression
1084 .expr
1085 .evaluate(&batch)?
1086 .into_array(batch.num_rows())?;
1087
1088 let next_range_idx = self.try_find_next_range(&sort_column)?;
1089 let Some(idx) = next_range_idx else {
1090 self.push_buffer(batch, sort_column.data_type())?;
1091 return Ok(None);
1093 };
1094
1095 let this_range = batch.slice(0, idx);
1096 let remaining_range = batch.slice(idx, batch.num_rows() - idx);
1097 if this_range.num_rows() != 0 {
1098 self.push_buffer(this_range, sort_column.data_type())?;
1099 }
1100
1101 self.cur_part_idx += 1;
1103
1104 if self.cur_part_idx >= self.partition_ranges.len() {
1106 debug_assert!(remaining_range.num_rows() == 0);
1108
1109 return self.sorted_buffer_if_non_empty();
1111 }
1112
1113 if self.is_in_current_group(self.cur_part_idx) {
1115 let next_sort_column = sort_column.slice(idx, batch.num_rows() - idx);
1117 if self.try_find_next_range(&next_sort_column)?.is_some() {
1118 self.evaluating_batch = Some(remaining_range);
1120 } else {
1121 if remaining_range.num_rows() != 0 {
1123 self.push_buffer(remaining_range, sort_column.data_type())?;
1124 }
1125 }
1126 return Ok(None);
1128 }
1129
1130 let sorted_batch = self.sorted_buffer_if_non_empty()?;
1132 self.advance_to_next_group();
1133
1134 let next_sort_column = sort_column.slice(idx, batch.num_rows() - idx);
1135 if self.try_find_next_range(&next_sort_column)?.is_some() {
1136 self.evaluating_batch = Some(remaining_range);
1139 } else {
1140 if remaining_range.num_rows() != 0 {
1143 self.push_buffer(remaining_range, sort_column.data_type())?;
1144 }
1145 }
1146
1147 Ok(sorted_batch)
1148 }
1149
1150 pub fn poll_next_inner(
1151 mut self: Pin<&mut Self>,
1152 cx: &mut Context<'_>,
1153 ) -> Poll<Option<datafusion_common::Result<DfRecordBatch>>> {
1154 loop {
1155 if self.input_complete {
1156 if let Some(sorted_batch) = self.sorted_buffer_if_non_empty()? {
1157 self.mark_dynamic_filter_complete();
1158 return Poll::Ready(Some(Ok(sorted_batch)));
1159 }
1160 self.mark_dynamic_filter_complete();
1161 return Poll::Ready(None);
1162 }
1163
1164 if let Some(evaluating_batch) = self.evaluating_batch.take()
1167 && evaluating_batch.num_rows() != 0
1168 {
1169 if self.cur_part_idx >= self.partition_ranges.len() {
1171 if let Some(sorted_batch) = self.sorted_buffer_if_non_empty()? {
1173 self.mark_dynamic_filter_complete();
1174 return Poll::Ready(Some(Ok(sorted_batch)));
1175 }
1176 self.mark_dynamic_filter_complete();
1177 return Poll::Ready(None);
1178 }
1179
1180 if let Some(sorted_batch) = self.split_batch(evaluating_batch)? {
1181 return Poll::Ready(Some(Ok(sorted_batch)));
1182 }
1183 continue;
1184 }
1185
1186 let res = self.input.as_mut().poll_next(cx);
1188 match res {
1189 Poll::Ready(Some(Ok(batch))) => {
1190 if let Some(sorted_batch) = self.split_batch(batch)? {
1191 return Poll::Ready(Some(Ok(sorted_batch)));
1192 }
1193 }
1194 Poll::Ready(None) => {
1196 self.input_complete = true;
1197 }
1198 Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
1199 Poll::Pending => return Poll::Pending,
1200 }
1201 }
1202 }
1203}
1204
1205impl Stream for PartSortStream {
1206 type Item = datafusion_common::Result<DfRecordBatch>;
1207
1208 fn poll_next(
1209 mut self: Pin<&mut Self>,
1210 cx: &mut Context<'_>,
1211 ) -> Poll<Option<datafusion_common::Result<DfRecordBatch>>> {
1212 let result = self.as_mut().poll_next_inner(cx);
1213 self.metrics.record_poll(result)
1214 }
1215}
1216
1217impl RecordBatchStream for PartSortStream {
1218 fn schema(&self) -> SchemaRef {
1219 self.schema.clone()
1220 }
1221}
1222
1223#[cfg(test)]
1224mod test {
1225 use std::sync::Arc;
1226
1227 use arrow::array::{
1228 BooleanArray, TimestampMicrosecondArray, TimestampMillisecondArray,
1229 TimestampNanosecondArray, TimestampSecondArray,
1230 };
1231 use arrow::json::ArrayWriter;
1232 use arrow_schema::{DataType, Field, Schema, SortOptions, TimeUnit};
1233 use common_time::Timestamp;
1234 use datafusion_physical_expr::expressions::Column;
1235 use futures::StreamExt;
1236 use store_api::region_engine::PartitionRange;
1237
1238 use super::*;
1239 use crate::test_util::{MockInputExec, new_ts_array};
1240
1241 #[ignore = "hard to gen expected data correctly here, TODO(discord9): fix it later"]
1242 #[tokio::test]
1243 async fn fuzzy_test() {
1244 let test_cnt = 100;
1245 let part_cnt_bound = 100;
1247 let range_size_bound = 100;
1249 let range_offset_bound = 100;
1250 let batch_cnt_bound = 20;
1252 let batch_size_bound = 100;
1253
1254 let mut rng = fastrand::Rng::new();
1255 rng.seed(1337);
1256
1257 let mut test_cases = Vec::new();
1258
1259 for case_id in 0..test_cnt {
1260 let mut bound_val: Option<i64> = None;
1261 let descending = rng.bool();
1262 let nulls_first = rng.bool();
1263 let opt = SortOptions {
1264 descending,
1265 nulls_first,
1266 };
1267 let limit = if rng.bool() {
1268 Some(rng.usize(1..batch_cnt_bound * batch_size_bound))
1269 } else {
1270 None
1271 };
1272 let unit = match rng.u8(0..3) {
1273 0 => TimeUnit::Second,
1274 1 => TimeUnit::Millisecond,
1275 2 => TimeUnit::Microsecond,
1276 _ => TimeUnit::Nanosecond,
1277 };
1278
1279 let schema = Schema::new(vec![Field::new(
1280 "ts",
1281 DataType::Timestamp(unit, None),
1282 false,
1283 )]);
1284 let schema = Arc::new(schema);
1285
1286 let mut input_ranged_data = vec![];
1287 let mut output_ranges = vec![];
1288 let mut output_data = vec![];
1289 for part_id in 0..rng.usize(0..part_cnt_bound) {
1291 let (start, end) = if descending {
1293 let end = bound_val
1295 .map(
1296 |i| i
1297 .checked_sub(rng.i64(1..=range_offset_bound))
1298 .expect("Bad luck, fuzzy test generate data that will overflow, change seed and try again")
1299 )
1300 .unwrap_or_else(|| rng.i64(-100000000..100000000));
1301 bound_val = Some(end);
1302 let start = end - rng.i64(1..range_size_bound);
1303 let start = Timestamp::new(start, unit.into());
1304 let end = Timestamp::new(end, unit.into());
1305 (start, end)
1306 } else {
1307 let start = bound_val
1309 .map(|i| i + rng.i64(1..=range_offset_bound))
1310 .unwrap_or_else(|| rng.i64(..));
1311 bound_val = Some(start);
1312 let end = start + rng.i64(1..range_size_bound);
1313 let start = Timestamp::new(start, unit.into());
1314 let end = Timestamp::new(end, unit.into());
1315 (start, end)
1316 };
1317 assert!(start < end);
1318
1319 let mut per_part_sort_data = vec![];
1320 let mut batches = vec![];
1321 for _batch_idx in 0..rng.usize(1..batch_cnt_bound) {
1322 let cnt = rng.usize(0..batch_size_bound) + 1;
1323 let iter = 0..rng.usize(0..cnt);
1324 let mut data_gen = iter
1325 .map(|_| rng.i64(start.value()..end.value()))
1326 .collect_vec();
1327 if data_gen.is_empty() {
1328 continue;
1330 }
1331 data_gen.sort();
1333 per_part_sort_data.extend(data_gen.clone());
1334 let arr = new_ts_array(unit, data_gen.clone());
1335 let batch = DfRecordBatch::try_new(schema.clone(), vec![arr]).unwrap();
1336 batches.push(batch);
1337 }
1338
1339 let range = PartitionRange {
1340 start,
1341 end,
1342 num_rows: batches.iter().map(|b| b.num_rows()).sum(),
1343 identifier: part_id,
1344 };
1345 input_ranged_data.push((range, batches));
1346
1347 output_ranges.push(range);
1348 if per_part_sort_data.is_empty() {
1349 continue;
1350 }
1351 output_data.extend_from_slice(&per_part_sort_data);
1352 }
1353
1354 let mut output_data_iter = output_data.iter().peekable();
1356 let mut output_data = vec![];
1357 for range in output_ranges.clone() {
1358 let mut cur_data = vec![];
1359 while let Some(val) = output_data_iter.peek() {
1360 if **val < range.start.value() || **val >= range.end.value() {
1361 break;
1362 }
1363 cur_data.push(*output_data_iter.next().unwrap());
1364 }
1365
1366 if cur_data.is_empty() {
1367 continue;
1368 }
1369
1370 if descending {
1371 cur_data.sort_by(|a, b| b.cmp(a));
1372 } else {
1373 cur_data.sort();
1374 }
1375 output_data.push(cur_data);
1376 }
1377
1378 let expected_output = if let Some(limit) = limit {
1379 let mut accumulated = Vec::new();
1380 let mut seen = 0usize;
1381 for mut range_values in output_data {
1382 seen += range_values.len();
1383 accumulated.append(&mut range_values);
1384 if seen >= limit {
1385 break;
1386 }
1387 }
1388
1389 if accumulated.is_empty() {
1390 None
1391 } else {
1392 if descending {
1393 accumulated.sort_by(|a, b| b.cmp(a));
1394 } else {
1395 accumulated.sort();
1396 }
1397 accumulated.truncate(limit.min(accumulated.len()));
1398
1399 Some(
1400 DfRecordBatch::try_new(
1401 schema.clone(),
1402 vec![new_ts_array(unit, accumulated)],
1403 )
1404 .unwrap(),
1405 )
1406 }
1407 } else {
1408 let batches = output_data
1409 .into_iter()
1410 .map(|a| {
1411 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, a)]).unwrap()
1412 })
1413 .collect_vec();
1414 if batches.is_empty() {
1415 None
1416 } else {
1417 Some(concat_batches(&schema, &batches).unwrap())
1418 }
1419 };
1420
1421 test_cases.push((
1422 case_id,
1423 unit,
1424 input_ranged_data,
1425 schema,
1426 opt,
1427 limit,
1428 expected_output,
1429 ));
1430 }
1431
1432 for (case_id, _unit, input_ranged_data, schema, opt, limit, expected_output) in test_cases {
1433 run_test(
1434 case_id,
1435 input_ranged_data,
1436 schema,
1437 opt,
1438 limit,
1439 expected_output,
1440 None,
1441 )
1442 .await;
1443 }
1444 }
1445
1446 #[tokio::test]
1447 async fn simple_cases() {
1448 let testcases = vec![
1449 (
1450 TimeUnit::Millisecond,
1451 vec![
1452 ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]]),
1453 ((5, 10), vec![vec![5, 6], vec![7, 8]]),
1454 ],
1455 false,
1456 None,
1457 vec![vec![1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9]],
1458 ),
1459 (
1462 TimeUnit::Millisecond,
1463 vec![
1464 ((5, 10), vec![vec![5, 6], vec![7, 8, 9]]),
1465 ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
1466 ],
1467 true,
1468 None,
1469 vec![vec![9, 8, 8, 7, 7, 6, 6, 5, 5, 4, 3, 2, 1]],
1470 ),
1471 (
1472 TimeUnit::Millisecond,
1473 vec![
1474 ((5, 10), vec![]),
1475 ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
1476 ],
1477 true,
1478 None,
1479 vec![vec![8, 7, 6, 5, 4, 3, 2, 1]],
1480 ),
1481 (
1482 TimeUnit::Millisecond,
1483 vec![
1484 ((15, 20), vec![vec![17, 18, 19]]),
1485 ((10, 15), vec![]),
1486 ((5, 10), vec![]),
1487 ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
1488 ],
1489 true,
1490 None,
1491 vec![vec![19, 18, 17], vec![8, 7, 6, 5, 4, 3, 2, 1]],
1492 ),
1493 (
1494 TimeUnit::Millisecond,
1495 vec![
1496 ((15, 20), vec![]),
1497 ((10, 15), vec![]),
1498 ((5, 10), vec![]),
1499 ((0, 10), vec![]),
1500 ],
1501 true,
1502 None,
1503 vec![],
1504 ),
1505 (
1510 TimeUnit::Millisecond,
1511 vec![
1512 (
1513 (15, 20),
1514 vec![vec![15, 17, 19, 10, 11, 12, 5, 6, 7, 8, 9, 1, 2, 3, 4]],
1515 ),
1516 ((10, 15), vec![]),
1517 ((5, 10), vec![]),
1518 ((0, 10), vec![]),
1519 ],
1520 true,
1521 None,
1522 vec![
1523 vec![19, 17, 15],
1524 vec![12, 11, 10],
1525 vec![9, 8, 7, 6, 5, 4, 3, 2, 1],
1526 ],
1527 ),
1528 (
1529 TimeUnit::Millisecond,
1530 vec![
1531 (
1532 (15, 20),
1533 vec![vec![15, 17, 19, 10, 11, 12, 5, 6, 7, 8, 9, 1, 2, 3, 4]],
1534 ),
1535 ((10, 15), vec![]),
1536 ((5, 10), vec![]),
1537 ((0, 10), vec![]),
1538 ],
1539 true,
1540 Some(2),
1541 vec![vec![19, 17]],
1542 ),
1543 ];
1544
1545 for (identifier, (unit, input_ranged_data, descending, limit, expected_output)) in
1546 testcases.into_iter().enumerate()
1547 {
1548 let schema = Schema::new(vec![Field::new(
1549 "ts",
1550 DataType::Timestamp(unit, None),
1551 false,
1552 )]);
1553 let schema = Arc::new(schema);
1554 let opt = SortOptions {
1555 descending,
1556 ..Default::default()
1557 };
1558
1559 let input_ranged_data = input_ranged_data
1560 .into_iter()
1561 .map(|(range, data)| {
1562 let part = PartitionRange {
1563 start: Timestamp::new(range.0, unit.into()),
1564 end: Timestamp::new(range.1, unit.into()),
1565 num_rows: data.iter().map(|b| b.len()).sum(),
1566 identifier,
1567 };
1568
1569 let batches = data
1570 .into_iter()
1571 .map(|b| {
1572 let arr = new_ts_array(unit, b);
1573 DfRecordBatch::try_new(schema.clone(), vec![arr]).unwrap()
1574 })
1575 .collect_vec();
1576 (part, batches)
1577 })
1578 .collect_vec();
1579
1580 let expected_output = expected_output
1581 .into_iter()
1582 .map(|a| {
1583 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, a)]).unwrap()
1584 })
1585 .collect_vec();
1586 let expected_output = if expected_output.is_empty() {
1587 None
1588 } else {
1589 Some(concat_batches(&schema, &expected_output).unwrap())
1590 };
1591
1592 run_test(
1593 identifier,
1594 input_ranged_data,
1595 schema.clone(),
1596 opt,
1597 limit,
1598 expected_output,
1599 None,
1600 )
1601 .await;
1602 }
1603 }
1604
1605 #[allow(clippy::print_stdout)]
1606 async fn run_test(
1607 case_id: usize,
1608 input_ranged_data: Vec<(PartitionRange, Vec<DfRecordBatch>)>,
1609 schema: SchemaRef,
1610 opt: SortOptions,
1611 limit: Option<usize>,
1612 expected_output: Option<DfRecordBatch>,
1613 expected_polled_rows: Option<usize>,
1614 ) {
1615 if let (Some(limit), Some(rb)) = (limit, &expected_output) {
1616 assert!(
1617 rb.num_rows() <= limit,
1618 "Expect row count in expected output({}) <= limit({})",
1619 rb.num_rows(),
1620 limit
1621 );
1622 }
1623
1624 let mut data_partition = Vec::with_capacity(input_ranged_data.len());
1625 let mut ranges = Vec::with_capacity(input_ranged_data.len());
1626 for (part_range, batches) in input_ranged_data {
1627 data_partition.push(batches);
1628 ranges.push(part_range);
1629 }
1630
1631 let mock_input = Arc::new(MockInputExec::new(data_partition, schema.clone()));
1632
1633 let exec = PartSortExec::try_new(
1634 PhysicalSortExpr {
1635 expr: Arc::new(Column::new("ts", 0)),
1636 options: opt,
1637 },
1638 limit,
1639 vec![ranges.clone()],
1640 mock_input.clone(),
1641 )
1642 .unwrap();
1643
1644 let exec_stream = exec.execute(0, Arc::new(TaskContext::default())).unwrap();
1645
1646 let real_output = exec_stream.map(|r| r.unwrap()).collect::<Vec<_>>().await;
1647 if limit.is_some() {
1648 assert!(
1649 real_output.len() <= 1,
1650 "case_{case_id} expects a single output batch when limit is set, got {}",
1651 real_output.len()
1652 );
1653 }
1654
1655 let actual_output = if real_output.is_empty() {
1656 None
1657 } else {
1658 Some(concat_batches(&schema, &real_output).unwrap())
1659 };
1660
1661 if let Some(expected_polled_rows) = expected_polled_rows {
1662 let input_pulled_rows = mock_input.metrics().unwrap().output_rows().unwrap();
1663 assert_eq!(input_pulled_rows, expected_polled_rows);
1664 }
1665
1666 match (actual_output, expected_output) {
1667 (None, None) => {}
1668 (Some(actual), Some(expected)) => {
1669 if actual != expected {
1670 let mut actual_json: Vec<u8> = Vec::new();
1671 let mut writer = ArrayWriter::new(&mut actual_json);
1672 writer.write(&actual).unwrap();
1673 writer.finish().unwrap();
1674
1675 let mut expected_json: Vec<u8> = Vec::new();
1676 let mut writer = ArrayWriter::new(&mut expected_json);
1677 writer.write(&expected).unwrap();
1678 writer.finish().unwrap();
1679
1680 panic!(
1681 "case_{} failed (limit {limit:?}), opt: {:?},\nreal_output: {}\nexpected: {}",
1682 case_id,
1683 opt,
1684 String::from_utf8_lossy(&actual_json),
1685 String::from_utf8_lossy(&expected_json),
1686 );
1687 }
1688 }
1689 (None, Some(expected)) => panic!(
1690 "case_{} failed (limit {limit:?}), opt: {:?},\nreal output is empty, expected {} rows",
1691 case_id,
1692 opt,
1693 expected.num_rows()
1694 ),
1695 (Some(actual), None) => panic!(
1696 "case_{} failed (limit {limit:?}), opt: {:?},\nreal output has {} rows, expected empty",
1697 case_id,
1698 opt,
1699 actual.num_rows()
1700 ),
1701 }
1702 }
1703
1704 #[tokio::test]
1707 async fn test_limit_with_multiple_batches_per_partition() {
1708 let unit = TimeUnit::Millisecond;
1709 let schema = Arc::new(Schema::new(vec![Field::new(
1710 "ts",
1711 DataType::Timestamp(unit, None),
1712 false,
1713 )]));
1714
1715 let input_ranged_data = vec![(
1719 PartitionRange {
1720 start: Timestamp::new(0, unit.into()),
1721 end: Timestamp::new(10, unit.into()),
1722 num_rows: 9,
1723 identifier: 0,
1724 },
1725 vec![
1726 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![1, 2, 3])])
1727 .unwrap(),
1728 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![4, 5, 6])])
1729 .unwrap(),
1730 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![7, 8, 9])])
1731 .unwrap(),
1732 ],
1733 )];
1734
1735 let expected_output = Some(
1736 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![9, 8, 7])])
1737 .unwrap(),
1738 );
1739
1740 run_test(
1741 1000,
1742 input_ranged_data,
1743 schema.clone(),
1744 SortOptions {
1745 descending: true,
1746 ..Default::default()
1747 },
1748 Some(3),
1749 expected_output,
1750 None,
1751 )
1752 .await;
1753
1754 let input_ranged_data = vec![
1758 (
1759 PartitionRange {
1760 start: Timestamp::new(10, unit.into()),
1761 end: Timestamp::new(20, unit.into()),
1762 num_rows: 6,
1763 identifier: 0,
1764 },
1765 vec![
1766 DfRecordBatch::try_new(
1767 schema.clone(),
1768 vec![new_ts_array(unit, vec![10, 11, 12])],
1769 )
1770 .unwrap(),
1771 DfRecordBatch::try_new(
1772 schema.clone(),
1773 vec![new_ts_array(unit, vec![13, 14, 15])],
1774 )
1775 .unwrap(),
1776 ],
1777 ),
1778 (
1779 PartitionRange {
1780 start: Timestamp::new(0, unit.into()),
1781 end: Timestamp::new(10, unit.into()),
1782 num_rows: 5,
1783 identifier: 1,
1784 },
1785 vec![
1786 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![1, 2, 3])])
1787 .unwrap(),
1788 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![4, 5])])
1789 .unwrap(),
1790 ],
1791 ),
1792 ];
1793
1794 let expected_output = Some(
1795 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![15, 14])]).unwrap(),
1796 );
1797
1798 run_test(
1799 1001,
1800 input_ranged_data,
1801 schema.clone(),
1802 SortOptions {
1803 descending: true,
1804 ..Default::default()
1805 },
1806 Some(2),
1807 expected_output,
1808 None,
1809 )
1810 .await;
1811
1812 let input_ranged_data = vec![(
1815 PartitionRange {
1816 start: Timestamp::new(0, unit.into()),
1817 end: Timestamp::new(10, unit.into()),
1818 num_rows: 9,
1819 identifier: 0,
1820 },
1821 vec![
1822 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![7, 8, 9])])
1823 .unwrap(),
1824 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![4, 5, 6])])
1825 .unwrap(),
1826 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![1, 2, 3])])
1827 .unwrap(),
1828 ],
1829 )];
1830
1831 let expected_output = Some(
1832 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![1, 2])]).unwrap(),
1833 );
1834
1835 run_test(
1836 1002,
1837 input_ranged_data,
1838 schema.clone(),
1839 SortOptions {
1840 descending: false,
1841 ..Default::default()
1842 },
1843 Some(2),
1844 expected_output,
1845 None,
1846 )
1847 .await;
1848 }
1849
1850 #[test]
1851 fn test_topk_buffer_is_bounded_and_updates_dynamic_filter() {
1852 let unit = TimeUnit::Millisecond;
1853 let schema = Arc::new(Schema::new(vec![Field::new(
1854 "ts",
1855 DataType::Timestamp(unit, None),
1856 false,
1857 )]));
1858 let sort_data_type = DataType::Timestamp(unit, None);
1859 let partition_range = PartitionRange {
1860 start: Timestamp::new(0, unit.into()),
1861 end: Timestamp::new(10, unit.into()),
1862 num_rows: 9,
1863 identifier: 0,
1864 };
1865 let mock_input = Arc::new(MockInputExec::new(vec![vec![]], schema.clone()));
1866 let exec = PartSortExec::try_new(
1867 PhysicalSortExpr {
1868 expr: Arc::new(Column::new("ts", 0)),
1869 options: SortOptions {
1870 descending: true,
1871 ..Default::default()
1872 },
1873 },
1874 Some(3),
1875 vec![vec![partition_range]],
1876 mock_input.clone(),
1877 )
1878 .unwrap();
1879 let input_stream = mock_input
1880 .execute(0, Arc::new(TaskContext::default()))
1881 .unwrap();
1882 let mut stream = PartSortStream::new(
1883 Arc::new(TaskContext::default()),
1884 &exec,
1885 Some(3),
1886 input_stream,
1887 vec![partition_range],
1888 0,
1889 )
1890 .unwrap();
1891
1892 for batch in [vec![1, 2, 3], vec![4, 5, 6], vec![0, 7, 8]] {
1893 stream
1894 .push_buffer(
1895 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, batch)])
1896 .unwrap(),
1897 &sort_data_type,
1898 )
1899 .unwrap();
1900 assert_eq!(stream.buffer.num_rows(), 3);
1901 }
1902
1903 let dynamic_filter = stream.dynamic_filter.as_ref().unwrap().clone();
1904 let probe = DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![5, 6, 7])])
1905 .unwrap();
1906 let predicate = dynamic_filter.current().unwrap();
1907 let result = predicate
1908 .evaluate(&probe)
1909 .unwrap()
1910 .into_array(probe.num_rows())
1911 .unwrap();
1912 let result = result.as_any().downcast_ref::<BooleanArray>().unwrap();
1913 assert_eq!(result, &BooleanArray::from(vec![false, false, true]));
1914
1915 let expected =
1916 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![8, 7, 6])])
1917 .unwrap();
1918 assert_eq!(stream.sort_buffer().unwrap(), expected);
1919 }
1920
1921 #[test]
1922 fn test_topk_limit_zero_clears_buffer_without_threshold() {
1923 let unit = TimeUnit::Millisecond;
1924 let schema = Arc::new(Schema::new(vec![Field::new(
1925 "ts",
1926 DataType::Timestamp(unit, None),
1927 false,
1928 )]));
1929 let sort_data_type = DataType::Timestamp(unit, None);
1930 let partition_range = PartitionRange {
1931 start: Timestamp::new(0, unit.into()),
1932 end: Timestamp::new(10, unit.into()),
1933 num_rows: 3,
1934 identifier: 0,
1935 };
1936 let mock_input = Arc::new(MockInputExec::new(vec![vec![]], schema.clone()));
1937 let exec = PartSortExec::try_new(
1938 PhysicalSortExpr {
1939 expr: Arc::new(Column::new("ts", 0)),
1940 options: SortOptions {
1941 descending: true,
1942 ..Default::default()
1943 },
1944 },
1945 Some(0),
1946 vec![vec![partition_range]],
1947 mock_input.clone(),
1948 )
1949 .unwrap();
1950 let input_stream = mock_input
1951 .execute(0, Arc::new(TaskContext::default()))
1952 .unwrap();
1953 let mut stream = PartSortStream::new(
1954 Arc::new(TaskContext::default()),
1955 &exec,
1956 Some(0),
1957 input_stream,
1958 vec![partition_range],
1959 0,
1960 )
1961 .unwrap();
1962
1963 stream
1964 .push_buffer(
1965 DfRecordBatch::try_new(schema, vec![new_ts_array(unit, vec![1, 2, 3])]).unwrap(),
1966 &sort_data_type,
1967 )
1968 .unwrap();
1969
1970 assert_eq!(stream.buffer.num_rows(), 0);
1971 assert_eq!(stream.dynamic_filter_threshold, None);
1972 }
1973
1974 #[tokio::test]
1978 async fn test_early_termination() {
1979 let unit = TimeUnit::Millisecond;
1980 let schema = Arc::new(Schema::new(vec![Field::new(
1981 "ts",
1982 DataType::Timestamp(unit, None),
1983 false,
1984 )]));
1985
1986 let input_ranged_data = vec![
1991 (
1992 PartitionRange {
1993 start: Timestamp::new(20, unit.into()),
1994 end: Timestamp::new(30, unit.into()),
1995 num_rows: 10,
1996 identifier: 2,
1997 },
1998 vec![
1999 DfRecordBatch::try_new(
2000 schema.clone(),
2001 vec![new_ts_array(unit, vec![21, 22, 23, 24, 25])],
2002 )
2003 .unwrap(),
2004 DfRecordBatch::try_new(
2005 schema.clone(),
2006 vec![new_ts_array(unit, vec![26, 27, 28, 29, 30])],
2007 )
2008 .unwrap(),
2009 ],
2010 ),
2011 (
2012 PartitionRange {
2013 start: Timestamp::new(10, unit.into()),
2014 end: Timestamp::new(20, unit.into()),
2015 num_rows: 10,
2016 identifier: 1,
2017 },
2018 vec![
2019 DfRecordBatch::try_new(
2020 schema.clone(),
2021 vec![new_ts_array(unit, vec![11, 12, 13, 14, 15])],
2022 )
2023 .unwrap(),
2024 DfRecordBatch::try_new(
2025 schema.clone(),
2026 vec![new_ts_array(unit, vec![16, 17, 18, 19, 20])],
2027 )
2028 .unwrap(),
2029 ],
2030 ),
2031 (
2032 PartitionRange {
2033 start: Timestamp::new(0, unit.into()),
2034 end: Timestamp::new(10, unit.into()),
2035 num_rows: 10,
2036 identifier: 0,
2037 },
2038 vec![
2039 DfRecordBatch::try_new(
2040 schema.clone(),
2041 vec![new_ts_array(unit, vec![1, 2, 3, 4, 5])],
2042 )
2043 .unwrap(),
2044 DfRecordBatch::try_new(
2045 schema.clone(),
2046 vec![new_ts_array(unit, vec![6, 7, 8, 9, 10])],
2047 )
2048 .unwrap(),
2049 ],
2050 ),
2051 ];
2052
2053 let expected_output = Some(
2057 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![29, 28])]).unwrap(),
2058 );
2059
2060 run_test(
2061 1003,
2062 input_ranged_data,
2063 schema.clone(),
2064 SortOptions {
2065 descending: true,
2066 ..Default::default()
2067 },
2068 Some(2),
2069 expected_output,
2070 Some(10),
2071 )
2072 .await;
2073 }
2074
2075 #[tokio::test]
2079 async fn test_primary_end_grouping_with_limit() {
2080 let unit = TimeUnit::Millisecond;
2081 let schema = Arc::new(Schema::new(vec![Field::new(
2082 "ts",
2083 DataType::Timestamp(unit, None),
2084 false,
2085 )]));
2086
2087 let input_ranged_data = vec![
2091 (
2092 PartitionRange {
2093 start: Timestamp::new(70, unit.into()),
2094 end: Timestamp::new(100, unit.into()),
2095 num_rows: 3,
2096 identifier: 0,
2097 },
2098 vec![
2099 DfRecordBatch::try_new(
2100 schema.clone(),
2101 vec![new_ts_array(unit, vec![80, 90, 95])],
2102 )
2103 .unwrap(),
2104 ],
2105 ),
2106 (
2107 PartitionRange {
2108 start: Timestamp::new(50, unit.into()),
2109 end: Timestamp::new(100, unit.into()),
2110 num_rows: 5,
2111 identifier: 1,
2112 },
2113 vec![
2114 DfRecordBatch::try_new(
2115 schema.clone(),
2116 vec![new_ts_array(unit, vec![55, 65, 75, 85, 95])],
2117 )
2118 .unwrap(),
2119 ],
2120 ),
2121 ];
2122
2123 let expected_output = Some(
2127 DfRecordBatch::try_new(
2128 schema.clone(),
2129 vec![new_ts_array(unit, vec![95, 95, 90, 85])],
2130 )
2131 .unwrap(),
2132 );
2133
2134 run_test(
2135 2000,
2136 input_ranged_data,
2137 schema.clone(),
2138 SortOptions {
2139 descending: true,
2140 ..Default::default()
2141 },
2142 Some(4),
2143 expected_output,
2144 None,
2145 )
2146 .await;
2147 }
2148
2149 #[tokio::test]
2160 async fn test_three_ranges_keep_pulling() {
2161 let unit = TimeUnit::Millisecond;
2162 let schema = Arc::new(Schema::new(vec![Field::new(
2163 "ts",
2164 DataType::Timestamp(unit, None),
2165 false,
2166 )]));
2167
2168 let input_ranged_data = vec![
2170 (
2171 PartitionRange {
2172 start: Timestamp::new(70, unit.into()),
2173 end: Timestamp::new(100, unit.into()),
2174 num_rows: 3,
2175 identifier: 0,
2176 },
2177 vec![
2178 DfRecordBatch::try_new(
2179 schema.clone(),
2180 vec![new_ts_array(unit, vec![80, 90, 95])],
2181 )
2182 .unwrap(),
2183 ],
2184 ),
2185 (
2186 PartitionRange {
2187 start: Timestamp::new(50, unit.into()),
2188 end: Timestamp::new(100, unit.into()),
2189 num_rows: 3,
2190 identifier: 1,
2191 },
2192 vec![
2193 DfRecordBatch::try_new(
2194 schema.clone(),
2195 vec![new_ts_array(unit, vec![55, 75, 85])],
2196 )
2197 .unwrap(),
2198 ],
2199 ),
2200 (
2201 PartitionRange {
2202 start: Timestamp::new(40, unit.into()),
2203 end: Timestamp::new(95, unit.into()),
2204 num_rows: 3,
2205 identifier: 2,
2206 },
2207 vec![
2208 DfRecordBatch::try_new(
2209 schema.clone(),
2210 vec![new_ts_array(unit, vec![45, 65, 94])],
2211 )
2212 .unwrap(),
2213 ],
2214 ),
2215 ];
2216
2217 let expected_output = Some(
2221 DfRecordBatch::try_new(
2222 schema.clone(),
2223 vec![new_ts_array(unit, vec![95, 94, 90, 85])],
2224 )
2225 .unwrap(),
2226 );
2227
2228 run_test(
2229 2001,
2230 input_ranged_data,
2231 schema.clone(),
2232 SortOptions {
2233 descending: true,
2234 ..Default::default()
2235 },
2236 Some(4),
2237 expected_output,
2238 None,
2239 )
2240 .await;
2241 }
2242
2243 #[tokio::test]
2247 async fn test_threshold_based_early_termination() {
2248 let unit = TimeUnit::Millisecond;
2249 let schema = Arc::new(Schema::new(vec![Field::new(
2250 "ts",
2251 DataType::Timestamp(unit, None),
2252 false,
2253 )]));
2254
2255 let input_ranged_data = vec![
2259 (
2260 PartitionRange {
2261 start: Timestamp::new(70, unit.into()),
2262 end: Timestamp::new(100, unit.into()),
2263 num_rows: 6,
2264 identifier: 0,
2265 },
2266 vec![
2267 DfRecordBatch::try_new(
2268 schema.clone(),
2269 vec![new_ts_array(unit, vec![94, 95, 96, 97, 98, 99])],
2270 )
2271 .unwrap(),
2272 ],
2273 ),
2274 (
2275 PartitionRange {
2276 start: Timestamp::new(50, unit.into()),
2277 end: Timestamp::new(90, unit.into()),
2278 num_rows: 3,
2279 identifier: 1,
2280 },
2281 vec![
2282 DfRecordBatch::try_new(
2283 schema.clone(),
2284 vec![new_ts_array(unit, vec![85, 86, 87])],
2285 )
2286 .unwrap(),
2287 ],
2288 ),
2289 ];
2290
2291 let expected_output = Some(
2295 DfRecordBatch::try_new(
2296 schema.clone(),
2297 vec![new_ts_array(unit, vec![99, 98, 97, 96])],
2298 )
2299 .unwrap(),
2300 );
2301
2302 run_test(
2303 2002,
2304 input_ranged_data,
2305 schema.clone(),
2306 SortOptions {
2307 descending: true,
2308 ..Default::default()
2309 },
2310 Some(4),
2311 expected_output,
2312 Some(9), )
2314 .await;
2315 }
2316
2317 #[tokio::test]
2321 async fn test_continue_when_threshold_in_next_group_range() {
2322 let unit = TimeUnit::Millisecond;
2323 let schema = Arc::new(Schema::new(vec![Field::new(
2324 "ts",
2325 DataType::Timestamp(unit, None),
2326 false,
2327 )]));
2328
2329 let input_ranged_data = vec![
2333 (
2334 PartitionRange {
2335 start: Timestamp::new(90, unit.into()),
2336 end: Timestamp::new(100, unit.into()),
2337 num_rows: 6,
2338 identifier: 0,
2339 },
2340 vec![
2341 DfRecordBatch::try_new(
2342 schema.clone(),
2343 vec![new_ts_array(unit, vec![94, 95, 96, 97, 98, 99])],
2344 )
2345 .unwrap(),
2346 ],
2347 ),
2348 (
2349 PartitionRange {
2350 start: Timestamp::new(50, unit.into()),
2351 end: Timestamp::new(98, unit.into()),
2352 num_rows: 3,
2353 identifier: 1,
2354 },
2355 vec![
2356 DfRecordBatch::try_new(
2358 schema.clone(),
2359 vec![new_ts_array(unit, vec![55, 60, 65])],
2360 )
2361 .unwrap(),
2362 ],
2363 ),
2364 ];
2365
2366 let expected_output = Some(
2371 DfRecordBatch::try_new(
2372 schema.clone(),
2373 vec![new_ts_array(unit, vec![99, 98, 97, 96])],
2374 )
2375 .unwrap(),
2376 );
2377
2378 run_test(
2381 2003,
2382 input_ranged_data,
2383 schema.clone(),
2384 SortOptions {
2385 descending: true,
2386 ..Default::default()
2387 },
2388 Some(4),
2389 expected_output,
2390 Some(9), )
2392 .await;
2393 }
2394
2395 #[tokio::test]
2397 async fn test_ascending_threshold_early_termination() {
2398 let unit = TimeUnit::Millisecond;
2399 let schema = Arc::new(Schema::new(vec![Field::new(
2400 "ts",
2401 DataType::Timestamp(unit, None),
2402 false,
2403 )]));
2404
2405 let input_ranged_data = vec![
2410 (
2411 PartitionRange {
2412 start: Timestamp::new(10, unit.into()),
2413 end: Timestamp::new(50, unit.into()),
2414 num_rows: 6,
2415 identifier: 0,
2416 },
2417 vec![
2418 DfRecordBatch::try_new(
2419 schema.clone(),
2420 vec![new_ts_array(unit, vec![10, 11, 12, 13, 14, 15])],
2421 )
2422 .unwrap(),
2423 ],
2424 ),
2425 (
2426 PartitionRange {
2427 start: Timestamp::new(20, unit.into()),
2428 end: Timestamp::new(60, unit.into()),
2429 num_rows: 3,
2430 identifier: 1,
2431 },
2432 vec![
2433 DfRecordBatch::try_new(
2434 schema.clone(),
2435 vec![new_ts_array(unit, vec![25, 30, 35])],
2436 )
2437 .unwrap(),
2438 ],
2439 ),
2440 (
2442 PartitionRange {
2443 start: Timestamp::new(60, unit.into()),
2444 end: Timestamp::new(70, unit.into()),
2445 num_rows: 2,
2446 identifier: 1,
2447 },
2448 vec![
2449 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![60, 61])])
2450 .unwrap(),
2451 ],
2452 ),
2453 (
2455 PartitionRange {
2456 start: Timestamp::new(61, unit.into()),
2457 end: Timestamp::new(70, unit.into()),
2458 num_rows: 2,
2459 identifier: 1,
2460 },
2461 vec![
2462 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![71, 72])])
2463 .unwrap(),
2464 ],
2465 ),
2466 ];
2467
2468 let expected_output = Some(
2472 DfRecordBatch::try_new(
2473 schema.clone(),
2474 vec![new_ts_array(unit, vec![10, 11, 12, 13])],
2475 )
2476 .unwrap(),
2477 );
2478
2479 run_test(
2480 2004,
2481 input_ranged_data,
2482 schema.clone(),
2483 SortOptions {
2484 descending: false,
2485 ..Default::default()
2486 },
2487 Some(4),
2488 expected_output,
2489 Some(11), )
2491 .await;
2492 }
2493
2494 #[tokio::test]
2495 async fn test_ascending_threshold_early_termination_case_two() {
2496 let unit = TimeUnit::Millisecond;
2497 let schema = Arc::new(Schema::new(vec![Field::new(
2498 "ts",
2499 DataType::Timestamp(unit, None),
2500 false,
2501 )]));
2502
2503 let input_ranged_data = vec![
2510 (
2511 PartitionRange {
2512 start: Timestamp::new(0, unit.into()),
2513 end: Timestamp::new(20, unit.into()),
2514 num_rows: 4,
2515 identifier: 0,
2516 },
2517 vec![
2518 DfRecordBatch::try_new(
2519 schema.clone(),
2520 vec![new_ts_array(unit, vec![9, 10, 11, 12])],
2521 )
2522 .unwrap(),
2523 ],
2524 ),
2525 (
2526 PartitionRange {
2527 start: Timestamp::new(4, unit.into()),
2528 end: Timestamp::new(25, unit.into()),
2529 num_rows: 1,
2530 identifier: 1,
2531 },
2532 vec![
2533 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![21])])
2534 .unwrap(),
2535 ],
2536 ),
2537 (
2538 PartitionRange {
2539 start: Timestamp::new(5, unit.into()),
2540 end: Timestamp::new(25, unit.into()),
2541 num_rows: 4,
2542 identifier: 1,
2543 },
2544 vec![
2545 DfRecordBatch::try_new(
2546 schema.clone(),
2547 vec![new_ts_array(unit, vec![5, 6, 7, 8])],
2548 )
2549 .unwrap(),
2550 ],
2551 ),
2552 (
2554 PartitionRange {
2555 start: Timestamp::new(42, unit.into()),
2556 end: Timestamp::new(52, unit.into()),
2557 num_rows: 2,
2558 identifier: 1,
2559 },
2560 vec![
2561 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![42, 51])])
2562 .unwrap(),
2563 ],
2564 ),
2565 (
2567 PartitionRange {
2568 start: Timestamp::new(48, unit.into()),
2569 end: Timestamp::new(53, unit.into()),
2570 num_rows: 2,
2571 identifier: 1,
2572 },
2573 vec![
2574 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![48, 51])])
2575 .unwrap(),
2576 ],
2577 ),
2578 ];
2579
2580 let expected_output = Some(
2583 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![5, 6, 7, 8])])
2584 .unwrap(),
2585 );
2586
2587 run_test(
2588 2005,
2589 input_ranged_data,
2590 schema.clone(),
2591 SortOptions {
2592 descending: false,
2593 ..Default::default()
2594 },
2595 Some(4),
2596 expected_output,
2597 Some(11), )
2599 .await;
2600 }
2601
2602 #[tokio::test]
2605 async fn test_early_stop_with_nulls() {
2606 let unit = TimeUnit::Millisecond;
2607 let schema = Arc::new(Schema::new(vec![Field::new(
2608 "ts",
2609 DataType::Timestamp(unit, None),
2610 true, )]));
2612
2613 let new_nullable_ts_array = |unit: TimeUnit, arr: Vec<Option<i64>>| -> ArrayRef {
2615 match unit {
2616 TimeUnit::Second => Arc::new(TimestampSecondArray::from(arr)) as ArrayRef,
2617 TimeUnit::Millisecond => Arc::new(TimestampMillisecondArray::from(arr)) as ArrayRef,
2618 TimeUnit::Microsecond => Arc::new(TimestampMicrosecondArray::from(arr)) as ArrayRef,
2619 TimeUnit::Nanosecond => Arc::new(TimestampNanosecondArray::from(arr)) as ArrayRef,
2620 }
2621 };
2622
2623 let input_ranged_data = vec![
2627 (
2628 PartitionRange {
2629 start: Timestamp::new(70, unit.into()),
2630 end: Timestamp::new(100, unit.into()),
2631 num_rows: 5,
2632 identifier: 0,
2633 },
2634 vec![
2635 DfRecordBatch::try_new(
2636 schema.clone(),
2637 vec![new_nullable_ts_array(
2638 unit,
2639 vec![Some(99), Some(98), None, Some(97), None],
2640 )],
2641 )
2642 .unwrap(),
2643 ],
2644 ),
2645 (
2646 PartitionRange {
2647 start: Timestamp::new(50, unit.into()),
2648 end: Timestamp::new(90, unit.into()),
2649 num_rows: 3,
2650 identifier: 1,
2651 },
2652 vec![
2653 DfRecordBatch::try_new(
2654 schema.clone(),
2655 vec![new_nullable_ts_array(
2656 unit,
2657 vec![Some(89), Some(88), Some(87)],
2658 )],
2659 )
2660 .unwrap(),
2661 ],
2662 ),
2663 ];
2664
2665 let expected_output = Some(
2669 DfRecordBatch::try_new(
2670 schema.clone(),
2671 vec![new_nullable_ts_array(unit, vec![None, None, Some(99)])],
2672 )
2673 .unwrap(),
2674 );
2675
2676 run_test(
2677 3000,
2678 input_ranged_data,
2679 schema.clone(),
2680 SortOptions {
2681 descending: true,
2682 nulls_first: true,
2683 },
2684 Some(3),
2685 expected_output,
2686 Some(8), )
2688 .await;
2689
2690 let input_ranged_data = vec![
2694 (
2695 PartitionRange {
2696 start: Timestamp::new(70, unit.into()),
2697 end: Timestamp::new(100, unit.into()),
2698 num_rows: 5,
2699 identifier: 0,
2700 },
2701 vec![
2702 DfRecordBatch::try_new(
2703 schema.clone(),
2704 vec![new_nullable_ts_array(
2705 unit,
2706 vec![Some(99), Some(98), Some(97), None, None],
2707 )],
2708 )
2709 .unwrap(),
2710 ],
2711 ),
2712 (
2713 PartitionRange {
2714 start: Timestamp::new(50, unit.into()),
2715 end: Timestamp::new(90, unit.into()),
2716 num_rows: 3,
2717 identifier: 1,
2718 },
2719 vec![
2720 DfRecordBatch::try_new(
2721 schema.clone(),
2722 vec![new_nullable_ts_array(
2723 unit,
2724 vec![Some(89), Some(88), Some(87)],
2725 )],
2726 )
2727 .unwrap(),
2728 ],
2729 ),
2730 ];
2731
2732 let expected_output = Some(
2736 DfRecordBatch::try_new(
2737 schema.clone(),
2738 vec![new_nullable_ts_array(
2739 unit,
2740 vec![Some(99), Some(98), Some(97)],
2741 )],
2742 )
2743 .unwrap(),
2744 );
2745
2746 run_test(
2747 3001,
2748 input_ranged_data,
2749 schema.clone(),
2750 SortOptions {
2751 descending: true,
2752 nulls_first: false,
2753 },
2754 Some(3),
2755 expected_output,
2756 Some(8), )
2758 .await;
2759 }
2760
2761 #[tokio::test]
2764 async fn test_early_stop_single_group() {
2765 let unit = TimeUnit::Millisecond;
2766 let schema = Arc::new(Schema::new(vec![Field::new(
2767 "ts",
2768 DataType::Timestamp(unit, None),
2769 false,
2770 )]));
2771
2772 let input_ranged_data = vec![
2774 (
2775 PartitionRange {
2776 start: Timestamp::new(70, unit.into()),
2777 end: Timestamp::new(100, unit.into()),
2778 num_rows: 6,
2779 identifier: 0,
2780 },
2781 vec![
2782 DfRecordBatch::try_new(
2783 schema.clone(),
2784 vec![new_ts_array(unit, vec![94, 95, 96, 97, 98, 99])],
2785 )
2786 .unwrap(),
2787 ],
2788 ),
2789 (
2790 PartitionRange {
2791 start: Timestamp::new(50, unit.into()),
2792 end: Timestamp::new(100, unit.into()),
2793 num_rows: 3,
2794 identifier: 1,
2795 },
2796 vec![
2797 DfRecordBatch::try_new(
2798 schema.clone(),
2799 vec![new_ts_array(unit, vec![85, 86, 87])],
2800 )
2801 .unwrap(),
2802 ],
2803 ),
2804 ];
2805
2806 let expected_output = Some(
2809 DfRecordBatch::try_new(
2810 schema.clone(),
2811 vec![new_ts_array(unit, vec![99, 98, 97, 96])],
2812 )
2813 .unwrap(),
2814 );
2815
2816 run_test(
2817 3002,
2818 input_ranged_data,
2819 schema.clone(),
2820 SortOptions {
2821 descending: true,
2822 ..Default::default()
2823 },
2824 Some(4),
2825 expected_output,
2826 Some(9), )
2828 .await;
2829 }
2830
2831 #[tokio::test]
2833 async fn test_early_stop_exact_boundary_equality() {
2834 let unit = TimeUnit::Millisecond;
2835 let schema = Arc::new(Schema::new(vec![Field::new(
2836 "ts",
2837 DataType::Timestamp(unit, None),
2838 false,
2839 )]));
2840
2841 let input_ranged_data = vec![
2845 (
2846 PartitionRange {
2847 start: Timestamp::new(70, unit.into()),
2848 end: Timestamp::new(100, unit.into()),
2849 num_rows: 4,
2850 identifier: 0,
2851 },
2852 vec![
2853 DfRecordBatch::try_new(
2854 schema.clone(),
2855 vec![new_ts_array(unit, vec![92, 91, 90, 89])],
2856 )
2857 .unwrap(),
2858 ],
2859 ),
2860 (
2861 PartitionRange {
2862 start: Timestamp::new(50, unit.into()),
2863 end: Timestamp::new(90, unit.into()),
2864 num_rows: 3,
2865 identifier: 1,
2866 },
2867 vec![
2868 DfRecordBatch::try_new(
2869 schema.clone(),
2870 vec![new_ts_array(unit, vec![88, 87, 86])],
2871 )
2872 .unwrap(),
2873 ],
2874 ),
2875 ];
2876
2877 let expected_output = Some(
2878 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![92, 91, 90])])
2879 .unwrap(),
2880 );
2881
2882 run_test(
2883 3003,
2884 input_ranged_data,
2885 schema.clone(),
2886 SortOptions {
2887 descending: true,
2888 ..Default::default()
2889 },
2890 Some(3),
2891 expected_output,
2892 Some(7), )
2894 .await;
2895
2896 let input_ranged_data = vec![
2900 (
2901 PartitionRange {
2902 start: Timestamp::new(10, unit.into()),
2903 end: Timestamp::new(50, unit.into()),
2904 num_rows: 4,
2905 identifier: 0,
2906 },
2907 vec![
2908 DfRecordBatch::try_new(
2909 schema.clone(),
2910 vec![new_ts_array(unit, vec![10, 15, 20, 25])],
2911 )
2912 .unwrap(),
2913 ],
2914 ),
2915 (
2916 PartitionRange {
2917 start: Timestamp::new(20, unit.into()),
2918 end: Timestamp::new(60, unit.into()),
2919 num_rows: 3,
2920 identifier: 1,
2921 },
2922 vec![
2923 DfRecordBatch::try_new(
2924 schema.clone(),
2925 vec![new_ts_array(unit, vec![21, 22, 23])],
2926 )
2927 .unwrap(),
2928 ],
2929 ),
2930 ];
2931
2932 let expected_output = Some(
2933 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![10, 15, 20])])
2934 .unwrap(),
2935 );
2936
2937 run_test(
2938 3004,
2939 input_ranged_data,
2940 schema.clone(),
2941 SortOptions {
2942 descending: false,
2943 ..Default::default()
2944 },
2945 Some(3),
2946 expected_output,
2947 Some(7), )
2949 .await;
2950 }
2951
2952 #[tokio::test]
2954 async fn test_early_stop_with_empty_partitions() {
2955 let unit = TimeUnit::Millisecond;
2956 let schema = Arc::new(Schema::new(vec![Field::new(
2957 "ts",
2958 DataType::Timestamp(unit, None),
2959 false,
2960 )]));
2961
2962 let input_ranged_data = vec![
2964 (
2965 PartitionRange {
2966 start: Timestamp::new(70, unit.into()),
2967 end: Timestamp::new(100, unit.into()),
2968 num_rows: 0,
2969 identifier: 0,
2970 },
2971 vec![
2972 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![])])
2974 .unwrap(),
2975 ],
2976 ),
2977 (
2978 PartitionRange {
2979 start: Timestamp::new(50, unit.into()),
2980 end: Timestamp::new(100, unit.into()),
2981 num_rows: 0,
2982 identifier: 1,
2983 },
2984 vec![
2985 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![])])
2987 .unwrap(),
2988 ],
2989 ),
2990 (
2991 PartitionRange {
2992 start: Timestamp::new(30, unit.into()),
2993 end: Timestamp::new(80, unit.into()),
2994 num_rows: 4,
2995 identifier: 2,
2996 },
2997 vec![
2998 DfRecordBatch::try_new(
2999 schema.clone(),
3000 vec![new_ts_array(unit, vec![74, 75, 76, 77])],
3001 )
3002 .unwrap(),
3003 ],
3004 ),
3005 (
3006 PartitionRange {
3007 start: Timestamp::new(10, unit.into()),
3008 end: Timestamp::new(60, unit.into()),
3009 num_rows: 3,
3010 identifier: 3,
3011 },
3012 vec![
3013 DfRecordBatch::try_new(
3014 schema.clone(),
3015 vec![new_ts_array(unit, vec![58, 59, 60])],
3016 )
3017 .unwrap(),
3018 ],
3019 ),
3020 ];
3021
3022 let expected_output = Some(
3025 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![77, 76])]).unwrap(),
3026 );
3027
3028 run_test(
3029 3005,
3030 input_ranged_data,
3031 schema.clone(),
3032 SortOptions {
3033 descending: true,
3034 ..Default::default()
3035 },
3036 Some(2),
3037 expected_output,
3038 Some(7), )
3040 .await;
3041
3042 let input_ranged_data = vec![
3044 (
3045 PartitionRange {
3046 start: Timestamp::new(70, unit.into()),
3047 end: Timestamp::new(100, unit.into()),
3048 num_rows: 4,
3049 identifier: 0,
3050 },
3051 vec![
3052 DfRecordBatch::try_new(
3053 schema.clone(),
3054 vec![new_ts_array(unit, vec![96, 97, 98, 99])],
3055 )
3056 .unwrap(),
3057 ],
3058 ),
3059 (
3060 PartitionRange {
3061 start: Timestamp::new(50, unit.into()),
3062 end: Timestamp::new(90, unit.into()),
3063 num_rows: 0,
3064 identifier: 1,
3065 },
3066 vec![
3067 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![])])
3069 .unwrap(),
3070 ],
3071 ),
3072 (
3073 PartitionRange {
3074 start: Timestamp::new(30, unit.into()),
3075 end: Timestamp::new(70, unit.into()),
3076 num_rows: 0,
3077 identifier: 2,
3078 },
3079 vec![
3080 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![])])
3082 .unwrap(),
3083 ],
3084 ),
3085 (
3086 PartitionRange {
3087 start: Timestamp::new(10, unit.into()),
3088 end: Timestamp::new(50, unit.into()),
3089 num_rows: 3,
3090 identifier: 3,
3091 },
3092 vec![
3093 DfRecordBatch::try_new(
3094 schema.clone(),
3095 vec![new_ts_array(unit, vec![48, 49, 50])],
3096 )
3097 .unwrap(),
3098 ],
3099 ),
3100 ];
3101
3102 let expected_output = Some(
3105 DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![99, 98])]).unwrap(),
3106 );
3107
3108 run_test(
3109 3006,
3110 input_ranged_data,
3111 schema.clone(),
3112 SortOptions {
3113 descending: true,
3114 ..Default::default()
3115 },
3116 Some(2),
3117 expected_output,
3118 Some(7), )
3120 .await;
3121 }
3122}