From 6817a376b5e7848e0dd78a9ea8c30e2e3d015c85 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Wed, 10 Dec 2025 15:44:44 +0800 Subject: [PATCH] fix: part sort behavior (#7374) * fix: part sort behavior Signed-off-by: Ruihang Xia * tune tests Signed-off-by: Ruihang Xia * debug assertion and remove produced count Signed-off-by: Ruihang Xia --------- Signed-off-by: Ruihang Xia --- src/query/src/part_sort.rs | 343 ++++++++++++++++++++++++++++++++--- src/query/src/test_util.rs | 19 +- src/query/src/window_sort.rs | 2 +- 3 files changed, 336 insertions(+), 28 deletions(-) diff --git a/src/query/src/part_sort.rs b/src/query/src/part_sort.rs index ebf4fddc1e..22f968f6f9 100644 --- a/src/query/src/part_sort.rs +++ b/src/query/src/part_sort.rs @@ -284,7 +284,6 @@ struct PartSortStream { buffer: PartSortBuffer, expression: PhysicalSortExpr, limit: Option, - produced: usize, input: DfSendableRecordBatchStream, input_complete: bool, schema: SchemaRef, @@ -340,7 +339,6 @@ impl PartSortStream { buffer, expression: sort.expression.clone(), limit, - produced: 0, input, input_complete: false, schema: sort.input.schema(), @@ -565,7 +563,6 @@ impl PartSortStream { ) })?; - self.produced += sorted.num_rows(); drop(full_input); // here remove both buffer and full_input memory self.reservation.shrink(2 * total_mem); @@ -666,6 +663,16 @@ impl PartSortStream { let sorted_batch = self.sort_buffer(); // step to next proper PartitionRange self.cur_part_idx += 1; + + // If we've processed all partitions, discard remaining data + if self.cur_part_idx >= self.partition_ranges.len() { + // assert there is no data beyond the last partition range (remaining is empty). + // it would be acceptable even if it happens, because `remaining_range` will be discarded anyway. + debug_assert!(remaining_range.num_rows() == 0); + + return sorted_batch.map(|x| if x.num_rows() == 0 { None } else { Some(x) }); + } + let next_sort_column = sort_column.slice(idx, batch.num_rows() - idx); if self.try_find_next_range(&next_sort_column)?.is_some() { // remaining batch still contains data that exceeds the current partition range @@ -687,6 +694,12 @@ impl PartSortStream { cx: &mut Context<'_>, ) -> Poll>> { loop { + // Early termination: if we've already produced enough rows, + // don't poll more input - just return + if matches!(self.limit, Some(0)) { + return Poll::Ready(None); + } + // no more input, sort the buffer and return if self.input_complete { if self.buffer.is_empty() { @@ -701,7 +714,24 @@ impl PartSortStream { if let Some(evaluating_batch) = self.evaluating_batch.take() && evaluating_batch.num_rows() != 0 { + // Check if we've already processed all partitions + if self.cur_part_idx >= self.partition_ranges.len() { + // All partitions processed, discard remaining data + if self.buffer.is_empty() { + return Poll::Ready(None); + } else { + let sorted_batch = self.sort_buffer()?; + self.limit = self + .limit + .map(|l| l.saturating_sub(sorted_batch.num_rows())); + return Poll::Ready(Some(Ok(sorted_batch))); + } + } + if let Some(sorted_batch) = self.split_batch(evaluating_batch)? { + self.limit = self + .limit + .map(|l| l.saturating_sub(sorted_batch.num_rows())); return Poll::Ready(Some(Ok(sorted_batch))); } else { continue; @@ -713,6 +743,9 @@ impl PartSortStream { match res { Poll::Ready(Some(Ok(batch))) => { if let Some(sorted_batch) = self.split_batch(batch)? { + self.limit = self + .limit + .map(|l| l.saturating_sub(sorted_batch.num_rows())); return Poll::Ready(Some(Ok(sorted_batch))); } else { continue; @@ -896,22 +929,30 @@ mod test { output_data.push(cur_data); } - let expected_output = output_data + let mut limit_remains = limit; + let mut expected_output = output_data .into_iter() .map(|a| { DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, a)]).unwrap() }) .map(|rb| { // trim expected output with limit - if let Some(limit) = limit - && rb.num_rows() > limit - { - rb.slice(0, limit) + if let Some(limit) = limit_remains.as_mut() { + let rb = rb.slice(0, (*limit).min(rb.num_rows())); + *limit = limit.saturating_sub(rb.num_rows()); + rb } else { rb } }) .collect_vec(); + while let Some(rb) = expected_output.last() { + if rb.num_rows() == 0 { + expected_output.pop(); + } else { + break; + } + } test_cases.push(( case_id, @@ -932,13 +973,14 @@ mod test { opt, limit, expected_output, + None, ) .await; } } #[tokio::test] - async fn simple_case() { + async fn simple_cases() { let testcases = vec![ ( TimeUnit::Millisecond, @@ -1027,7 +1069,7 @@ mod test { ], true, Some(2), - vec![vec![19, 17], vec![12, 11], vec![9, 8], vec![4, 3]], + vec![vec![19, 17]], ), ]; @@ -1080,6 +1122,7 @@ mod test { opt, limit, expected_output, + None, ) .await; } @@ -1093,6 +1136,7 @@ mod test { opt: SortOptions, limit: Option, expected_output: Vec, + expected_polled_rows: Option, ) { for rb in &expected_output { if let Some(limit) = limit { @@ -1104,16 +1148,15 @@ mod test { ); } } - let (ranges, batches): (Vec<_>, Vec<_>) = input_ranged_data.clone().into_iter().unzip(); - let batches = batches - .into_iter() - .flat_map(|mut cols| { - cols.push(DfRecordBatch::new_empty(schema.clone())); - cols - }) - .collect_vec(); - let mock_input = MockInputExec::new(batches, schema.clone()); + let mut data_partition = Vec::with_capacity(input_ranged_data.len()); + let mut ranges = Vec::with_capacity(input_ranged_data.len()); + for (part_range, batches) in input_ranged_data { + data_partition.push(batches); + ranges.push(part_range); + } + + let mock_input = Arc::new(MockInputExec::new(data_partition, schema.clone())); let exec = PartSortExec::new( PhysicalSortExpr { @@ -1122,7 +1165,7 @@ mod test { }, limit, vec![ranges.clone()], - Arc::new(mock_input), + mock_input.clone(), ); let exec_stream = exec.execute(0, Arc::new(TaskContext::default())).unwrap(); @@ -1131,12 +1174,17 @@ mod test { // a makeshift solution for compare large data if real_output != expected_output { let mut first_diff = 0; + let mut is_diff_found = false; for (idx, (lhs, rhs)) in real_output.iter().zip(expected_output.iter()).enumerate() { - if lhs != rhs { + if lhs.slice(0, rhs.num_rows()) != *rhs { first_diff = idx; + is_diff_found = true; break; } } + if !is_diff_found { + return; + } println!("first diff batch at {}", first_diff); println!( "ranges: {:?}", @@ -1175,8 +1223,14 @@ mod test { let buf = String::from_utf8_lossy(&buf); full_msg += &format!("case_id:{case_id}, expected_output \n{buf}"); } + + if let Some(expected_polled_rows) = expected_polled_rows { + let input_pulled_rows = mock_input.metrics().unwrap().output_rows().unwrap(); + assert_eq!(input_pulled_rows, expected_polled_rows); + } + panic!( - "case_{} failed, opt: {:?},\n real output has {} batches, {} rows, expected has {} batches with {} rows\nfull msg: {}", + "case_{} failed (limit {limit:?}), opt: {:?},\n real output has {} batches, {} rows, expected has {} batches with {} rows\nfull msg: {}", case_id, opt, real_output.len(), @@ -1187,4 +1241,249 @@ mod test { ); } } + + /// Test that verifies the limit is correctly applied per partition when + /// multiple batches are received for the same partition. + #[tokio::test] + async fn test_limit_with_multiple_batches_per_partition() { + let unit = TimeUnit::Millisecond; + let schema = Arc::new(Schema::new(vec![Field::new( + "ts", + DataType::Timestamp(unit, None), + false, + )])); + + // Test case: Multiple batches in a single partition with limit=3 + // Input: 3 batches with [1,2,3], [4,5,6], [7,8,9] all in partition (0,10) + // Expected: Only top 3 values [9,8,7] for descending sort + let input_ranged_data = vec![( + PartitionRange { + start: Timestamp::new(0, unit.into()), + end: Timestamp::new(10, unit.into()), + num_rows: 9, + identifier: 0, + }, + vec![ + DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![1, 2, 3])]) + .unwrap(), + DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![4, 5, 6])]) + .unwrap(), + DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![7, 8, 9])]) + .unwrap(), + ], + )]; + + let expected_output = vec![ + DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![9, 8, 7])]) + .unwrap(), + ]; + + run_test( + 1000, + input_ranged_data, + schema.clone(), + SortOptions { + descending: true, + ..Default::default() + }, + Some(3), + expected_output, + None, + ) + .await; + + // Test case: Multiple batches across multiple partitions with limit=2 + // Partition 0: batches [10,11,12], [13,14,15] -> top 2 descending = [15,14] + // Partition 1: batches [1,2,3], [4,5] -> top 2 descending = [5,4] + let input_ranged_data = vec![ + ( + PartitionRange { + start: Timestamp::new(10, unit.into()), + end: Timestamp::new(20, unit.into()), + num_rows: 6, + identifier: 0, + }, + vec![ + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![10, 11, 12])], + ) + .unwrap(), + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![13, 14, 15])], + ) + .unwrap(), + ], + ), + ( + PartitionRange { + start: Timestamp::new(0, unit.into()), + end: Timestamp::new(10, unit.into()), + num_rows: 5, + identifier: 1, + }, + vec![ + DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![1, 2, 3])]) + .unwrap(), + DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![4, 5])]) + .unwrap(), + ], + ), + ]; + + let expected_output = vec![ + DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![15, 14])]).unwrap(), + ]; + + run_test( + 1001, + input_ranged_data, + schema.clone(), + SortOptions { + descending: true, + ..Default::default() + }, + Some(2), + expected_output, + None, + ) + .await; + + // Test case: Ascending sort with limit + // Partition: batches [7,8,9], [4,5,6], [1,2,3] -> top 2 ascending = [1,2] + let input_ranged_data = vec![( + PartitionRange { + start: Timestamp::new(0, unit.into()), + end: Timestamp::new(10, unit.into()), + num_rows: 9, + identifier: 0, + }, + vec![ + DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![7, 8, 9])]) + .unwrap(), + DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![4, 5, 6])]) + .unwrap(), + DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![1, 2, 3])]) + .unwrap(), + ], + )]; + + let expected_output = vec![ + DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![1, 2])]).unwrap(), + ]; + + run_test( + 1002, + input_ranged_data, + schema.clone(), + SortOptions { + descending: false, + ..Default::default() + }, + Some(2), + expected_output, + None, + ) + .await; + } + + /// Test that verifies early termination behavior. + /// Once we've produced limit * num_partitions rows, we should stop + /// pulling from input stream. + #[tokio::test] + async fn test_early_termination() { + let unit = TimeUnit::Millisecond; + let schema = Arc::new(Schema::new(vec![Field::new( + "ts", + DataType::Timestamp(unit, None), + false, + )])); + + // Create 3 partitions, each with more data than the limit + // limit=2 per partition, so total expected output = 6 rows + // After producing 6 rows, early termination should kick in + let input_ranged_data = vec![ + ( + PartitionRange { + start: Timestamp::new(0, unit.into()), + end: Timestamp::new(10, unit.into()), + num_rows: 10, + identifier: 0, + }, + vec![ + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![1, 2, 3, 4, 5])], + ) + .unwrap(), + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![6, 7, 8, 9, 10])], + ) + .unwrap(), + ], + ), + ( + PartitionRange { + start: Timestamp::new(10, unit.into()), + end: Timestamp::new(20, unit.into()), + num_rows: 10, + identifier: 1, + }, + vec![ + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![11, 12, 13, 14, 15])], + ) + .unwrap(), + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![16, 17, 18, 19, 20])], + ) + .unwrap(), + ], + ), + ( + PartitionRange { + start: Timestamp::new(20, unit.into()), + end: Timestamp::new(30, unit.into()), + num_rows: 10, + identifier: 2, + }, + vec![ + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![21, 22, 23, 24, 25])], + ) + .unwrap(), + DfRecordBatch::try_new( + schema.clone(), + vec![new_ts_array(unit, vec![26, 27, 28, 29, 30])], + ) + .unwrap(), + ], + ), + ]; + + // PartSort won't reorder `PartitionRange` (it assumes it's already ordered), so it will not read other partitions. + // This case is just to verify that early termination works as expected. + let expected_output = vec![ + DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![9, 8])]).unwrap(), + ]; + + run_test( + 1003, + input_ranged_data, + schema.clone(), + SortOptions { + descending: true, + ..Default::default() + }, + Some(2), + expected_output, + Some(10), + ) + .await; + } } diff --git a/src/query/src/test_util.rs b/src/query/src/test_util.rs index f64718b84a..55891b0063 100644 --- a/src/query/src/test_util.rs +++ b/src/query/src/test_util.rs @@ -25,6 +25,7 @@ use arrow_schema::{SchemaRef, TimeUnit}; use common_recordbatch::{DfRecordBatch, DfSendableRecordBatchStream}; use datafusion::execution::{RecordBatchStream, TaskContext}; use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties}; use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; use futures::Stream; @@ -46,13 +47,14 @@ pub fn new_ts_array(unit: TimeUnit, arr: Vec) -> ArrayRef { #[derive(Debug)] pub struct MockInputExec { - input: Vec, + input: Vec>, schema: SchemaRef, properties: PlanProperties, + metrics: ExecutionPlanMetricsSet, } impl MockInputExec { - pub fn new(input: Vec, schema: SchemaRef) -> Self { + pub fn new(input: Vec>, schema: SchemaRef) -> Self { Self { properties: PlanProperties::new( EquivalenceProperties::new(schema.clone()), @@ -62,6 +64,7 @@ impl MockInputExec { ), input, schema, + metrics: ExecutionPlanMetricsSet::new(), } } } @@ -98,22 +101,28 @@ impl ExecutionPlan for MockInputExec { fn execute( &self, - _partition: usize, + partition: usize, _context: Arc, ) -> datafusion_common::Result { let stream = MockStream { - stream: self.input.clone(), + stream: self.input.clone().into_iter().flatten().collect(), schema: self.schema.clone(), idx: 0, + metrics: BaselineMetrics::new(&self.metrics, partition), }; Ok(Box::pin(stream)) } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } } struct MockStream { stream: Vec, schema: SchemaRef, idx: usize, + metrics: BaselineMetrics, } impl Stream for MockStream { @@ -125,7 +134,7 @@ impl Stream for MockStream { if self.idx < self.stream.len() { let ret = self.stream[self.idx].clone(); self.idx += 1; - Poll::Ready(Some(Ok(ret))) + self.metrics.record_poll(Poll::Ready(Some(Ok(ret)))) } else { Poll::Ready(None) } diff --git a/src/query/src/window_sort.rs b/src/query/src/window_sort.rs index eb0aa2d071..47ee8be75a 100644 --- a/src/query/src/window_sort.rs +++ b/src/query/src/window_sort.rs @@ -2528,7 +2528,7 @@ mod test { async fn run_test(&self) -> Vec { let (ranges, batches): (Vec<_>, Vec<_>) = self.input.clone().into_iter().unzip(); - let mock_input = MockInputExec::new(batches, self.schema.clone()); + let mock_input = MockInputExec::new(vec![batches], self.schema.clone()); let exec = WindowedSortExec::try_new( self.expression.clone(),