fix: part sort behavior (#7374)

* fix: part sort behavior

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* tune tests

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* debug assertion and remove produced count

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

---------

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
This commit is contained in:
Ruihang Xia
2025-12-10 15:44:44 +08:00
committed by GitHub
parent 4d1a587079
commit 6817a376b5
3 changed files with 336 additions and 28 deletions

View File

@@ -284,7 +284,6 @@ struct PartSortStream {
buffer: PartSortBuffer,
expression: PhysicalSortExpr,
limit: Option<usize>,
produced: usize,
input: DfSendableRecordBatchStream,
input_complete: bool,
schema: SchemaRef,
@@ -340,7 +339,6 @@ impl PartSortStream {
buffer,
expression: sort.expression.clone(),
limit,
produced: 0,
input,
input_complete: false,
schema: sort.input.schema(),
@@ -565,7 +563,6 @@ impl PartSortStream {
)
})?;
self.produced += sorted.num_rows();
drop(full_input);
// here remove both buffer and full_input memory
self.reservation.shrink(2 * total_mem);
@@ -666,6 +663,16 @@ impl PartSortStream {
let sorted_batch = self.sort_buffer();
// step to next proper PartitionRange
self.cur_part_idx += 1;
// If we've processed all partitions, discard remaining data
if self.cur_part_idx >= self.partition_ranges.len() {
// assert there is no data beyond the last partition range (remaining is empty).
// it would be acceptable even if it happens, because `remaining_range` will be discarded anyway.
debug_assert!(remaining_range.num_rows() == 0);
return sorted_batch.map(|x| if x.num_rows() == 0 { None } else { Some(x) });
}
let next_sort_column = sort_column.slice(idx, batch.num_rows() - idx);
if self.try_find_next_range(&next_sort_column)?.is_some() {
// remaining batch still contains data that exceeds the current partition range
@@ -687,6 +694,12 @@ impl PartSortStream {
cx: &mut Context<'_>,
) -> Poll<Option<datafusion_common::Result<DfRecordBatch>>> {
loop {
// Early termination: if we've already produced enough rows,
// don't poll more input - just return
if matches!(self.limit, Some(0)) {
return Poll::Ready(None);
}
// no more input, sort the buffer and return
if self.input_complete {
if self.buffer.is_empty() {
@@ -701,7 +714,24 @@ impl PartSortStream {
if let Some(evaluating_batch) = self.evaluating_batch.take()
&& evaluating_batch.num_rows() != 0
{
// Check if we've already processed all partitions
if self.cur_part_idx >= self.partition_ranges.len() {
// All partitions processed, discard remaining data
if self.buffer.is_empty() {
return Poll::Ready(None);
} else {
let sorted_batch = self.sort_buffer()?;
self.limit = self
.limit
.map(|l| l.saturating_sub(sorted_batch.num_rows()));
return Poll::Ready(Some(Ok(sorted_batch)));
}
}
if let Some(sorted_batch) = self.split_batch(evaluating_batch)? {
self.limit = self
.limit
.map(|l| l.saturating_sub(sorted_batch.num_rows()));
return Poll::Ready(Some(Ok(sorted_batch)));
} else {
continue;
@@ -713,6 +743,9 @@ impl PartSortStream {
match res {
Poll::Ready(Some(Ok(batch))) => {
if let Some(sorted_batch) = self.split_batch(batch)? {
self.limit = self
.limit
.map(|l| l.saturating_sub(sorted_batch.num_rows()));
return Poll::Ready(Some(Ok(sorted_batch)));
} else {
continue;
@@ -896,22 +929,30 @@ mod test {
output_data.push(cur_data);
}
let expected_output = output_data
let mut limit_remains = limit;
let mut expected_output = output_data
.into_iter()
.map(|a| {
DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, a)]).unwrap()
})
.map(|rb| {
// trim expected output with limit
if let Some(limit) = limit
&& rb.num_rows() > limit
{
rb.slice(0, limit)
if let Some(limit) = limit_remains.as_mut() {
let rb = rb.slice(0, (*limit).min(rb.num_rows()));
*limit = limit.saturating_sub(rb.num_rows());
rb
} else {
rb
}
})
.collect_vec();
while let Some(rb) = expected_output.last() {
if rb.num_rows() == 0 {
expected_output.pop();
} else {
break;
}
}
test_cases.push((
case_id,
@@ -932,13 +973,14 @@ mod test {
opt,
limit,
expected_output,
None,
)
.await;
}
}
#[tokio::test]
async fn simple_case() {
async fn simple_cases() {
let testcases = vec![
(
TimeUnit::Millisecond,
@@ -1027,7 +1069,7 @@ mod test {
],
true,
Some(2),
vec![vec![19, 17], vec![12, 11], vec![9, 8], vec![4, 3]],
vec![vec![19, 17]],
),
];
@@ -1080,6 +1122,7 @@ mod test {
opt,
limit,
expected_output,
None,
)
.await;
}
@@ -1093,6 +1136,7 @@ mod test {
opt: SortOptions,
limit: Option<usize>,
expected_output: Vec<DfRecordBatch>,
expected_polled_rows: Option<usize>,
) {
for rb in &expected_output {
if let Some(limit) = limit {
@@ -1104,16 +1148,15 @@ mod test {
);
}
}
let (ranges, batches): (Vec<_>, Vec<_>) = input_ranged_data.clone().into_iter().unzip();
let batches = batches
.into_iter()
.flat_map(|mut cols| {
cols.push(DfRecordBatch::new_empty(schema.clone()));
cols
})
.collect_vec();
let mock_input = MockInputExec::new(batches, schema.clone());
let mut data_partition = Vec::with_capacity(input_ranged_data.len());
let mut ranges = Vec::with_capacity(input_ranged_data.len());
for (part_range, batches) in input_ranged_data {
data_partition.push(batches);
ranges.push(part_range);
}
let mock_input = Arc::new(MockInputExec::new(data_partition, schema.clone()));
let exec = PartSortExec::new(
PhysicalSortExpr {
@@ -1122,7 +1165,7 @@ mod test {
},
limit,
vec![ranges.clone()],
Arc::new(mock_input),
mock_input.clone(),
);
let exec_stream = exec.execute(0, Arc::new(TaskContext::default())).unwrap();
@@ -1131,12 +1174,17 @@ mod test {
// a makeshift solution for compare large data
if real_output != expected_output {
let mut first_diff = 0;
let mut is_diff_found = false;
for (idx, (lhs, rhs)) in real_output.iter().zip(expected_output.iter()).enumerate() {
if lhs != rhs {
if lhs.slice(0, rhs.num_rows()) != *rhs {
first_diff = idx;
is_diff_found = true;
break;
}
}
if !is_diff_found {
return;
}
println!("first diff batch at {}", first_diff);
println!(
"ranges: {:?}",
@@ -1175,8 +1223,14 @@ mod test {
let buf = String::from_utf8_lossy(&buf);
full_msg += &format!("case_id:{case_id}, expected_output \n{buf}");
}
if let Some(expected_polled_rows) = expected_polled_rows {
let input_pulled_rows = mock_input.metrics().unwrap().output_rows().unwrap();
assert_eq!(input_pulled_rows, expected_polled_rows);
}
panic!(
"case_{} failed, opt: {:?},\n real output has {} batches, {} rows, expected has {} batches with {} rows\nfull msg: {}",
"case_{} failed (limit {limit:?}), opt: {:?},\n real output has {} batches, {} rows, expected has {} batches with {} rows\nfull msg: {}",
case_id,
opt,
real_output.len(),
@@ -1187,4 +1241,249 @@ mod test {
);
}
}
/// Test that verifies the limit is correctly applied per partition when
/// multiple batches are received for the same partition.
#[tokio::test]
async fn test_limit_with_multiple_batches_per_partition() {
let unit = TimeUnit::Millisecond;
let schema = Arc::new(Schema::new(vec![Field::new(
"ts",
DataType::Timestamp(unit, None),
false,
)]));
// Test case: Multiple batches in a single partition with limit=3
// Input: 3 batches with [1,2,3], [4,5,6], [7,8,9] all in partition (0,10)
// Expected: Only top 3 values [9,8,7] for descending sort
let input_ranged_data = vec![(
PartitionRange {
start: Timestamp::new(0, unit.into()),
end: Timestamp::new(10, unit.into()),
num_rows: 9,
identifier: 0,
},
vec![
DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![1, 2, 3])])
.unwrap(),
DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![4, 5, 6])])
.unwrap(),
DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![7, 8, 9])])
.unwrap(),
],
)];
let expected_output = vec![
DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![9, 8, 7])])
.unwrap(),
];
run_test(
1000,
input_ranged_data,
schema.clone(),
SortOptions {
descending: true,
..Default::default()
},
Some(3),
expected_output,
None,
)
.await;
// Test case: Multiple batches across multiple partitions with limit=2
// Partition 0: batches [10,11,12], [13,14,15] -> top 2 descending = [15,14]
// Partition 1: batches [1,2,3], [4,5] -> top 2 descending = [5,4]
let input_ranged_data = vec![
(
PartitionRange {
start: Timestamp::new(10, unit.into()),
end: Timestamp::new(20, unit.into()),
num_rows: 6,
identifier: 0,
},
vec![
DfRecordBatch::try_new(
schema.clone(),
vec![new_ts_array(unit, vec![10, 11, 12])],
)
.unwrap(),
DfRecordBatch::try_new(
schema.clone(),
vec![new_ts_array(unit, vec![13, 14, 15])],
)
.unwrap(),
],
),
(
PartitionRange {
start: Timestamp::new(0, unit.into()),
end: Timestamp::new(10, unit.into()),
num_rows: 5,
identifier: 1,
},
vec![
DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![1, 2, 3])])
.unwrap(),
DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![4, 5])])
.unwrap(),
],
),
];
let expected_output = vec![
DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![15, 14])]).unwrap(),
];
run_test(
1001,
input_ranged_data,
schema.clone(),
SortOptions {
descending: true,
..Default::default()
},
Some(2),
expected_output,
None,
)
.await;
// Test case: Ascending sort with limit
// Partition: batches [7,8,9], [4,5,6], [1,2,3] -> top 2 ascending = [1,2]
let input_ranged_data = vec![(
PartitionRange {
start: Timestamp::new(0, unit.into()),
end: Timestamp::new(10, unit.into()),
num_rows: 9,
identifier: 0,
},
vec![
DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![7, 8, 9])])
.unwrap(),
DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![4, 5, 6])])
.unwrap(),
DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![1, 2, 3])])
.unwrap(),
],
)];
let expected_output = vec![
DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![1, 2])]).unwrap(),
];
run_test(
1002,
input_ranged_data,
schema.clone(),
SortOptions {
descending: false,
..Default::default()
},
Some(2),
expected_output,
None,
)
.await;
}
/// Test that verifies early termination behavior.
/// Once we've produced limit * num_partitions rows, we should stop
/// pulling from input stream.
#[tokio::test]
async fn test_early_termination() {
let unit = TimeUnit::Millisecond;
let schema = Arc::new(Schema::new(vec![Field::new(
"ts",
DataType::Timestamp(unit, None),
false,
)]));
// Create 3 partitions, each with more data than the limit
// limit=2 per partition, so total expected output = 6 rows
// After producing 6 rows, early termination should kick in
let input_ranged_data = vec![
(
PartitionRange {
start: Timestamp::new(0, unit.into()),
end: Timestamp::new(10, unit.into()),
num_rows: 10,
identifier: 0,
},
vec![
DfRecordBatch::try_new(
schema.clone(),
vec![new_ts_array(unit, vec![1, 2, 3, 4, 5])],
)
.unwrap(),
DfRecordBatch::try_new(
schema.clone(),
vec![new_ts_array(unit, vec![6, 7, 8, 9, 10])],
)
.unwrap(),
],
),
(
PartitionRange {
start: Timestamp::new(10, unit.into()),
end: Timestamp::new(20, unit.into()),
num_rows: 10,
identifier: 1,
},
vec![
DfRecordBatch::try_new(
schema.clone(),
vec![new_ts_array(unit, vec![11, 12, 13, 14, 15])],
)
.unwrap(),
DfRecordBatch::try_new(
schema.clone(),
vec![new_ts_array(unit, vec![16, 17, 18, 19, 20])],
)
.unwrap(),
],
),
(
PartitionRange {
start: Timestamp::new(20, unit.into()),
end: Timestamp::new(30, unit.into()),
num_rows: 10,
identifier: 2,
},
vec![
DfRecordBatch::try_new(
schema.clone(),
vec![new_ts_array(unit, vec![21, 22, 23, 24, 25])],
)
.unwrap(),
DfRecordBatch::try_new(
schema.clone(),
vec![new_ts_array(unit, vec![26, 27, 28, 29, 30])],
)
.unwrap(),
],
),
];
// PartSort won't reorder `PartitionRange` (it assumes it's already ordered), so it will not read other partitions.
// This case is just to verify that early termination works as expected.
let expected_output = vec![
DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![9, 8])]).unwrap(),
];
run_test(
1003,
input_ranged_data,
schema.clone(),
SortOptions {
descending: true,
..Default::default()
},
Some(2),
expected_output,
Some(10),
)
.await;
}
}

View File

@@ -25,6 +25,7 @@ use arrow_schema::{SchemaRef, TimeUnit};
use common_recordbatch::{DfRecordBatch, DfSendableRecordBatchStream};
use datafusion::execution::{RecordBatchStream, TaskContext};
use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
use futures::Stream;
@@ -46,13 +47,14 @@ pub fn new_ts_array(unit: TimeUnit, arr: Vec<i64>) -> ArrayRef {
#[derive(Debug)]
pub struct MockInputExec {
input: Vec<DfRecordBatch>,
input: Vec<Vec<DfRecordBatch>>,
schema: SchemaRef,
properties: PlanProperties,
metrics: ExecutionPlanMetricsSet,
}
impl MockInputExec {
pub fn new(input: Vec<DfRecordBatch>, schema: SchemaRef) -> Self {
pub fn new(input: Vec<Vec<DfRecordBatch>>, schema: SchemaRef) -> Self {
Self {
properties: PlanProperties::new(
EquivalenceProperties::new(schema.clone()),
@@ -62,6 +64,7 @@ impl MockInputExec {
),
input,
schema,
metrics: ExecutionPlanMetricsSet::new(),
}
}
}
@@ -98,22 +101,28 @@ impl ExecutionPlan for MockInputExec {
fn execute(
&self,
_partition: usize,
partition: usize,
_context: Arc<TaskContext>,
) -> datafusion_common::Result<DfSendableRecordBatchStream> {
let stream = MockStream {
stream: self.input.clone(),
stream: self.input.clone().into_iter().flatten().collect(),
schema: self.schema.clone(),
idx: 0,
metrics: BaselineMetrics::new(&self.metrics, partition),
};
Ok(Box::pin(stream))
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
}
struct MockStream {
stream: Vec<DfRecordBatch>,
schema: SchemaRef,
idx: usize,
metrics: BaselineMetrics,
}
impl Stream for MockStream {
@@ -125,7 +134,7 @@ impl Stream for MockStream {
if self.idx < self.stream.len() {
let ret = self.stream[self.idx].clone();
self.idx += 1;
Poll::Ready(Some(Ok(ret)))
self.metrics.record_poll(Poll::Ready(Some(Ok(ret))))
} else {
Poll::Ready(None)
}

View File

@@ -2528,7 +2528,7 @@ mod test {
async fn run_test(&self) -> Vec<DfRecordBatch> {
let (ranges, batches): (Vec<_>, Vec<_>) = self.input.clone().into_iter().unzip();
let mock_input = MockInputExec::new(batches, self.schema.clone());
let mock_input = MockInputExec::new(vec![batches], self.schema.clone());
let exec = WindowedSortExec::try_new(
self.expression.clone(),