mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2025-12-22 22:20:02 +00:00
fix: part sort behavior (#7374)
* fix: part sort behavior Signed-off-by: Ruihang Xia <waynestxia@gmail.com> * tune tests Signed-off-by: Ruihang Xia <waynestxia@gmail.com> * debug assertion and remove produced count Signed-off-by: Ruihang Xia <waynestxia@gmail.com> --------- Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
This commit is contained in:
@@ -284,7 +284,6 @@ struct PartSortStream {
|
||||
buffer: PartSortBuffer,
|
||||
expression: PhysicalSortExpr,
|
||||
limit: Option<usize>,
|
||||
produced: usize,
|
||||
input: DfSendableRecordBatchStream,
|
||||
input_complete: bool,
|
||||
schema: SchemaRef,
|
||||
@@ -340,7 +339,6 @@ impl PartSortStream {
|
||||
buffer,
|
||||
expression: sort.expression.clone(),
|
||||
limit,
|
||||
produced: 0,
|
||||
input,
|
||||
input_complete: false,
|
||||
schema: sort.input.schema(),
|
||||
@@ -565,7 +563,6 @@ impl PartSortStream {
|
||||
)
|
||||
})?;
|
||||
|
||||
self.produced += sorted.num_rows();
|
||||
drop(full_input);
|
||||
// here remove both buffer and full_input memory
|
||||
self.reservation.shrink(2 * total_mem);
|
||||
@@ -666,6 +663,16 @@ impl PartSortStream {
|
||||
let sorted_batch = self.sort_buffer();
|
||||
// step to next proper PartitionRange
|
||||
self.cur_part_idx += 1;
|
||||
|
||||
// If we've processed all partitions, discard remaining data
|
||||
if self.cur_part_idx >= self.partition_ranges.len() {
|
||||
// assert there is no data beyond the last partition range (remaining is empty).
|
||||
// it would be acceptable even if it happens, because `remaining_range` will be discarded anyway.
|
||||
debug_assert!(remaining_range.num_rows() == 0);
|
||||
|
||||
return sorted_batch.map(|x| if x.num_rows() == 0 { None } else { Some(x) });
|
||||
}
|
||||
|
||||
let next_sort_column = sort_column.slice(idx, batch.num_rows() - idx);
|
||||
if self.try_find_next_range(&next_sort_column)?.is_some() {
|
||||
// remaining batch still contains data that exceeds the current partition range
|
||||
@@ -687,6 +694,12 @@ impl PartSortStream {
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<datafusion_common::Result<DfRecordBatch>>> {
|
||||
loop {
|
||||
// Early termination: if we've already produced enough rows,
|
||||
// don't poll more input - just return
|
||||
if matches!(self.limit, Some(0)) {
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
|
||||
// no more input, sort the buffer and return
|
||||
if self.input_complete {
|
||||
if self.buffer.is_empty() {
|
||||
@@ -701,7 +714,24 @@ impl PartSortStream {
|
||||
if let Some(evaluating_batch) = self.evaluating_batch.take()
|
||||
&& evaluating_batch.num_rows() != 0
|
||||
{
|
||||
// Check if we've already processed all partitions
|
||||
if self.cur_part_idx >= self.partition_ranges.len() {
|
||||
// All partitions processed, discard remaining data
|
||||
if self.buffer.is_empty() {
|
||||
return Poll::Ready(None);
|
||||
} else {
|
||||
let sorted_batch = self.sort_buffer()?;
|
||||
self.limit = self
|
||||
.limit
|
||||
.map(|l| l.saturating_sub(sorted_batch.num_rows()));
|
||||
return Poll::Ready(Some(Ok(sorted_batch)));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(sorted_batch) = self.split_batch(evaluating_batch)? {
|
||||
self.limit = self
|
||||
.limit
|
||||
.map(|l| l.saturating_sub(sorted_batch.num_rows()));
|
||||
return Poll::Ready(Some(Ok(sorted_batch)));
|
||||
} else {
|
||||
continue;
|
||||
@@ -713,6 +743,9 @@ impl PartSortStream {
|
||||
match res {
|
||||
Poll::Ready(Some(Ok(batch))) => {
|
||||
if let Some(sorted_batch) = self.split_batch(batch)? {
|
||||
self.limit = self
|
||||
.limit
|
||||
.map(|l| l.saturating_sub(sorted_batch.num_rows()));
|
||||
return Poll::Ready(Some(Ok(sorted_batch)));
|
||||
} else {
|
||||
continue;
|
||||
@@ -896,22 +929,30 @@ mod test {
|
||||
output_data.push(cur_data);
|
||||
}
|
||||
|
||||
let expected_output = output_data
|
||||
let mut limit_remains = limit;
|
||||
let mut expected_output = output_data
|
||||
.into_iter()
|
||||
.map(|a| {
|
||||
DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, a)]).unwrap()
|
||||
})
|
||||
.map(|rb| {
|
||||
// trim expected output with limit
|
||||
if let Some(limit) = limit
|
||||
&& rb.num_rows() > limit
|
||||
{
|
||||
rb.slice(0, limit)
|
||||
if let Some(limit) = limit_remains.as_mut() {
|
||||
let rb = rb.slice(0, (*limit).min(rb.num_rows()));
|
||||
*limit = limit.saturating_sub(rb.num_rows());
|
||||
rb
|
||||
} else {
|
||||
rb
|
||||
}
|
||||
})
|
||||
.collect_vec();
|
||||
while let Some(rb) = expected_output.last() {
|
||||
if rb.num_rows() == 0 {
|
||||
expected_output.pop();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
test_cases.push((
|
||||
case_id,
|
||||
@@ -932,13 +973,14 @@ mod test {
|
||||
opt,
|
||||
limit,
|
||||
expected_output,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn simple_case() {
|
||||
async fn simple_cases() {
|
||||
let testcases = vec![
|
||||
(
|
||||
TimeUnit::Millisecond,
|
||||
@@ -1027,7 +1069,7 @@ mod test {
|
||||
],
|
||||
true,
|
||||
Some(2),
|
||||
vec![vec![19, 17], vec![12, 11], vec![9, 8], vec![4, 3]],
|
||||
vec![vec![19, 17]],
|
||||
),
|
||||
];
|
||||
|
||||
@@ -1080,6 +1122,7 @@ mod test {
|
||||
opt,
|
||||
limit,
|
||||
expected_output,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
@@ -1093,6 +1136,7 @@ mod test {
|
||||
opt: SortOptions,
|
||||
limit: Option<usize>,
|
||||
expected_output: Vec<DfRecordBatch>,
|
||||
expected_polled_rows: Option<usize>,
|
||||
) {
|
||||
for rb in &expected_output {
|
||||
if let Some(limit) = limit {
|
||||
@@ -1104,16 +1148,15 @@ mod test {
|
||||
);
|
||||
}
|
||||
}
|
||||
let (ranges, batches): (Vec<_>, Vec<_>) = input_ranged_data.clone().into_iter().unzip();
|
||||
|
||||
let batches = batches
|
||||
.into_iter()
|
||||
.flat_map(|mut cols| {
|
||||
cols.push(DfRecordBatch::new_empty(schema.clone()));
|
||||
cols
|
||||
})
|
||||
.collect_vec();
|
||||
let mock_input = MockInputExec::new(batches, schema.clone());
|
||||
let mut data_partition = Vec::with_capacity(input_ranged_data.len());
|
||||
let mut ranges = Vec::with_capacity(input_ranged_data.len());
|
||||
for (part_range, batches) in input_ranged_data {
|
||||
data_partition.push(batches);
|
||||
ranges.push(part_range);
|
||||
}
|
||||
|
||||
let mock_input = Arc::new(MockInputExec::new(data_partition, schema.clone()));
|
||||
|
||||
let exec = PartSortExec::new(
|
||||
PhysicalSortExpr {
|
||||
@@ -1122,7 +1165,7 @@ mod test {
|
||||
},
|
||||
limit,
|
||||
vec![ranges.clone()],
|
||||
Arc::new(mock_input),
|
||||
mock_input.clone(),
|
||||
);
|
||||
|
||||
let exec_stream = exec.execute(0, Arc::new(TaskContext::default())).unwrap();
|
||||
@@ -1131,12 +1174,17 @@ mod test {
|
||||
// a makeshift solution for compare large data
|
||||
if real_output != expected_output {
|
||||
let mut first_diff = 0;
|
||||
let mut is_diff_found = false;
|
||||
for (idx, (lhs, rhs)) in real_output.iter().zip(expected_output.iter()).enumerate() {
|
||||
if lhs != rhs {
|
||||
if lhs.slice(0, rhs.num_rows()) != *rhs {
|
||||
first_diff = idx;
|
||||
is_diff_found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if !is_diff_found {
|
||||
return;
|
||||
}
|
||||
println!("first diff batch at {}", first_diff);
|
||||
println!(
|
||||
"ranges: {:?}",
|
||||
@@ -1175,8 +1223,14 @@ mod test {
|
||||
let buf = String::from_utf8_lossy(&buf);
|
||||
full_msg += &format!("case_id:{case_id}, expected_output \n{buf}");
|
||||
}
|
||||
|
||||
if let Some(expected_polled_rows) = expected_polled_rows {
|
||||
let input_pulled_rows = mock_input.metrics().unwrap().output_rows().unwrap();
|
||||
assert_eq!(input_pulled_rows, expected_polled_rows);
|
||||
}
|
||||
|
||||
panic!(
|
||||
"case_{} failed, opt: {:?},\n real output has {} batches, {} rows, expected has {} batches with {} rows\nfull msg: {}",
|
||||
"case_{} failed (limit {limit:?}), opt: {:?},\n real output has {} batches, {} rows, expected has {} batches with {} rows\nfull msg: {}",
|
||||
case_id,
|
||||
opt,
|
||||
real_output.len(),
|
||||
@@ -1187,4 +1241,249 @@ mod test {
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Test that verifies the limit is correctly applied per partition when
|
||||
/// multiple batches are received for the same partition.
|
||||
#[tokio::test]
|
||||
async fn test_limit_with_multiple_batches_per_partition() {
|
||||
let unit = TimeUnit::Millisecond;
|
||||
let schema = Arc::new(Schema::new(vec![Field::new(
|
||||
"ts",
|
||||
DataType::Timestamp(unit, None),
|
||||
false,
|
||||
)]));
|
||||
|
||||
// Test case: Multiple batches in a single partition with limit=3
|
||||
// Input: 3 batches with [1,2,3], [4,5,6], [7,8,9] all in partition (0,10)
|
||||
// Expected: Only top 3 values [9,8,7] for descending sort
|
||||
let input_ranged_data = vec![(
|
||||
PartitionRange {
|
||||
start: Timestamp::new(0, unit.into()),
|
||||
end: Timestamp::new(10, unit.into()),
|
||||
num_rows: 9,
|
||||
identifier: 0,
|
||||
},
|
||||
vec![
|
||||
DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![1, 2, 3])])
|
||||
.unwrap(),
|
||||
DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![4, 5, 6])])
|
||||
.unwrap(),
|
||||
DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![7, 8, 9])])
|
||||
.unwrap(),
|
||||
],
|
||||
)];
|
||||
|
||||
let expected_output = vec![
|
||||
DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![9, 8, 7])])
|
||||
.unwrap(),
|
||||
];
|
||||
|
||||
run_test(
|
||||
1000,
|
||||
input_ranged_data,
|
||||
schema.clone(),
|
||||
SortOptions {
|
||||
descending: true,
|
||||
..Default::default()
|
||||
},
|
||||
Some(3),
|
||||
expected_output,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
// Test case: Multiple batches across multiple partitions with limit=2
|
||||
// Partition 0: batches [10,11,12], [13,14,15] -> top 2 descending = [15,14]
|
||||
// Partition 1: batches [1,2,3], [4,5] -> top 2 descending = [5,4]
|
||||
let input_ranged_data = vec![
|
||||
(
|
||||
PartitionRange {
|
||||
start: Timestamp::new(10, unit.into()),
|
||||
end: Timestamp::new(20, unit.into()),
|
||||
num_rows: 6,
|
||||
identifier: 0,
|
||||
},
|
||||
vec![
|
||||
DfRecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![new_ts_array(unit, vec![10, 11, 12])],
|
||||
)
|
||||
.unwrap(),
|
||||
DfRecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![new_ts_array(unit, vec![13, 14, 15])],
|
||||
)
|
||||
.unwrap(),
|
||||
],
|
||||
),
|
||||
(
|
||||
PartitionRange {
|
||||
start: Timestamp::new(0, unit.into()),
|
||||
end: Timestamp::new(10, unit.into()),
|
||||
num_rows: 5,
|
||||
identifier: 1,
|
||||
},
|
||||
vec![
|
||||
DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![1, 2, 3])])
|
||||
.unwrap(),
|
||||
DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![4, 5])])
|
||||
.unwrap(),
|
||||
],
|
||||
),
|
||||
];
|
||||
|
||||
let expected_output = vec![
|
||||
DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![15, 14])]).unwrap(),
|
||||
];
|
||||
|
||||
run_test(
|
||||
1001,
|
||||
input_ranged_data,
|
||||
schema.clone(),
|
||||
SortOptions {
|
||||
descending: true,
|
||||
..Default::default()
|
||||
},
|
||||
Some(2),
|
||||
expected_output,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
// Test case: Ascending sort with limit
|
||||
// Partition: batches [7,8,9], [4,5,6], [1,2,3] -> top 2 ascending = [1,2]
|
||||
let input_ranged_data = vec![(
|
||||
PartitionRange {
|
||||
start: Timestamp::new(0, unit.into()),
|
||||
end: Timestamp::new(10, unit.into()),
|
||||
num_rows: 9,
|
||||
identifier: 0,
|
||||
},
|
||||
vec![
|
||||
DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![7, 8, 9])])
|
||||
.unwrap(),
|
||||
DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![4, 5, 6])])
|
||||
.unwrap(),
|
||||
DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![1, 2, 3])])
|
||||
.unwrap(),
|
||||
],
|
||||
)];
|
||||
|
||||
let expected_output = vec![
|
||||
DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![1, 2])]).unwrap(),
|
||||
];
|
||||
|
||||
run_test(
|
||||
1002,
|
||||
input_ranged_data,
|
||||
schema.clone(),
|
||||
SortOptions {
|
||||
descending: false,
|
||||
..Default::default()
|
||||
},
|
||||
Some(2),
|
||||
expected_output,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Test that verifies early termination behavior.
|
||||
/// Once we've produced limit * num_partitions rows, we should stop
|
||||
/// pulling from input stream.
|
||||
#[tokio::test]
|
||||
async fn test_early_termination() {
|
||||
let unit = TimeUnit::Millisecond;
|
||||
let schema = Arc::new(Schema::new(vec![Field::new(
|
||||
"ts",
|
||||
DataType::Timestamp(unit, None),
|
||||
false,
|
||||
)]));
|
||||
|
||||
// Create 3 partitions, each with more data than the limit
|
||||
// limit=2 per partition, so total expected output = 6 rows
|
||||
// After producing 6 rows, early termination should kick in
|
||||
let input_ranged_data = vec![
|
||||
(
|
||||
PartitionRange {
|
||||
start: Timestamp::new(0, unit.into()),
|
||||
end: Timestamp::new(10, unit.into()),
|
||||
num_rows: 10,
|
||||
identifier: 0,
|
||||
},
|
||||
vec![
|
||||
DfRecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![new_ts_array(unit, vec![1, 2, 3, 4, 5])],
|
||||
)
|
||||
.unwrap(),
|
||||
DfRecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![new_ts_array(unit, vec![6, 7, 8, 9, 10])],
|
||||
)
|
||||
.unwrap(),
|
||||
],
|
||||
),
|
||||
(
|
||||
PartitionRange {
|
||||
start: Timestamp::new(10, unit.into()),
|
||||
end: Timestamp::new(20, unit.into()),
|
||||
num_rows: 10,
|
||||
identifier: 1,
|
||||
},
|
||||
vec![
|
||||
DfRecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![new_ts_array(unit, vec![11, 12, 13, 14, 15])],
|
||||
)
|
||||
.unwrap(),
|
||||
DfRecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![new_ts_array(unit, vec![16, 17, 18, 19, 20])],
|
||||
)
|
||||
.unwrap(),
|
||||
],
|
||||
),
|
||||
(
|
||||
PartitionRange {
|
||||
start: Timestamp::new(20, unit.into()),
|
||||
end: Timestamp::new(30, unit.into()),
|
||||
num_rows: 10,
|
||||
identifier: 2,
|
||||
},
|
||||
vec![
|
||||
DfRecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![new_ts_array(unit, vec![21, 22, 23, 24, 25])],
|
||||
)
|
||||
.unwrap(),
|
||||
DfRecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![new_ts_array(unit, vec![26, 27, 28, 29, 30])],
|
||||
)
|
||||
.unwrap(),
|
||||
],
|
||||
),
|
||||
];
|
||||
|
||||
// PartSort won't reorder `PartitionRange` (it assumes it's already ordered), so it will not read other partitions.
|
||||
// This case is just to verify that early termination works as expected.
|
||||
let expected_output = vec![
|
||||
DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![9, 8])]).unwrap(),
|
||||
];
|
||||
|
||||
run_test(
|
||||
1003,
|
||||
input_ranged_data,
|
||||
schema.clone(),
|
||||
SortOptions {
|
||||
descending: true,
|
||||
..Default::default()
|
||||
},
|
||||
Some(2),
|
||||
expected_output,
|
||||
Some(10),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,6 +25,7 @@ use arrow_schema::{SchemaRef, TimeUnit};
|
||||
use common_recordbatch::{DfRecordBatch, DfSendableRecordBatchStream};
|
||||
use datafusion::execution::{RecordBatchStream, TaskContext};
|
||||
use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
|
||||
use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
|
||||
use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
|
||||
use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
|
||||
use futures::Stream;
|
||||
@@ -46,13 +47,14 @@ pub fn new_ts_array(unit: TimeUnit, arr: Vec<i64>) -> ArrayRef {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MockInputExec {
|
||||
input: Vec<DfRecordBatch>,
|
||||
input: Vec<Vec<DfRecordBatch>>,
|
||||
schema: SchemaRef,
|
||||
properties: PlanProperties,
|
||||
metrics: ExecutionPlanMetricsSet,
|
||||
}
|
||||
|
||||
impl MockInputExec {
|
||||
pub fn new(input: Vec<DfRecordBatch>, schema: SchemaRef) -> Self {
|
||||
pub fn new(input: Vec<Vec<DfRecordBatch>>, schema: SchemaRef) -> Self {
|
||||
Self {
|
||||
properties: PlanProperties::new(
|
||||
EquivalenceProperties::new(schema.clone()),
|
||||
@@ -62,6 +64,7 @@ impl MockInputExec {
|
||||
),
|
||||
input,
|
||||
schema,
|
||||
metrics: ExecutionPlanMetricsSet::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -98,22 +101,28 @@ impl ExecutionPlan for MockInputExec {
|
||||
|
||||
fn execute(
|
||||
&self,
|
||||
_partition: usize,
|
||||
partition: usize,
|
||||
_context: Arc<TaskContext>,
|
||||
) -> datafusion_common::Result<DfSendableRecordBatchStream> {
|
||||
let stream = MockStream {
|
||||
stream: self.input.clone(),
|
||||
stream: self.input.clone().into_iter().flatten().collect(),
|
||||
schema: self.schema.clone(),
|
||||
idx: 0,
|
||||
metrics: BaselineMetrics::new(&self.metrics, partition),
|
||||
};
|
||||
Ok(Box::pin(stream))
|
||||
}
|
||||
|
||||
fn metrics(&self) -> Option<MetricsSet> {
|
||||
Some(self.metrics.clone_inner())
|
||||
}
|
||||
}
|
||||
|
||||
struct MockStream {
|
||||
stream: Vec<DfRecordBatch>,
|
||||
schema: SchemaRef,
|
||||
idx: usize,
|
||||
metrics: BaselineMetrics,
|
||||
}
|
||||
|
||||
impl Stream for MockStream {
|
||||
@@ -125,7 +134,7 @@ impl Stream for MockStream {
|
||||
if self.idx < self.stream.len() {
|
||||
let ret = self.stream[self.idx].clone();
|
||||
self.idx += 1;
|
||||
Poll::Ready(Some(Ok(ret)))
|
||||
self.metrics.record_poll(Poll::Ready(Some(Ok(ret))))
|
||||
} else {
|
||||
Poll::Ready(None)
|
||||
}
|
||||
|
||||
@@ -2528,7 +2528,7 @@ mod test {
|
||||
async fn run_test(&self) -> Vec<DfRecordBatch> {
|
||||
let (ranges, batches): (Vec<_>, Vec<_>) = self.input.clone().into_iter().unzip();
|
||||
|
||||
let mock_input = MockInputExec::new(batches, self.schema.clone());
|
||||
let mock_input = MockInputExec::new(vec![batches], self.schema.clone());
|
||||
|
||||
let exec = WindowedSortExec::try_new(
|
||||
self.expression.clone(),
|
||||
|
||||
Reference in New Issue
Block a user