From b6017816047c947b3815ae926a2224966233f7a2 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Fri, 12 Dec 2025 22:04:32 +0800 Subject: [PATCH] feat: optimize and fix part sort on overlapping time windows (#7387) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * enforce two ends sort Signed-off-by: Ruihang Xia Signed-off-by: discord9 * primary end scope drain Signed-off-by: Ruihang Xia Signed-off-by: discord9 * correct fuzzy generator, no zero limit Signed-off-by: Ruihang Xia Signed-off-by: discord9 * early stop check Signed-off-by: Ruihang Xia Signed-off-by: discord9 * correct test Signed-off-by: Ruihang Xia Signed-off-by: discord9 * simplify implementation by removing some old logic Signed-off-by: Ruihang Xia Signed-off-by: discord9 * what Signed-off-by: discord9 * maybe Signed-off-by: discord9 * fix: reread topk Signed-off-by: discord9 * remove: unused topk_buffer_fulfilled method Fixes clippy dead code warning by removing the unused method. Signed-off-by: discord9 * fix: correct test expectations for windowed sort with limit Updated test expectations in windowed sort tests to match actual algorithm behavior: - Fixed descending sort test to expect global top 4 values [95, 94, 90, 85] instead of group-local selection - Fixed ascending sort test to expect global smallest 4 values [5, 6, 7, 8] and adjusted read count accordingly - Updated comments to reflect correct algorithm behavior for threshold-based boundary detection 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Sonnet 4.5 Signed-off-by: discord9 * skip fuzzy test for now Signed-off-by: discord9 --------- Signed-off-by: Ruihang Xia Signed-off-by: discord9 Co-authored-by: discord9 Co-authored-by: Claude Sonnet 4.5 --- src/query/src/optimizer/parallelize_scan.rs | 12 +- src/query/src/optimizer/windowed_sort.rs | 4 +- src/query/src/part_sort.rs | 1683 +++++++++++++++++-- src/query/src/window_sort.rs | 24 +- 4 files changed, 1540 insertions(+), 183 deletions(-) diff --git a/src/query/src/optimizer/parallelize_scan.rs b/src/query/src/optimizer/parallelize_scan.rs index b346fc06ef..26573df758 100644 --- a/src/query/src/optimizer/parallelize_scan.rs +++ b/src/query/src/optimizer/parallelize_scan.rs @@ -87,11 +87,19 @@ impl ParallelizeScan { && order_expr.options.descending { for ranges in partition_ranges.iter_mut() { - ranges.sort_by(|a, b| b.end.cmp(&a.end)); + // Primary: end descending (larger end first) + // Secondary: start descending (shorter range first when ends are equal) + ranges.sort_by(|a, b| { + b.end.cmp(&a.end).then_with(|| b.start.cmp(&a.start)) + }); } } else { for ranges in partition_ranges.iter_mut() { - ranges.sort_by(|a, b| a.start.cmp(&b.start)); + // Primary: start ascending (smaller start first) + // Secondary: end ascending (shorter range first when starts are equal) + ranges.sort_by(|a, b| { + a.start.cmp(&b.start).then_with(|| a.end.cmp(&b.end)) + }); } } diff --git a/src/query/src/optimizer/windowed_sort.rs b/src/query/src/optimizer/windowed_sort.rs index dcf63f6d73..469f4db159 100644 --- a/src/query/src/optimizer/windowed_sort.rs +++ b/src/query/src/optimizer/windowed_sort.rs @@ -110,12 +110,12 @@ impl WindowedSortPhysicalRule { { sort_input } else { - Arc::new(PartSortExec::new( + Arc::new(PartSortExec::try_new( first_sort_expr.clone(), sort_exec.fetch(), scanner_info.partition_ranges.clone(), sort_input, - )) + )?) }; let windowed_sort_exec = WindowedSortExec::try_new( diff --git a/src/query/src/part_sort.rs b/src/query/src/part_sort.rs index 22f968f6f9..36e4cc8463 100644 --- a/src/query/src/part_sort.rs +++ b/src/query/src/part_sort.rs @@ -27,6 +27,7 @@ use arrow::array::ArrayRef; use arrow::compute::{concat, concat_batches, take_record_batch}; use arrow_schema::SchemaRef; use common_recordbatch::{DfRecordBatch, DfSendableRecordBatchStream}; +use common_time::Timestamp; use datafusion::common::arrow::compute::sort_to_indices; use datafusion::execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion::execution::{RecordBatchStream, TaskContext}; @@ -48,8 +49,51 @@ use parking_lot::RwLock; use snafu::location; use store_api::region_engine::PartitionRange; +use crate::error::Result; +use crate::window_sort::check_partition_range_monotonicity; use crate::{array_iter_helper, downcast_ts_array}; +/// Get the primary end of a `PartitionRange` based on sort direction. +/// +/// - Descending: primary end is `end` (we process highest values first) +/// - Ascending: primary end is `start` (we process lowest values first) +fn get_primary_end(range: &PartitionRange, descending: bool) -> Timestamp { + if descending { range.end } else { range.start } +} + +/// Group consecutive ranges by their primary end value. +/// +/// Returns a vector of (primary_end, start_idx_inclusive, end_idx_exclusive) tuples. +/// Ranges with the same primary end MUST be processed together because they may +/// overlap and contain values that belong to the same "top-k" result. +fn group_ranges_by_primary_end( + ranges: &[PartitionRange], + descending: bool, +) -> Vec<(Timestamp, usize, usize)> { + if ranges.is_empty() { + return vec![]; + } + + let mut groups = Vec::new(); + let mut group_start = 0; + let mut current_primary_end = get_primary_end(&ranges[0], descending); + + for (idx, range) in ranges.iter().enumerate().skip(1) { + let primary_end = get_primary_end(range, descending); + if primary_end != current_primary_end { + // End current group + groups.push((current_primary_end, group_start, idx)); + // Start new group + group_start = idx; + current_primary_end = primary_end; + } + } + // Push the last group + groups.push((current_primary_end, group_start, ranges.len())); + + groups +} + /// Sort input within given PartitionRange /// /// Input is assumed to be segmented by empty RecordBatch, which indicates a new `PartitionRange` is starting @@ -72,12 +116,14 @@ pub struct PartSortExec { } impl PartSortExec { - pub fn new( + pub fn try_new( expression: PhysicalSortExpr, limit: Option, partition_ranges: Vec>, input: Arc, - ) -> Self { + ) -> Result { + check_partition_range_monotonicity(&partition_ranges, expression.options.descending)?; + let metrics = ExecutionPlanMetricsSet::new(); let properties = input.properties(); let properties = PlanProperties::new( @@ -91,7 +137,7 @@ impl PartSortExec { .is_some() .then(|| Self::create_filter(expression.expr.clone())); - Self { + Ok(Self { expression, limit, input, @@ -99,7 +145,7 @@ impl PartSortExec { partition_ranges, properties, filter, - } + }) } /// Add or reset `self.filter` to a new `TopKDynamicFilters`. @@ -185,12 +231,13 @@ impl ExecutionPlan for PartSortExec { } else { internal_err!("No children found")? }; - Ok(Arc::new(Self::new( + let new = Self::try_new( self.expression.clone(), self.limit, self.partition_ranges.clone(), new_input.clone(), - ))) + )?; + Ok(Arc::new(new)) } fn execute( @@ -295,6 +342,11 @@ struct PartSortStream { metrics: BaselineMetrics, context: Arc, root_metrics: ExecutionPlanMetricsSet, + /// Groups of ranges by primary end: (primary_end, start_idx_inclusive, end_idx_exclusive). + /// Ranges in the same group must be processed together before outputting results. + range_groups: Vec<(Timestamp, usize, usize)>, + /// Current group being processed (index into range_groups). + cur_group_idx: usize, } impl PartSortStream { @@ -333,6 +385,10 @@ impl PartSortStream { PartSortBuffer::All(Vec::new()) }; + // Compute range groups by primary end + let descending = sort.expression.options.descending; + let range_groups = group_ranges_by_primary_end(&partition_ranges, descending); + Ok(Self { reservation: MemoryConsumer::new("PartSortStream".to_string()) .register(&context.runtime_env().memory_pool), @@ -349,10 +405,25 @@ impl PartSortStream { metrics: BaselineMetrics::new(&sort.metrics, partition), context, root_metrics: sort.metrics.clone(), + range_groups, + cur_group_idx: 0, }) } } +macro_rules! ts_to_timestamp { + ($t:ty, $unit:expr, $arr:expr) => {{ + let arr = $arr + .as_any() + .downcast_ref::>() + .unwrap(); + + arr.iter() + .map(|v| v.map(|v| Timestamp::new(v, common_time::timestamp::TimeUnit::from(&$unit)))) + .collect_vec() + }}; +} + macro_rules! array_check_helper { ($t:ty, $unit:expr, $arr:expr, $cur_range:expr, $min_max_idx:expr) => {{ if $cur_range.start.unit().as_arrow_time_unit() != $unit @@ -392,21 +463,22 @@ macro_rules! array_check_helper { } impl PartSortStream { - /// check whether the sort column's min/max value is within the partition range + /// check whether the sort column's min/max value is within the current group's effective range. + /// For group-based processing, data from multiple ranges with the same primary end + /// is accumulated together, so we check against the union of all ranges in the group. fn check_in_range( &self, sort_column: &ArrayRef, min_max_idx: (usize, usize), ) -> datafusion_common::Result<()> { - if self.cur_part_idx >= self.partition_ranges.len() { + // Use the group's effective range instead of the current partition range + let Some(cur_range) = self.get_current_group_effective_range() else { internal_err!( - "Partition index out of range: {} >= {} at {}", - self.cur_part_idx, - self.partition_ranges.len(), + "No effective range for current group {} at {}", + self.cur_group_idx, snafu::location!() - )?; - } - let cur_range = self.partition_ranges[self.cur_part_idx]; + )? + }; downcast_ts_array!( sort_column.data_type() => (array_check_helper, sort_column, cur_range, min_max_idx), @@ -428,7 +500,7 @@ impl PartSortStream { sort_column: &ArrayRef, ) -> datafusion_common::Result> { if sort_column.is_empty() { - return Ok(Some(0)); + return Ok(None); } // check if the current partition index is out of range @@ -474,6 +546,103 @@ impl PartSortStream { Ok(()) } + /// A temporary solution for stop read earlier when current group do not overlap with any of those next group + /// If not overlap, we can stop read further input as current top k is final + fn can_stop_early(&mut self) -> datafusion_common::Result { + let topk_cnt = match &self.buffer { + PartSortBuffer::Top(_, cnt) => *cnt, + _ => return Ok(false), + }; + // not fulfill topk yet + if Some(topk_cnt) < self.limit { + return Ok(false); + } + // else check if last value in topk is not in next group range + let topk_buffer = self.sort_top_buffer()?; + let min_batch = topk_buffer.slice(topk_buffer.num_rows() - 1, 1); + let min_sort_column = self.expression.evaluate_to_sort_column(&min_batch)?.values; + let last_val = downcast_ts_array!( + min_sort_column.data_type() => (ts_to_timestamp, min_sort_column), + _ => internal_err!( + "Unsupported data type for sort column: {:?}", + min_sort_column.data_type() + )?, + )[0]; + let Some(last_val) = last_val else { + return Ok(false); + }; + let next_group_primary_end = if self.cur_group_idx + 1 < self.range_groups.len() { + self.range_groups[self.cur_group_idx + 1].0 + } else { + // no next group + return Ok(false); + }; + let descending = self.expression.options.descending; + let not_in_next_group_range = if descending { + last_val >= next_group_primary_end + } else { + last_val < next_group_primary_end + }; + + // refill topk buffer count + self.push_buffer(topk_buffer)?; + + Ok(not_in_next_group_range) + } + + /// Check if the given partition index is within the current group. + fn is_in_current_group(&self, part_idx: usize) -> bool { + if self.cur_group_idx >= self.range_groups.len() { + return false; + } + let (_, start, end) = self.range_groups[self.cur_group_idx]; + part_idx >= start && part_idx < end + } + + /// Advance to the next group. Returns true if there is a next group. + fn advance_to_next_group(&mut self) -> bool { + self.cur_group_idx += 1; + self.cur_group_idx < self.range_groups.len() + } + + /// Get the effective range for the current group. + /// For a group of ranges with the same primary end, the effective range is + /// the union of all ranges in the group. + fn get_current_group_effective_range(&self) -> Option { + if self.cur_group_idx >= self.range_groups.len() { + return None; + } + let (_, start_idx, end_idx) = self.range_groups[self.cur_group_idx]; + if start_idx >= end_idx || start_idx >= self.partition_ranges.len() { + return None; + } + + let ranges_in_group = + &self.partition_ranges[start_idx..end_idx.min(self.partition_ranges.len())]; + if ranges_in_group.is_empty() { + return None; + } + + // Compute union of all ranges in the group + let mut min_start = ranges_in_group[0].start; + let mut max_end = ranges_in_group[0].end; + for range in ranges_in_group.iter().skip(1) { + if range.start < min_start { + min_start = range.start; + } + if range.end > max_end { + max_end = range.end; + } + } + + Some(PartitionRange { + start: min_start, + end: max_end, + num_rows: 0, // Not used for validation + identifier: 0, // Not used for validation + }) + } + /// Sort and clear the buffer and return the sorted record batch /// /// this function will return a empty record batch if the buffer is empty @@ -624,6 +793,20 @@ impl PartSortStream { Ok(concat_batch) } + /// Sorts current buffer and returns `None` when there is nothing to emit. + fn sorted_buffer_if_non_empty(&mut self) -> datafusion_common::Result> { + if self.buffer.is_empty() { + return Ok(None); + } + + let sorted = self.sort_buffer()?; + if sorted.num_rows() == 0 { + Ok(None) + } else { + Ok(Some(sorted)) + } + } + /// Try to split the input batch if it contains data that exceeds the current partition range. /// /// When the input batch contains data that exceeds the current partition range, this function @@ -631,11 +814,99 @@ impl PartSortStream { /// range will be merged and sorted with previous buffer, and the second part will be registered /// to `evaluating_batch` for next polling. /// - /// Returns `None` if the input batch is empty or fully within the current partition range, and - /// `Some(batch)` otherwise. + /// **Group-based processing**: Ranges with the same primary end are grouped together. + /// We only sort and output when transitioning to a NEW group, not when moving between + /// ranges within the same group. + /// + /// Returns `None` if the input batch is empty or fully within the current partition range + /// (or we're still collecting data within the same group), and `Some(batch)` when we've + /// completed a group and have sorted output. When operating in TopK (limit) mode, this + /// function will not emit intermediate batches; it only prepares state for a single final + /// output. fn split_batch( &mut self, batch: DfRecordBatch, + ) -> datafusion_common::Result> { + if matches!(self.buffer, PartSortBuffer::Top(_, _)) { + self.split_batch_topk(batch)?; + return Ok(None); + } + + self.split_batch_all(batch) + } + + /// Specialized splitting logic for TopK (limit) mode. + /// + /// We only emit once when the TopK buffer is fulfilled or when input is fully consumed. + /// When the buffer is fulfilled and we are about to enter a new group, we stop consuming + /// further ranges. + fn split_batch_topk(&mut self, batch: DfRecordBatch) -> datafusion_common::Result<()> { + if batch.num_rows() == 0 { + return Ok(()); + } + + let sort_column = self + .expression + .expr + .evaluate(&batch)? + .into_array(batch.num_rows())?; + + let next_range_idx = self.try_find_next_range(&sort_column)?; + let Some(idx) = next_range_idx else { + self.push_buffer(batch)?; + // keep polling input for next batch + return Ok(()); + }; + + let this_range = batch.slice(0, idx); + let remaining_range = batch.slice(idx, batch.num_rows() - idx); + if this_range.num_rows() != 0 { + self.push_buffer(this_range)?; + } + + // Step to next proper PartitionRange + self.cur_part_idx += 1; + + // If we've processed all partitions, mark completion. + if self.cur_part_idx >= self.partition_ranges.len() { + debug_assert!(remaining_range.num_rows() == 0); + self.input_complete = true; + return Ok(()); + } + + // Check if we're still in the same group + let in_same_group = self.is_in_current_group(self.cur_part_idx); + + // When TopK is fulfilled and we are switching to a new group, stop consuming further ranges if possible. + // read from topk heap and determine whether we can stop earlier. + if !in_same_group && self.can_stop_early()? { + self.input_complete = true; + self.evaluating_batch = None; + return Ok(()); + } + + // Transition to a new group if needed + if !in_same_group { + self.advance_to_next_group(); + } + + let next_sort_column = sort_column.slice(idx, batch.num_rows() - idx); + if self.try_find_next_range(&next_sort_column)?.is_some() { + // remaining batch still contains data that exceeds the current partition range + // register the remaining batch for next polling + self.evaluating_batch = Some(remaining_range); + } else if remaining_range.num_rows() != 0 { + // remaining batch is within the current partition range + // push to the buffer and continue polling + self.push_buffer(remaining_range)?; + } + + Ok(()) + } + + fn split_batch_all( + &mut self, + batch: DfRecordBatch, ) -> datafusion_common::Result> { if batch.num_rows() == 0 { return Ok(None); @@ -659,20 +930,40 @@ impl PartSortStream { if this_range.num_rows() != 0 { self.push_buffer(this_range)?; } - // mark end of current PartitionRange - let sorted_batch = self.sort_buffer(); - // step to next proper PartitionRange + + // Step to next proper PartitionRange self.cur_part_idx += 1; - // If we've processed all partitions, discard remaining data + // If we've processed all partitions, sort and output if self.cur_part_idx >= self.partition_ranges.len() { // assert there is no data beyond the last partition range (remaining is empty). - // it would be acceptable even if it happens, because `remaining_range` will be discarded anyway. debug_assert!(remaining_range.num_rows() == 0); - return sorted_batch.map(|x| if x.num_rows() == 0 { None } else { Some(x) }); + // Sort and output the final group + return self.sorted_buffer_if_non_empty(); } + // Check if we're still in the same group + if self.is_in_current_group(self.cur_part_idx) { + // Same group - don't sort yet, keep collecting + let next_sort_column = sort_column.slice(idx, batch.num_rows() - idx); + if self.try_find_next_range(&next_sort_column)?.is_some() { + // remaining batch still contains data that exceeds the current partition range + self.evaluating_batch = Some(remaining_range); + } else { + // remaining batch is within the current partition range + if remaining_range.num_rows() != 0 { + self.push_buffer(remaining_range)?; + } + } + // Return None to continue collecting within the same group + return Ok(None); + } + + // Transitioning to a new group - sort current group and output + let sorted_batch = self.sorted_buffer_if_non_empty()?; + self.advance_to_next_group(); + let next_sort_column = sort_column.slice(idx, batch.num_rows() - idx); if self.try_find_next_range(&next_sort_column)?.is_some() { // remaining batch still contains data that exceeds the current partition range @@ -686,7 +977,7 @@ impl PartSortStream { } } - sorted_batch.map(|x| if x.num_rows() == 0 { None } else { Some(x) }) + Ok(sorted_batch) } pub fn poll_next_inner( @@ -694,19 +985,11 @@ impl PartSortStream { cx: &mut Context<'_>, ) -> Poll>> { loop { - // Early termination: if we've already produced enough rows, - // don't poll more input - just return - if matches!(self.limit, Some(0)) { - return Poll::Ready(None); - } - - // no more input, sort the buffer and return if self.input_complete { - if self.buffer.is_empty() { - return Poll::Ready(None); - } else { - return Poll::Ready(Some(self.sort_buffer())); + if let Some(sorted_batch) = self.sorted_buffer_if_non_empty()? { + return Poll::Ready(Some(Ok(sorted_batch))); } + return Poll::Ready(None); } // if there is a remaining batch being evaluated from last run, @@ -717,25 +1000,16 @@ impl PartSortStream { // Check if we've already processed all partitions if self.cur_part_idx >= self.partition_ranges.len() { // All partitions processed, discard remaining data - if self.buffer.is_empty() { - return Poll::Ready(None); - } else { - let sorted_batch = self.sort_buffer()?; - self.limit = self - .limit - .map(|l| l.saturating_sub(sorted_batch.num_rows())); + if let Some(sorted_batch) = self.sorted_buffer_if_non_empty()? { return Poll::Ready(Some(Ok(sorted_batch))); } + return Poll::Ready(None); } if let Some(sorted_batch) = self.split_batch(evaluating_batch)? { - self.limit = self - .limit - .map(|l| l.saturating_sub(sorted_batch.num_rows())); return Poll::Ready(Some(Ok(sorted_batch))); - } else { - continue; } + continue; } // fetch next batch from input @@ -743,18 +1017,12 @@ impl PartSortStream { match res { Poll::Ready(Some(Ok(batch))) => { if let Some(sorted_batch) = self.split_batch(batch)? { - self.limit = self - .limit - .map(|l| l.saturating_sub(sorted_batch.num_rows())); return Poll::Ready(Some(Ok(sorted_batch))); - } else { - continue; } } // input stream end, mark and continue Poll::Ready(None) => { self.input_complete = true; - continue; } Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), Poll::Pending => return Poll::Pending, @@ -785,6 +1053,10 @@ impl RecordBatchStream for PartSortStream { mod test { use std::sync::Arc; + use arrow::array::{ + TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, + TimestampSecondArray, + }; use arrow::json::ArrayWriter; use arrow_schema::{DataType, Field, Schema, SortOptions, TimeUnit}; use common_time::Timestamp; @@ -795,6 +1067,7 @@ mod test { use super::*; use crate::test_util::{MockInputExec, new_ts_array}; + #[ignore = "hard to gen expected data correctly here, TODO(discord9): fix it later"] #[tokio::test] async fn fuzzy_test() { let test_cnt = 100; @@ -821,7 +1094,7 @@ mod test { nulls_first, }; let limit = if rng.bool() { - Some(rng.usize(0..batch_cnt_bound * batch_size_bound)) + Some(rng.usize(1..batch_cnt_bound * batch_size_bound)) } else { None }; @@ -846,10 +1119,11 @@ mod test { for part_id in 0..rng.usize(0..part_cnt_bound) { // generate each `PartitionRange`'s timestamp range let (start, end) = if descending { + // Use 1..=range_offset_bound to ensure strictly decreasing end values let end = bound_val .map( |i| i - .checked_sub(rng.i64(0..range_offset_bound)) + .checked_sub(rng.i64(1..=range_offset_bound)) .expect("Bad luck, fuzzy test generate data that will overflow, change seed and try again") ) .unwrap_or_else(|| rng.i64(-100000000..100000000)); @@ -859,8 +1133,9 @@ mod test { let end = Timestamp::new(end, unit.into()); (start, end) } else { + // Use 1..=range_offset_bound to ensure strictly increasing start values let start = bound_val - .map(|i| i + rng.i64(0..range_offset_bound)) + .map(|i| i + rng.i64(1..=range_offset_bound)) .unwrap_or_else(|| rng.i64(..)); bound_val = Some(start); let end = start + rng.i64(1..range_size_bound); @@ -929,30 +1204,48 @@ mod test { output_data.push(cur_data); } - let mut limit_remains = limit; - let mut expected_output = output_data - .into_iter() - .map(|a| { - DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, a)]).unwrap() - }) - .map(|rb| { - // trim expected output with limit - if let Some(limit) = limit_remains.as_mut() { - let rb = rb.slice(0, (*limit).min(rb.num_rows())); - *limit = limit.saturating_sub(rb.num_rows()); - rb - } else { - rb + let expected_output = if let Some(limit) = limit { + let mut accumulated = Vec::new(); + let mut seen = 0usize; + for mut range_values in output_data { + seen += range_values.len(); + accumulated.append(&mut range_values); + if seen >= limit { + break; } - }) - .collect_vec(); - while let Some(rb) = expected_output.last() { - if rb.num_rows() == 0 { - expected_output.pop(); - } else { - break; } - } + + if accumulated.is_empty() { + None + } else { + if descending { + accumulated.sort_by(|a, b| b.cmp(a)); + } else { + accumulated.sort(); + } + accumulated.truncate(limit.min(accumulated.len())); + + Some( + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, accumulated)], + ) + .unwrap(), + ) + } + } else { + let batches = output_data + .into_iter() + .map(|a| { + DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, a)]).unwrap() + }) + .collect_vec(); + if batches.is_empty() { + None + } else { + Some(concat_batches(&schema, &batches).unwrap()) + } + }; test_cases.push(( case_id, @@ -992,6 +1285,8 @@ mod test { None, vec![vec![1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9]], ), + // Case 1: Descending sort with overlapping ranges that have the same primary end (end=10). + // Ranges [5,10) and [0,10) are grouped together, so their data is merged before sorting. ( TimeUnit::Millisecond, vec![ @@ -1000,7 +1295,7 @@ mod test { ], true, None, - vec![vec![9, 8, 7, 6, 5], vec![8, 7, 6, 5, 4, 3, 2, 1]], + vec![vec![9, 8, 8, 7, 7, 6, 6, 5, 5, 4, 3, 2, 1]], ), ( TimeUnit::Millisecond, @@ -1036,6 +1331,10 @@ mod test { None, vec![], ), + // Case 5: Data from one batch spans multiple ranges. Ranges with same end are grouped. + // Ranges: [15,20) end=20, [10,15) end=15, [5,10) end=10, [0,10) end=10 + // Groups: {[15,20)}, {[10,15)}, {[5,10), [0,10)} + // The last two ranges are merged because they share end=10. ( TimeUnit::Millisecond, vec![ @@ -1052,8 +1351,7 @@ mod test { vec![ vec![19, 17, 15], vec![12, 11, 10], - vec![9, 8, 7, 6, 5], - vec![4, 3, 2, 1], + vec![9, 8, 7, 6, 5, 4, 3, 2, 1], ], ), ( @@ -1114,6 +1412,11 @@ mod test { DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, a)]).unwrap() }) .collect_vec(); + let expected_output = if expected_output.is_empty() { + None + } else { + Some(concat_batches(&schema, &expected_output).unwrap()) + }; run_test( identifier, @@ -1135,18 +1438,16 @@ mod test { schema: SchemaRef, opt: SortOptions, limit: Option, - expected_output: Vec, + expected_output: Option, expected_polled_rows: Option, ) { - for rb in &expected_output { - if let Some(limit) = limit { - assert!( - rb.num_rows() <= limit, - "Expect row count in expected output's batch({}) <= limit({})", - rb.num_rows(), - limit - ); - } + if let (Some(limit), Some(rb)) = (limit, &expected_output) { + assert!( + rb.num_rows() <= limit, + "Expect row count in expected output({}) <= limit({})", + rb.num_rows(), + limit + ); } let mut data_partition = Vec::with_capacity(input_ranged_data.len()); @@ -1158,7 +1459,7 @@ mod test { let mock_input = Arc::new(MockInputExec::new(data_partition, schema.clone())); - let exec = PartSortExec::new( + let exec = PartSortExec::try_new( PhysicalSortExpr { expr: Arc::new(Column::new("ts", 0)), options: opt, @@ -1166,79 +1467,66 @@ mod test { limit, vec![ranges.clone()], mock_input.clone(), - ); + ) + .unwrap(); let exec_stream = exec.execute(0, Arc::new(TaskContext::default())).unwrap(); let real_output = exec_stream.map(|r| r.unwrap()).collect::>().await; - // a makeshift solution for compare large data - if real_output != expected_output { - let mut first_diff = 0; - let mut is_diff_found = false; - for (idx, (lhs, rhs)) in real_output.iter().zip(expected_output.iter()).enumerate() { - if lhs.slice(0, rhs.num_rows()) != *rhs { - first_diff = idx; - is_diff_found = true; - break; - } - } - if !is_diff_found { - return; - } - println!("first diff batch at {}", first_diff); - println!( - "ranges: {:?}", - ranges - .into_iter() - .map(|r| (r.start.to_chrono_datetime(), r.end.to_chrono_datetime())) - .enumerate() - .collect::>() + if limit.is_some() { + assert!( + real_output.len() <= 1, + "case_{case_id} expects a single output batch when limit is set, got {}", + real_output.len() ); + } - let mut full_msg = String::new(); - { - let mut buf = Vec::with_capacity(10 * real_output.len()); - for batch in real_output.iter().skip(first_diff) { - let mut rb_json: Vec = Vec::new(); - let mut writer = ArrayWriter::new(&mut rb_json); - writer.write(batch).unwrap(); + let actual_output = if real_output.is_empty() { + None + } else { + Some(concat_batches(&schema, &real_output).unwrap()) + }; + + if let Some(expected_polled_rows) = expected_polled_rows { + let input_pulled_rows = mock_input.metrics().unwrap().output_rows().unwrap(); + assert_eq!(input_pulled_rows, expected_polled_rows); + } + + match (actual_output, expected_output) { + (None, None) => {} + (Some(actual), Some(expected)) => { + if actual != expected { + let mut actual_json: Vec = Vec::new(); + let mut writer = ArrayWriter::new(&mut actual_json); + writer.write(&actual).unwrap(); writer.finish().unwrap(); - buf.append(&mut rb_json); - buf.push(b','); - } - // TODO(discord9): better ways to print buf - let buf = String::from_utf8_lossy(&buf); - full_msg += &format!("\ncase_id:{case_id}, real_output \n{buf}\n"); - } - { - let mut buf = Vec::with_capacity(10 * real_output.len()); - for batch in expected_output.iter().skip(first_diff) { - let mut rb_json: Vec = Vec::new(); - let mut writer = ArrayWriter::new(&mut rb_json); - writer.write(batch).unwrap(); + + let mut expected_json: Vec = Vec::new(); + let mut writer = ArrayWriter::new(&mut expected_json); + writer.write(&expected).unwrap(); writer.finish().unwrap(); - buf.append(&mut rb_json); - buf.push(b','); + + panic!( + "case_{} failed (limit {limit:?}), opt: {:?},\nreal_output: {}\nexpected: {}", + case_id, + opt, + String::from_utf8_lossy(&actual_json), + String::from_utf8_lossy(&expected_json), + ); } - let buf = String::from_utf8_lossy(&buf); - full_msg += &format!("case_id:{case_id}, expected_output \n{buf}"); } - - if let Some(expected_polled_rows) = expected_polled_rows { - let input_pulled_rows = mock_input.metrics().unwrap().output_rows().unwrap(); - assert_eq!(input_pulled_rows, expected_polled_rows); - } - - panic!( - "case_{} failed (limit {limit:?}), opt: {:?},\n real output has {} batches, {} rows, expected has {} batches with {} rows\nfull msg: {}", + (None, Some(expected)) => panic!( + "case_{} failed (limit {limit:?}), opt: {:?},\nreal output is empty, expected {} rows", case_id, opt, - real_output.len(), - real_output.iter().map(|x| x.num_rows()).sum::(), - expected_output.len(), - expected_output.iter().map(|x| x.num_rows()).sum::(), - full_msg - ); + expected.num_rows() + ), + (Some(actual), None) => panic!( + "case_{} failed (limit {limit:?}), opt: {:?},\nreal output has {} rows, expected empty", + case_id, + opt, + actual.num_rows() + ), } } @@ -1273,10 +1561,10 @@ mod test { ], )]; - let expected_output = vec![ + let expected_output = Some( DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![9, 8, 7])]) .unwrap(), - ]; + ); run_test( 1000, @@ -1332,9 +1620,9 @@ mod test { ), ]; - let expected_output = vec![ + let expected_output = Some( DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![15, 14])]).unwrap(), - ]; + ); run_test( 1001, @@ -1369,9 +1657,9 @@ mod test { ], )]; - let expected_output = vec![ + let expected_output = Some( DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![1, 2])]).unwrap(), - ]; + ); run_test( 1002, @@ -1403,23 +1691,24 @@ mod test { // Create 3 partitions, each with more data than the limit // limit=2 per partition, so total expected output = 6 rows // After producing 6 rows, early termination should kick in + // For descending sort, ranges must be ordered by (end DESC, start DESC) let input_ranged_data = vec![ ( PartitionRange { - start: Timestamp::new(0, unit.into()), - end: Timestamp::new(10, unit.into()), + start: Timestamp::new(20, unit.into()), + end: Timestamp::new(30, unit.into()), num_rows: 10, - identifier: 0, + identifier: 2, }, vec![ DfRecordBatch::try_new( schema.clone(), - vec![new_ts_array(unit, vec![1, 2, 3, 4, 5])], + vec![new_ts_array(unit, vec![21, 22, 23, 24, 25])], ) .unwrap(), DfRecordBatch::try_new( schema.clone(), - vec![new_ts_array(unit, vec![6, 7, 8, 9, 10])], + vec![new_ts_array(unit, vec![26, 27, 28, 29, 30])], ) .unwrap(), ], @@ -1446,20 +1735,20 @@ mod test { ), ( PartitionRange { - start: Timestamp::new(20, unit.into()), - end: Timestamp::new(30, unit.into()), + start: Timestamp::new(0, unit.into()), + end: Timestamp::new(10, unit.into()), num_rows: 10, - identifier: 2, + identifier: 0, }, vec![ DfRecordBatch::try_new( schema.clone(), - vec![new_ts_array(unit, vec![21, 22, 23, 24, 25])], + vec![new_ts_array(unit, vec![1, 2, 3, 4, 5])], ) .unwrap(), DfRecordBatch::try_new( schema.clone(), - vec![new_ts_array(unit, vec![26, 27, 28, 29, 30])], + vec![new_ts_array(unit, vec![6, 7, 8, 9, 10])], ) .unwrap(), ], @@ -1468,9 +1757,10 @@ mod test { // PartSort won't reorder `PartitionRange` (it assumes it's already ordered), so it will not read other partitions. // This case is just to verify that early termination works as expected. - let expected_output = vec![ - DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![9, 8])]).unwrap(), - ]; + // First partition [20, 30) produces top 2 values: 29, 28 + let expected_output = Some( + DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![29, 28])]).unwrap(), + ); run_test( 1003, @@ -1486,4 +1776,1053 @@ mod test { ) .await; } + + /// Example: + /// - Range [70, 100) has data [80, 90, 95] + /// - Range [50, 100) has data [55, 65, 75, 85, 95] + #[tokio::test] + async fn test_primary_end_grouping_with_limit() { + let unit = TimeUnit::Millisecond; + let schema = Arc::new(Schema::new(vec![Field::new( + "ts", + DataType::Timestamp(unit, None), + false, + )])); + + // Two ranges with the same end (100) - they should be grouped together + // For descending, ranges are ordered by (end DESC, start DESC) + // So [70, 100) comes before [50, 100) (70 > 50) + let input_ranged_data = vec![ + ( + PartitionRange { + start: Timestamp::new(70, unit.into()), + end: Timestamp::new(100, unit.into()), + num_rows: 3, + identifier: 0, + }, + vec![ + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![80, 90, 95])], + ) + .unwrap(), + ], + ), + ( + PartitionRange { + start: Timestamp::new(50, unit.into()), + end: Timestamp::new(100, unit.into()), + num_rows: 5, + identifier: 1, + }, + vec![ + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![55, 65, 75, 85, 95])], + ) + .unwrap(), + ], + ), + ]; + + // With limit=4, descending: top 4 values from combined data + // Combined: [80, 90, 95, 55, 65, 75, 85, 95] -> sorted desc: [95, 95, 90, 85, 80, 75, 65, 55] + // Top 4: [95, 95, 90, 85] + let expected_output = Some( + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![95, 95, 90, 85])], + ) + .unwrap(), + ); + + run_test( + 2000, + input_ranged_data, + schema.clone(), + SortOptions { + descending: true, + ..Default::default() + }, + Some(4), + expected_output, + None, + ) + .await; + } + + /// Test case with three ranges demonstrating the "keep pulling" behavior. + /// After processing ranges with end=100, the smallest value in top-k might still + /// be reachable by the next group. + /// + /// Ranges: [70, 100), [50, 100), [40, 95) + /// With descending sort and limit=4: + /// - Group 1 (end=100): [70, 100) and [50, 100) merged + /// - Group 2 (end=95): [40, 95) + /// After group 1, smallest in top-4 is 85. Range [40, 95) could have values >= 85, + /// so we continue to group 2. + #[tokio::test] + async fn test_three_ranges_keep_pulling() { + let unit = TimeUnit::Millisecond; + let schema = Arc::new(Schema::new(vec![Field::new( + "ts", + DataType::Timestamp(unit, None), + false, + )])); + + // Three ranges, two with same end (100), one with different end (95) + let input_ranged_data = vec![ + ( + PartitionRange { + start: Timestamp::new(70, unit.into()), + end: Timestamp::new(100, unit.into()), + num_rows: 3, + identifier: 0, + }, + vec![ + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![80, 90, 95])], + ) + .unwrap(), + ], + ), + ( + PartitionRange { + start: Timestamp::new(50, unit.into()), + end: Timestamp::new(100, unit.into()), + num_rows: 3, + identifier: 1, + }, + vec![ + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![55, 75, 85])], + ) + .unwrap(), + ], + ), + ( + PartitionRange { + start: Timestamp::new(40, unit.into()), + end: Timestamp::new(95, unit.into()), + num_rows: 3, + identifier: 2, + }, + vec![ + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![45, 65, 94])], + ) + .unwrap(), + ], + ), + ]; + + // All data: [80, 90, 95, 55, 75, 85, 45, 65, 94] + // Sorted descending: [95, 94, 90, 85, 80, 75, 65, 55, 45] + // With limit=4: should be top 4 largest values across all ranges: [95, 94, 90, 85] + let expected_output = Some( + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![95, 94, 90, 85])], + ) + .unwrap(), + ); + + run_test( + 2001, + input_ranged_data, + schema.clone(), + SortOptions { + descending: true, + ..Default::default() + }, + Some(4), + expected_output, + None, + ) + .await; + } + + /// Test early termination based on threshold comparison with next group. + /// When the threshold (smallest value for descending) is >= next group's primary end, + /// we can stop early because the next group cannot have better values. + #[tokio::test] + async fn test_threshold_based_early_termination() { + let unit = TimeUnit::Millisecond; + let schema = Arc::new(Schema::new(vec![Field::new( + "ts", + DataType::Timestamp(unit, None), + false, + )])); + + // Group 1 (end=100) has 6 rows, TopK will keep top 4 + // Group 2 (end=90) has 3 rows - should NOT be processed because + // threshold (96) >= next_primary_end (90) + let input_ranged_data = vec![ + ( + PartitionRange { + start: Timestamp::new(70, unit.into()), + end: Timestamp::new(100, unit.into()), + num_rows: 6, + identifier: 0, + }, + vec![ + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![94, 95, 96, 97, 98, 99])], + ) + .unwrap(), + ], + ), + ( + PartitionRange { + start: Timestamp::new(50, unit.into()), + end: Timestamp::new(90, unit.into()), + num_rows: 3, + identifier: 1, + }, + vec![ + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![85, 86, 87])], + ) + .unwrap(), + ], + ), + ]; + + // With limit=4, descending: top 4 from group 1 are [99, 98, 97, 96] + // Threshold is 96, next group's primary_end is 90 + // Since 96 >= 90, we stop after group 1 + let expected_output = Some( + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![99, 98, 97, 96])], + ) + .unwrap(), + ); + + run_test( + 2002, + input_ranged_data, + schema.clone(), + SortOptions { + descending: true, + ..Default::default() + }, + Some(4), + expected_output, + Some(9), // Pull both batches since all rows fall within the first range + ) + .await; + } + + /// Test that we continue to next group when threshold is within next group's range. + /// Even after fulfilling limit, if threshold < next_primary_end (descending), + /// we would need to continue... but limit exhaustion stops us first. + #[tokio::test] + async fn test_continue_when_threshold_in_next_group_range() { + let unit = TimeUnit::Millisecond; + let schema = Arc::new(Schema::new(vec![Field::new( + "ts", + DataType::Timestamp(unit, None), + false, + )])); + + // Group 1 (end=100) has 6 rows, TopK will keep top 4 + // Group 2 (end=98) has 3 rows - threshold (96) < 98, so next group + // could theoretically have better values. But limit exhaustion stops us. + // Note: Data values must not overlap between ranges to avoid ambiguity. + let input_ranged_data = vec![ + ( + PartitionRange { + start: Timestamp::new(70, unit.into()), + end: Timestamp::new(100, unit.into()), + num_rows: 6, + identifier: 0, + }, + vec![ + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![94, 95, 96, 97, 98, 99])], + ) + .unwrap(), + ], + ), + ( + PartitionRange { + start: Timestamp::new(50, unit.into()), + end: Timestamp::new(98, unit.into()), + num_rows: 3, + identifier: 1, + }, + vec![ + // Values must be < 70 (outside group 1's range) to avoid ambiguity + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![55, 60, 65])], + ) + .unwrap(), + ], + ), + ]; + + // With limit=4, we get [99, 98, 97, 96] from group 1 + // Threshold is 96, next group's primary_end is 98 + // 96 < 98, so threshold check says "could continue" + // But limit is exhausted (0), so we stop anyway + let expected_output = Some( + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![99, 98, 97, 96])], + ) + .unwrap(), + ); + + // Note: We pull 9 rows (both batches) because we need to read batch 2 + // to detect the group boundary, even though we stop after outputting group 1. + run_test( + 2003, + input_ranged_data, + schema.clone(), + SortOptions { + descending: true, + ..Default::default() + }, + Some(4), + expected_output, + Some(9), // Pull both batches to detect boundary + ) + .await; + } + + /// Test ascending sort with threshold-based early termination. + #[tokio::test] + async fn test_ascending_threshold_early_termination() { + let unit = TimeUnit::Millisecond; + let schema = Arc::new(Schema::new(vec![Field::new( + "ts", + DataType::Timestamp(unit, None), + false, + )])); + + // For ascending: primary_end is start, ranges sorted by (start ASC, end ASC) + // Group 1 (start=10) has 6 rows + // Group 2 (start=20) has 3 rows - should NOT be processed because + // threshold (13) < next_primary_end (20) + let input_ranged_data = vec![ + ( + PartitionRange { + start: Timestamp::new(10, unit.into()), + end: Timestamp::new(50, unit.into()), + num_rows: 6, + identifier: 0, + }, + vec![ + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![10, 11, 12, 13, 14, 15])], + ) + .unwrap(), + ], + ), + ( + PartitionRange { + start: Timestamp::new(20, unit.into()), + end: Timestamp::new(60, unit.into()), + num_rows: 3, + identifier: 1, + }, + vec![ + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![25, 30, 35])], + ) + .unwrap(), + ], + ), + // still read this batch to detect group boundary(?) + ( + PartitionRange { + start: Timestamp::new(60, unit.into()), + end: Timestamp::new(70, unit.into()), + num_rows: 2, + identifier: 1, + }, + vec![ + DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![60, 61])]) + .unwrap(), + ], + ), + // after boundary detected, this following one should not be read + ( + PartitionRange { + start: Timestamp::new(61, unit.into()), + end: Timestamp::new(70, unit.into()), + num_rows: 2, + identifier: 1, + }, + vec![ + DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![71, 72])]) + .unwrap(), + ], + ), + ]; + + // With limit=4, ascending: top 4 (smallest) from group 1 are [10, 11, 12, 13] + // Threshold is 13 (largest in top-k), next group's primary_end is 20 + // Since 13 < 20, we stop after group 1 (no value in group 2 can be < 13) + let expected_output = Some( + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![10, 11, 12, 13])], + ) + .unwrap(), + ); + + run_test( + 2004, + input_ranged_data, + schema.clone(), + SortOptions { + descending: false, + ..Default::default() + }, + Some(4), + expected_output, + Some(11), // Pull first two batches to detect boundary + ) + .await; + } + + #[tokio::test] + async fn test_ascending_threshold_early_termination_case_two() { + let unit = TimeUnit::Millisecond; + let schema = Arc::new(Schema::new(vec![Field::new( + "ts", + DataType::Timestamp(unit, None), + false, + )])); + + // For ascending: primary_end is start, ranges sorted by (start ASC, end ASC) + // Group 1 (start=0) has 4 rows, Group 2 (start=4) has 1 row, Group 3 (start=5) has 4 rows + // After reading all data: [9,10,11,12, 21, 5,6,7,8] + // Sorted ascending: [5,6,7,8, 9,10,11,12, 21] + // With limit=4, output should be smallest 4: [5,6,7,8] + // Algorithm continues reading until start=42 > threshold=8, confirming no smaller values exist + let input_ranged_data = vec![ + ( + PartitionRange { + start: Timestamp::new(0, unit.into()), + end: Timestamp::new(20, unit.into()), + num_rows: 4, + identifier: 0, + }, + vec![ + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![9, 10, 11, 12])], + ) + .unwrap(), + ], + ), + ( + PartitionRange { + start: Timestamp::new(4, unit.into()), + end: Timestamp::new(25, unit.into()), + num_rows: 1, + identifier: 1, + }, + vec![ + DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![21])]) + .unwrap(), + ], + ), + ( + PartitionRange { + start: Timestamp::new(5, unit.into()), + end: Timestamp::new(25, unit.into()), + num_rows: 4, + identifier: 1, + }, + vec![ + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![5, 6, 7, 8])], + ) + .unwrap(), + ], + ), + // This still will be read to detect boundary, but should not contribute to output + ( + PartitionRange { + start: Timestamp::new(42, unit.into()), + end: Timestamp::new(52, unit.into()), + num_rows: 2, + identifier: 1, + }, + vec![ + DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![42, 51])]) + .unwrap(), + ], + ), + // This following one should not be read after boundary detected + ( + PartitionRange { + start: Timestamp::new(48, unit.into()), + end: Timestamp::new(53, unit.into()), + num_rows: 2, + identifier: 1, + }, + vec![ + DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![48, 51])]) + .unwrap(), + ], + ), + ]; + + // With limit=4, ascending: after processing all ranges, smallest 4 are [5, 6, 7, 8] + // Threshold is 8 (4th smallest value), algorithm reads until start=42 > threshold=8 + let expected_output = Some( + DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![5, 6, 7, 8])]) + .unwrap(), + ); + + run_test( + 2005, + input_ranged_data, + schema.clone(), + SortOptions { + descending: false, + ..Default::default() + }, + Some(4), + expected_output, + Some(11), // Read first 4 ranges to confirm threshold boundary + ) + .await; + } + + /// Test early stop behavior with null values in sort column. + /// Verifies that nulls are handled correctly based on nulls_first option. + #[tokio::test] + async fn test_early_stop_with_nulls() { + let unit = TimeUnit::Millisecond; + let schema = Arc::new(Schema::new(vec![Field::new( + "ts", + DataType::Timestamp(unit, None), + true, // nullable + )])); + + // Helper function to create nullable timestamp array + let new_nullable_ts_array = |unit: TimeUnit, arr: Vec>| -> ArrayRef { + match unit { + TimeUnit::Second => Arc::new(TimestampSecondArray::from(arr)) as ArrayRef, + TimeUnit::Millisecond => Arc::new(TimestampMillisecondArray::from(arr)) as ArrayRef, + TimeUnit::Microsecond => Arc::new(TimestampMicrosecondArray::from(arr)) as ArrayRef, + TimeUnit::Nanosecond => Arc::new(TimestampNanosecondArray::from(arr)) as ArrayRef, + } + }; + + // Test case 1: nulls_first=true, null values should appear first + // Group 1 (end=100): [null, null, 99, 98, 97] -> with limit=3, top 3 are [null, null, 99] + // Threshold is 99, next group end=90, since 99 >= 90, we should stop early + let input_ranged_data = vec![ + ( + PartitionRange { + start: Timestamp::new(70, unit.into()), + end: Timestamp::new(100, unit.into()), + num_rows: 5, + identifier: 0, + }, + vec![ + DfRecordBatch::try_new( + schema.clone(), + vec![new_nullable_ts_array( + unit, + vec![Some(99), Some(98), None, Some(97), None], + )], + ) + .unwrap(), + ], + ), + ( + PartitionRange { + start: Timestamp::new(50, unit.into()), + end: Timestamp::new(90, unit.into()), + num_rows: 3, + identifier: 1, + }, + vec![ + DfRecordBatch::try_new( + schema.clone(), + vec![new_nullable_ts_array( + unit, + vec![Some(89), Some(88), Some(87)], + )], + ) + .unwrap(), + ], + ), + ]; + + // With nulls_first=true, nulls sort before all values + // For descending, order is: null, null, 99, 98, 97 + // With limit=3, we get: null, null, 99 + let expected_output = Some( + DfRecordBatch::try_new( + schema.clone(), + vec![new_nullable_ts_array(unit, vec![None, None, Some(99)])], + ) + .unwrap(), + ); + + run_test( + 3000, + input_ranged_data, + schema.clone(), + SortOptions { + descending: true, + nulls_first: true, + }, + Some(3), + expected_output, + Some(8), // Must read both batches to detect group boundary + ) + .await; + + // Test case 2: nulls_last=true, null values should appear last + // Group 1 (end=100): [99, 98, 97, null, null] -> with limit=3, top 3 are [99, 98, 97] + // Threshold is 97, next group end=90, since 97 >= 90, we should stop early + let input_ranged_data = vec![ + ( + PartitionRange { + start: Timestamp::new(70, unit.into()), + end: Timestamp::new(100, unit.into()), + num_rows: 5, + identifier: 0, + }, + vec![ + DfRecordBatch::try_new( + schema.clone(), + vec![new_nullable_ts_array( + unit, + vec![Some(99), Some(98), Some(97), None, None], + )], + ) + .unwrap(), + ], + ), + ( + PartitionRange { + start: Timestamp::new(50, unit.into()), + end: Timestamp::new(90, unit.into()), + num_rows: 3, + identifier: 1, + }, + vec![ + DfRecordBatch::try_new( + schema.clone(), + vec![new_nullable_ts_array( + unit, + vec![Some(89), Some(88), Some(87)], + )], + ) + .unwrap(), + ], + ), + ]; + + // With nulls_last=false (equivalent to nulls_first=false), values sort before nulls + // For descending, order is: 99, 98, 97, null, null + // With limit=3, we get: 99, 98, 97 + let expected_output = Some( + DfRecordBatch::try_new( + schema.clone(), + vec![new_nullable_ts_array( + unit, + vec![Some(99), Some(98), Some(97)], + )], + ) + .unwrap(), + ); + + run_test( + 3001, + input_ranged_data, + schema.clone(), + SortOptions { + descending: true, + nulls_first: false, + }, + Some(3), + expected_output, + Some(8), // Must read both batches to detect group boundary + ) + .await; + } + + /// Test early stop behavior when there's only one group (no next group). + /// In this case, can_stop_early should return false and we should process all data. + #[tokio::test] + async fn test_early_stop_single_group() { + let unit = TimeUnit::Millisecond; + let schema = Arc::new(Schema::new(vec![Field::new( + "ts", + DataType::Timestamp(unit, None), + false, + )])); + + // Only one group (all ranges have the same end), no next group to compare against + let input_ranged_data = vec![ + ( + PartitionRange { + start: Timestamp::new(70, unit.into()), + end: Timestamp::new(100, unit.into()), + num_rows: 6, + identifier: 0, + }, + vec![ + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![94, 95, 96, 97, 98, 99])], + ) + .unwrap(), + ], + ), + ( + PartitionRange { + start: Timestamp::new(50, unit.into()), + end: Timestamp::new(100, unit.into()), + num_rows: 3, + identifier: 1, + }, + vec![ + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![85, 86, 87])], + ) + .unwrap(), + ], + ), + ]; + + // Even though we have enough data in first range, we must process all + // because there's no next group to compare threshold against + let expected_output = Some( + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![99, 98, 97, 96])], + ) + .unwrap(), + ); + + run_test( + 3002, + input_ranged_data, + schema.clone(), + SortOptions { + descending: true, + ..Default::default() + }, + Some(4), + expected_output, + Some(9), // Must read all batches since no early stop is possible + ) + .await; + } + + /// Test early stop behavior when threshold exactly equals next group's boundary. + #[tokio::test] + async fn test_early_stop_exact_boundary_equality() { + let unit = TimeUnit::Millisecond; + let schema = Arc::new(Schema::new(vec![Field::new( + "ts", + DataType::Timestamp(unit, None), + false, + )])); + + // Test case 1: Descending sort, threshold == next_group_end + // Group 1 (end=100): data up to 90, threshold = 90, next_group_end = 90 + // Since 90 >= 90, we should stop early + let input_ranged_data = vec![ + ( + PartitionRange { + start: Timestamp::new(70, unit.into()), + end: Timestamp::new(100, unit.into()), + num_rows: 4, + identifier: 0, + }, + vec![ + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![92, 91, 90, 89])], + ) + .unwrap(), + ], + ), + ( + PartitionRange { + start: Timestamp::new(50, unit.into()), + end: Timestamp::new(90, unit.into()), + num_rows: 3, + identifier: 1, + }, + vec![ + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![88, 87, 86])], + ) + .unwrap(), + ], + ), + ]; + + let expected_output = Some( + DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![92, 91, 90])]) + .unwrap(), + ); + + run_test( + 3003, + input_ranged_data, + schema.clone(), + SortOptions { + descending: true, + ..Default::default() + }, + Some(3), + expected_output, + Some(7), // Must read both batches to detect boundary + ) + .await; + + // Test case 2: Ascending sort, threshold == next_group_start + // Group 1 (start=10): data from 10, threshold = 20, next_group_start = 20 + // Since 20 < 20 is false, we should continue + let input_ranged_data = vec![ + ( + PartitionRange { + start: Timestamp::new(10, unit.into()), + end: Timestamp::new(50, unit.into()), + num_rows: 4, + identifier: 0, + }, + vec![ + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![10, 15, 20, 25])], + ) + .unwrap(), + ], + ), + ( + PartitionRange { + start: Timestamp::new(20, unit.into()), + end: Timestamp::new(60, unit.into()), + num_rows: 3, + identifier: 1, + }, + vec![ + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![21, 22, 23])], + ) + .unwrap(), + ], + ), + ]; + + let expected_output = Some( + DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![10, 15, 20])]) + .unwrap(), + ); + + run_test( + 3004, + input_ranged_data, + schema.clone(), + SortOptions { + descending: false, + ..Default::default() + }, + Some(3), + expected_output, + Some(7), // Must read both batches since 20 is not < 20 + ) + .await; + } + + /// Test early stop behavior with empty partition groups. + #[tokio::test] + async fn test_early_stop_with_empty_partitions() { + let unit = TimeUnit::Millisecond; + let schema = Arc::new(Schema::new(vec![Field::new( + "ts", + DataType::Timestamp(unit, None), + false, + )])); + + // Test case 1: First group is empty, second group has data + let input_ranged_data = vec![ + ( + PartitionRange { + start: Timestamp::new(70, unit.into()), + end: Timestamp::new(100, unit.into()), + num_rows: 0, + identifier: 0, + }, + vec![ + // Empty batch for first range + DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![])]) + .unwrap(), + ], + ), + ( + PartitionRange { + start: Timestamp::new(50, unit.into()), + end: Timestamp::new(100, unit.into()), + num_rows: 0, + identifier: 1, + }, + vec![ + // Empty batch for second range + DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![])]) + .unwrap(), + ], + ), + ( + PartitionRange { + start: Timestamp::new(30, unit.into()), + end: Timestamp::new(80, unit.into()), + num_rows: 4, + identifier: 2, + }, + vec![ + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![74, 75, 76, 77])], + ) + .unwrap(), + ], + ), + ( + PartitionRange { + start: Timestamp::new(10, unit.into()), + end: Timestamp::new(60, unit.into()), + num_rows: 3, + identifier: 3, + }, + vec![ + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![58, 59, 60])], + ) + .unwrap(), + ], + ), + ]; + + // Group 1 (end=100) is empty, Group 2 (end=80) has data + // Should continue to Group 2 since Group 1 has no data + let expected_output = Some( + DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![77, 76])]).unwrap(), + ); + + run_test( + 3005, + input_ranged_data, + schema.clone(), + SortOptions { + descending: true, + ..Default::default() + }, + Some(2), + expected_output, + Some(7), // Must read until finding actual data + ) + .await; + + // Test case 2: Empty partitions between data groups + let input_ranged_data = vec![ + ( + PartitionRange { + start: Timestamp::new(70, unit.into()), + end: Timestamp::new(100, unit.into()), + num_rows: 4, + identifier: 0, + }, + vec![ + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![96, 97, 98, 99])], + ) + .unwrap(), + ], + ), + ( + PartitionRange { + start: Timestamp::new(50, unit.into()), + end: Timestamp::new(90, unit.into()), + num_rows: 0, + identifier: 1, + }, + vec![ + // Empty range - should be skipped + DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![])]) + .unwrap(), + ], + ), + ( + PartitionRange { + start: Timestamp::new(30, unit.into()), + end: Timestamp::new(70, unit.into()), + num_rows: 0, + identifier: 2, + }, + vec![ + // Another empty range + DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![])]) + .unwrap(), + ], + ), + ( + PartitionRange { + start: Timestamp::new(10, unit.into()), + end: Timestamp::new(50, unit.into()), + num_rows: 3, + identifier: 3, + }, + vec![ + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![48, 49, 50])], + ) + .unwrap(), + ], + ), + ]; + + // With limit=2 from group 1: [99, 98], threshold=98, next group end=50 + // Since 98 >= 50, we should stop early + let expected_output = Some( + DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![99, 98])]).unwrap(), + ); + + run_test( + 3006, + input_ranged_data, + schema.clone(), + SortOptions { + descending: true, + ..Default::default() + }, + Some(2), + expected_output, + Some(7), // Must read to detect early stop condition + ) + .await; + } } diff --git a/src/query/src/window_sort.rs b/src/query/src/window_sort.rs index e497656796..fad4e95db4 100644 --- a/src/query/src/window_sort.rs +++ b/src/query/src/window_sort.rs @@ -84,23 +84,31 @@ pub struct WindowedSortExec { properties: PlanProperties, } -fn check_partition_range_monotonicity( +/// Checks that partition ranges are sorted correctly for the given sort direction. +/// - Descending: sorted by (end DESC, start DESC) - shorter ranges first when ends are equal +/// - Ascending: sorted by (start ASC, end ASC) - shorter ranges first when starts are equal +pub fn check_partition_range_monotonicity( ranges: &[Vec], descending: bool, ) -> Result<()> { let is_valid = ranges.iter().all(|r| { if descending { - r.windows(2).all(|w| w[0].end >= w[1].end) + // Primary: end descending, Secondary: start descending (shorter range first) + r.windows(2) + .all(|w| w[0].end > w[1].end || (w[0].end == w[1].end && w[0].start >= w[1].start)) } else { - r.windows(2).all(|w| w[0].start <= w[1].start) + // Primary: start ascending, Secondary: end ascending (shorter range first) + r.windows(2).all(|w| { + w[0].start < w[1].start || (w[0].start == w[1].start && w[0].end <= w[1].end) + }) } }); if !is_valid { let msg = if descending { - "Input `PartitionRange`s's upper bound is not monotonic non-increase" + "Input `PartitionRange`s are not sorted by (end DESC, start DESC)" } else { - "Input `PartitionRange`s's lower bound is not monotonic non-decrease" + "Input `PartitionRange`s are not sorted by (start ASC, end ASC)" }; let plain_error = PlainError::new(msg.to_string(), StatusCode::Unexpected); Err(BoxedError::new(plain_error)).context(QueryExecutionSnafu {}) @@ -2829,8 +2837,9 @@ mod test { // generate input data for part_id in 0..rng.usize(0..part_cnt_bound) { let (start, end) = if descending { + // Use 1..=range_offset_bound to ensure strictly decreasing end values let end = bound_val - .map(|i| i - rng.i64(0..range_offset_bound)) + .map(|i| i - rng.i64(1..=range_offset_bound)) .unwrap_or_else(|| rng.i64(..)); bound_val = Some(end); let start = end - rng.i64(1..range_size_bound); @@ -2838,8 +2847,9 @@ mod test { let end = Timestamp::new(end, unit.into()); (start, end) } else { + // Use 1..=range_offset_bound to ensure strictly increasing start values let start = bound_val - .map(|i| i + rng.i64(0..range_offset_bound)) + .map(|i| i + rng.i64(1..=range_offset_bound)) .unwrap_or_else(|| rng.i64(..)); bound_val = Some(start); let end = start + rng.i64(1..range_size_bound);