mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-20 23:10:37 +00:00
@@ -534,43 +534,69 @@ impl PartSortStream {
|
||||
}
|
||||
|
||||
if topk {
|
||||
self.compact_topk_buffer()?;
|
||||
self.update_dynamic_filter(sort_data_type)?;
|
||||
let threshold = self.compact_topk_buffer(sort_data_type)?;
|
||||
self.update_dynamic_filter(sort_data_type, threshold)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn compact_topk_buffer(&mut self) -> datafusion_common::Result<()> {
|
||||
fn compact_topk_buffer(
|
||||
&mut self,
|
||||
sort_data_type: &DataType,
|
||||
) -> datafusion_common::Result<Option<TopKThreshold>> {
|
||||
let Some(limit) = self.limit else {
|
||||
return Ok(());
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let PartSortBuffer::TopK(buffer) =
|
||||
std::mem::replace(&mut self.buffer, PartSortBuffer::TopK(Vec::new()))
|
||||
else {
|
||||
return Ok(());
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
if limit == 0 || buffer.is_empty() {
|
||||
self.buffer = PartSortBuffer::TopK(Vec::new());
|
||||
return Ok(());
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let total_rows: usize = buffer.iter().map(|batch| batch.num_rows()).sum();
|
||||
if total_rows <= limit {
|
||||
self.buffer = PartSortBuffer::TopK(buffer);
|
||||
return Ok(());
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let topk = self.sort_record_batches(&buffer, Some(limit), false)?;
|
||||
let threshold = self.threshold_from_sorted_batch(&topk, sort_data_type)?;
|
||||
self.buffer = if topk.num_rows() == 0 {
|
||||
PartSortBuffer::TopK(Vec::new())
|
||||
} else {
|
||||
PartSortBuffer::TopK(vec![topk])
|
||||
};
|
||||
|
||||
Ok(())
|
||||
Ok(threshold)
|
||||
}
|
||||
|
||||
fn threshold_from_sorted_batch(
|
||||
&self,
|
||||
batch: &DfRecordBatch,
|
||||
sort_data_type: &DataType,
|
||||
) -> datafusion_common::Result<Option<TopKThreshold>> {
|
||||
if batch.num_rows() == 0 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let threshold_idx = batch.num_rows() - 1;
|
||||
let sort_column = self.expression.evaluate_to_sort_column(batch)?.values;
|
||||
let threshold = downcast_ts_array!(
|
||||
sort_data_type => (threshold_helper, sort_column, threshold_idx),
|
||||
_ => internal_err!(
|
||||
"Unsupported data type for sort column: {:?}",
|
||||
sort_data_type
|
||||
)?,
|
||||
);
|
||||
|
||||
Ok(Some(threshold))
|
||||
}
|
||||
|
||||
fn topk_threshold(
|
||||
@@ -581,7 +607,7 @@ impl PartSortStream {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
if self.buffer.num_rows() < limit {
|
||||
if limit == 0 || self.buffer.num_rows() < limit {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
@@ -691,13 +717,19 @@ impl PartSortStream {
|
||||
fn update_dynamic_filter(
|
||||
&mut self,
|
||||
sort_data_type: &DataType,
|
||||
threshold: Option<TopKThreshold>,
|
||||
) -> datafusion_common::Result<()> {
|
||||
let Some(filter) = &self.dynamic_filter else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let Some(threshold) = self.topk_threshold(sort_data_type)? else {
|
||||
return Ok(());
|
||||
let threshold = if let Some(threshold) = threshold {
|
||||
threshold
|
||||
} else {
|
||||
let Some(threshold) = self.topk_threshold(sort_data_type)? else {
|
||||
return Ok(());
|
||||
};
|
||||
threshold
|
||||
};
|
||||
|
||||
if self.dynamic_filter_threshold.as_ref() == Some(&threshold) {
|
||||
@@ -722,8 +754,13 @@ impl PartSortStream {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let Some(threshold) = self.topk_threshold(sort_data_type)? else {
|
||||
return Ok(false);
|
||||
let threshold = if let Some(threshold) = &self.dynamic_filter_threshold {
|
||||
threshold.clone()
|
||||
} else {
|
||||
let Some(threshold) = self.topk_threshold(sort_data_type)? else {
|
||||
return Ok(false);
|
||||
};
|
||||
threshold
|
||||
};
|
||||
|
||||
let (_, start_idx, _) = self.range_groups[group_idx];
|
||||
@@ -1881,6 +1918,59 @@ mod test {
|
||||
assert_eq!(stream.sort_buffer().unwrap(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_topk_limit_zero_clears_buffer_without_threshold() {
|
||||
let unit = TimeUnit::Millisecond;
|
||||
let schema = Arc::new(Schema::new(vec![Field::new(
|
||||
"ts",
|
||||
DataType::Timestamp(unit, None),
|
||||
false,
|
||||
)]));
|
||||
let sort_data_type = DataType::Timestamp(unit, None);
|
||||
let partition_range = PartitionRange {
|
||||
start: Timestamp::new(0, unit.into()),
|
||||
end: Timestamp::new(10, unit.into()),
|
||||
num_rows: 3,
|
||||
identifier: 0,
|
||||
};
|
||||
let mock_input = Arc::new(MockInputExec::new(vec![vec![]], schema.clone()));
|
||||
let exec = PartSortExec::try_new(
|
||||
PhysicalSortExpr {
|
||||
expr: Arc::new(Column::new("ts", 0)),
|
||||
options: SortOptions {
|
||||
descending: true,
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
Some(0),
|
||||
vec![vec![partition_range]],
|
||||
mock_input.clone(),
|
||||
)
|
||||
.unwrap();
|
||||
let input_stream = mock_input
|
||||
.execute(0, Arc::new(TaskContext::default()))
|
||||
.unwrap();
|
||||
let mut stream = PartSortStream::new(
|
||||
Arc::new(TaskContext::default()),
|
||||
&exec,
|
||||
Some(0),
|
||||
input_stream,
|
||||
vec![partition_range],
|
||||
0,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
stream
|
||||
.push_buffer(
|
||||
DfRecordBatch::try_new(schema, vec![new_ts_array(unit, vec![1, 2, 3])]).unwrap(),
|
||||
&sort_data_type,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(stream.buffer.num_rows(), 0);
|
||||
assert_eq!(stream.dynamic_filter_threshold, None);
|
||||
}
|
||||
|
||||
/// Test that verifies early termination behavior.
|
||||
/// Once we've produced limit * num_partitions rows, we should stop
|
||||
/// pulling from input stream.
|
||||
|
||||
Reference in New Issue
Block a user