fix limit 0

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
This commit is contained in:
Ruihang Xia
2026-05-14 11:29:58 +08:00
parent bbe2b1c11a
commit dcf97128be

View File

@@ -534,43 +534,69 @@ impl PartSortStream {
}
if topk {
self.compact_topk_buffer()?;
self.update_dynamic_filter(sort_data_type)?;
let threshold = self.compact_topk_buffer(sort_data_type)?;
self.update_dynamic_filter(sort_data_type, threshold)?;
}
Ok(())
}
fn compact_topk_buffer(&mut self) -> datafusion_common::Result<()> {
fn compact_topk_buffer(
&mut self,
sort_data_type: &DataType,
) -> datafusion_common::Result<Option<TopKThreshold>> {
let Some(limit) = self.limit else {
return Ok(());
return Ok(None);
};
let PartSortBuffer::TopK(buffer) =
std::mem::replace(&mut self.buffer, PartSortBuffer::TopK(Vec::new()))
else {
return Ok(());
return Ok(None);
};
if limit == 0 || buffer.is_empty() {
self.buffer = PartSortBuffer::TopK(Vec::new());
return Ok(());
return Ok(None);
}
let total_rows: usize = buffer.iter().map(|batch| batch.num_rows()).sum();
if total_rows <= limit {
self.buffer = PartSortBuffer::TopK(buffer);
return Ok(());
return Ok(None);
}
let topk = self.sort_record_batches(&buffer, Some(limit), false)?;
let threshold = self.threshold_from_sorted_batch(&topk, sort_data_type)?;
self.buffer = if topk.num_rows() == 0 {
PartSortBuffer::TopK(Vec::new())
} else {
PartSortBuffer::TopK(vec![topk])
};
Ok(())
Ok(threshold)
}
fn threshold_from_sorted_batch(
&self,
batch: &DfRecordBatch,
sort_data_type: &DataType,
) -> datafusion_common::Result<Option<TopKThreshold>> {
if batch.num_rows() == 0 {
return Ok(None);
}
let threshold_idx = batch.num_rows() - 1;
let sort_column = self.expression.evaluate_to_sort_column(batch)?.values;
let threshold = downcast_ts_array!(
sort_data_type => (threshold_helper, sort_column, threshold_idx),
_ => internal_err!(
"Unsupported data type for sort column: {:?}",
sort_data_type
)?,
);
Ok(Some(threshold))
}
fn topk_threshold(
@@ -581,7 +607,7 @@ impl PartSortStream {
return Ok(None);
};
if self.buffer.num_rows() < limit {
if limit == 0 || self.buffer.num_rows() < limit {
return Ok(None);
}
@@ -691,13 +717,19 @@ impl PartSortStream {
fn update_dynamic_filter(
&mut self,
sort_data_type: &DataType,
threshold: Option<TopKThreshold>,
) -> datafusion_common::Result<()> {
let Some(filter) = &self.dynamic_filter else {
return Ok(());
};
let Some(threshold) = self.topk_threshold(sort_data_type)? else {
return Ok(());
let threshold = if let Some(threshold) = threshold {
threshold
} else {
let Some(threshold) = self.topk_threshold(sort_data_type)? else {
return Ok(());
};
threshold
};
if self.dynamic_filter_threshold.as_ref() == Some(&threshold) {
@@ -722,8 +754,13 @@ impl PartSortStream {
return Ok(false);
}
let Some(threshold) = self.topk_threshold(sort_data_type)? else {
return Ok(false);
let threshold = if let Some(threshold) = &self.dynamic_filter_threshold {
threshold.clone()
} else {
let Some(threshold) = self.topk_threshold(sort_data_type)? else {
return Ok(false);
};
threshold
};
let (_, start_idx, _) = self.range_groups[group_idx];
@@ -1881,6 +1918,59 @@ mod test {
assert_eq!(stream.sort_buffer().unwrap(), expected);
}
#[test]
fn test_topk_limit_zero_clears_buffer_without_threshold() {
let unit = TimeUnit::Millisecond;
let schema = Arc::new(Schema::new(vec![Field::new(
"ts",
DataType::Timestamp(unit, None),
false,
)]));
let sort_data_type = DataType::Timestamp(unit, None);
let partition_range = PartitionRange {
start: Timestamp::new(0, unit.into()),
end: Timestamp::new(10, unit.into()),
num_rows: 3,
identifier: 0,
};
let mock_input = Arc::new(MockInputExec::new(vec![vec![]], schema.clone()));
let exec = PartSortExec::try_new(
PhysicalSortExpr {
expr: Arc::new(Column::new("ts", 0)),
options: SortOptions {
descending: true,
..Default::default()
},
},
Some(0),
vec![vec![partition_range]],
mock_input.clone(),
)
.unwrap();
let input_stream = mock_input
.execute(0, Arc::new(TaskContext::default()))
.unwrap();
let mut stream = PartSortStream::new(
Arc::new(TaskContext::default()),
&exec,
Some(0),
input_stream,
vec![partition_range],
0,
)
.unwrap();
stream
.push_buffer(
DfRecordBatch::try_new(schema, vec![new_ts_array(unit, vec![1, 2, 3])]).unwrap(),
&sort_data_type,
)
.unwrap();
assert_eq!(stream.buffer.num_rows(), 0);
assert_eq!(stream.dynamic_filter_threshold, None);
}
/// Test that verifies early termination behavior.
/// Once we've produced limit * num_partitions rows, we should stop
/// pulling from input stream.