maintain topk in part sort

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
This commit is contained in:
Ruihang Xia
2026-05-13 14:41:16 +08:00
parent 3f19030aaa
commit 372d707839

View File

@@ -25,7 +25,7 @@ use std::task::{Context, Poll};
use arrow::array::{Array, ArrayRef};
use arrow::compute::{concat, concat_batches, take_record_batch};
use arrow_schema::SchemaRef;
use arrow_schema::{DataType, SchemaRef, TimeUnit};
use common_recordbatch::{DfRecordBatch, DfSendableRecordBatchStream};
use common_time::Timestamp;
use datafusion::common::arrow::compute::sort_to_indices;
@@ -39,8 +39,11 @@ use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSe
use datafusion::physical_plan::{
DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties,
};
use datafusion_common::{DataFusionError, internal_err};
use datafusion_physical_expr::expressions::{DynamicFilterPhysicalExpr, lit};
use datafusion_common::{DataFusionError, ScalarValue, internal_err};
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::{
BinaryExpr, DynamicFilterPhysicalExpr, is_not_null, is_null, lit,
};
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
use futures::Stream;
use itertools::Itertools;
@@ -299,8 +302,10 @@ impl ExecutionPlan for PartSortExec {
enum PartSortBuffer {
All(Vec<DfRecordBatch>),
TopK(Vec<DfRecordBatch>),
}
#[derive(Clone, Debug, PartialEq, Eq)]
enum TopKThreshold {
Null,
Value(i64),
@@ -310,12 +315,14 @@ impl PartSortBuffer {
pub fn is_empty(&self) -> bool {
match self {
PartSortBuffer::All(v) => v.is_empty(),
PartSortBuffer::TopK(v) => v.is_empty(),
}
}
pub fn num_rows(&self) -> usize {
match self {
PartSortBuffer::All(v) => v.iter().map(|batch| batch.num_rows()).sum(),
PartSortBuffer::TopK(v) => v.iter().map(|batch| batch.num_rows()).sum(),
}
}
}
@@ -336,6 +343,7 @@ struct PartSortStream {
evaluating_batch: Option<DfRecordBatch>,
metrics: BaselineMetrics,
dynamic_filter: Option<Arc<DynamicFilterPhysicalExpr>>,
dynamic_filter_threshold: Option<TopKThreshold>,
/// Groups of ranges by primary end: (primary_end, start_idx_inclusive, end_idx_exclusive).
/// Ranges in the same group must be processed together before outputting results.
range_groups: Vec<(Timestamp, usize, usize)>,
@@ -352,7 +360,11 @@ impl PartSortStream {
partition_ranges: Vec<PartitionRange>,
partition: usize,
) -> datafusion_common::Result<Self> {
let buffer = PartSortBuffer::All(Vec::new());
let buffer = if limit.is_some() {
PartSortBuffer::TopK(Vec::new())
} else {
PartSortBuffer::All(Vec::new())
};
// Compute range groups by primary end
let descending = sort.expression.options.descending;
@@ -373,6 +385,7 @@ impl PartSortStream {
evaluating_batch: None,
metrics: BaselineMetrics::new(&sort.metrics, partition),
dynamic_filter: sort.dynamic_filter.clone(),
dynamic_filter_threshold: None,
range_groups,
cur_group_idx: 0,
})
@@ -507,11 +520,54 @@ impl PartSortStream {
Ok(None)
}
fn push_buffer(&mut self, batch: DfRecordBatch) -> datafusion_common::Result<()> {
fn push_buffer(
&mut self,
batch: DfRecordBatch,
sort_data_type: &DataType,
) -> datafusion_common::Result<()> {
let topk = matches!(self.buffer, PartSortBuffer::TopK(_));
match &mut self.buffer {
PartSortBuffer::All(v) => v.push(batch),
PartSortBuffer::TopK(v) => v.push(batch),
}
if topk {
self.compact_topk_buffer()?;
self.update_dynamic_filter(sort_data_type)?;
}
Ok(())
}
fn compact_topk_buffer(&mut self) -> datafusion_common::Result<()> {
let Some(limit) = self.limit else {
return Ok(());
};
let PartSortBuffer::TopK(buffer) =
std::mem::replace(&mut self.buffer, PartSortBuffer::TopK(Vec::new()))
else {
return Ok(());
};
if limit == 0 || buffer.is_empty() {
self.buffer = PartSortBuffer::TopK(Vec::new());
return Ok(());
}
let total_rows: usize = buffer.iter().map(|batch| batch.num_rows()).sum();
if total_rows <= limit {
self.buffer = PartSortBuffer::TopK(buffer);
return Ok(());
}
let topk = self.sort_record_batches(&buffer, Some(limit), false)?;
self.buffer = if topk.num_rows() == 0 {
PartSortBuffer::TopK(Vec::new())
} else {
PartSortBuffer::TopK(vec![topk])
};
Ok(())
}
@@ -527,7 +583,9 @@ impl PartSortStream {
return Ok(None);
}
let PartSortBuffer::All(buffer) = &self.buffer;
let buffer = match &self.buffer {
PartSortBuffer::All(buffer) | PartSortBuffer::TopK(buffer) => buffer,
};
let mut sort_columns = Vec::with_capacity(buffer.len());
let mut opt = None;
for batch in buffer {
@@ -567,6 +625,90 @@ impl PartSortStream {
Ok(Some(threshold))
}
fn threshold_scalar_value(
sort_data_type: &DataType,
threshold: &TopKThreshold,
) -> datafusion_common::Result<ScalarValue> {
let value = match threshold {
TopKThreshold::Null => None,
TopKThreshold::Value(value) => Some(*value),
};
let scalar = match sort_data_type {
DataType::Timestamp(TimeUnit::Second, tz) => {
ScalarValue::TimestampSecond(value, tz.clone())
}
DataType::Timestamp(TimeUnit::Millisecond, tz) => {
ScalarValue::TimestampMillisecond(value, tz.clone())
}
DataType::Timestamp(TimeUnit::Microsecond, tz) => {
ScalarValue::TimestampMicrosecond(value, tz.clone())
}
DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
ScalarValue::TimestampNanosecond(value, tz.clone())
}
_ => internal_err!(
"Unsupported data type for sort column: {:?}",
sort_data_type
)?,
};
Ok(scalar)
}
fn build_dynamic_filter_expr(
&self,
sort_data_type: &DataType,
threshold: &TopKThreshold,
) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
let op = if self.expression.options.descending {
Operator::Gt
} else {
Operator::Lt
};
let value_null = matches!(threshold, TopKThreshold::Null);
let value = Self::threshold_scalar_value(sort_data_type, threshold)?;
let comparison: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
self.expression.expr.clone(),
op,
lit(value),
));
match (self.expression.options.nulls_first, value_null) {
(true, true) => Ok(lit(false)),
(true, false) => Ok(Arc::new(BinaryExpr::new(
is_null(self.expression.expr.clone())?,
Operator::Or,
comparison,
))),
(false, true) => is_not_null(self.expression.expr.clone()),
(false, false) => Ok(comparison),
}
}
fn update_dynamic_filter(
&mut self,
sort_data_type: &DataType,
) -> datafusion_common::Result<()> {
let Some(filter) = &self.dynamic_filter else {
return Ok(());
};
let Some(threshold) = self.topk_threshold(sort_data_type)? else {
return Ok(());
};
if self.dynamic_filter_threshold.as_ref() == Some(&threshold) {
return Ok(());
}
let predicate = self.build_dynamic_filter_expr(sort_data_type, &threshold)?;
filter.update(predicate)?;
self.dynamic_filter_threshold = Some(threshold);
Ok(())
}
/// Returns true when all rows in the next group are guaranteed to be worse
/// than the current top-k threshold.
fn can_stop_before_group(
@@ -661,17 +803,20 @@ impl PartSortStream {
fn sort_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
match &mut self.buffer {
PartSortBuffer::All(_) => self.sort_all_buffer(),
PartSortBuffer::TopK(_) => self.sort_topk_buffer(),
}
}
/// Internal method for sorting `All` buffer (without limit).
fn sort_all_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
let PartSortBuffer::All(buffer) =
std::mem::replace(&mut self.buffer, PartSortBuffer::All(Vec::new()));
fn sort_record_batches(
&mut self,
buffer: &[DfRecordBatch],
limit: Option<usize>,
check_range: bool,
) -> datafusion_common::Result<DfRecordBatch> {
if buffer.is_empty() {
return Ok(DfRecordBatch::new_empty(self.schema.clone()));
}
let mut sort_columns = Vec::with_capacity(buffer.len());
let mut opt = None;
for batch in buffer.iter() {
@@ -688,7 +833,7 @@ impl PartSortStream {
)
})?;
let indices = sort_to_indices(&sort_column, opt, self.limit).map_err(|e| {
let indices = sort_to_indices(&sort_column, opt, limit).map_err(|e| {
DataFusionError::ArrowError(
Box::new(e),
Some(format!("Fail to sort to indices at {}", location!())),
@@ -698,7 +843,7 @@ impl PartSortStream {
return Ok(DfRecordBatch::new_empty(self.schema.clone()));
}
if self.limit.is_none() {
if check_range {
self.check_in_range(
&sort_column,
(
@@ -722,7 +867,7 @@ impl PartSortStream {
let total_mem: usize = buffer.iter().map(|r| r.get_array_memory_size()).sum();
self.reservation.try_grow(total_mem * 2)?;
let full_input = concat_batches(&self.schema, &buffer).map_err(|e| {
let full_input = concat_batches(&self.schema, buffer).map_err(|e| {
DataFusionError::ArrowError(
Box::new(e),
Some(format!(
@@ -748,6 +893,27 @@ impl PartSortStream {
Ok(sorted)
}
/// Internal method for sorting `All` buffer (without limit).
fn sort_all_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
let PartSortBuffer::All(buffer) =
std::mem::replace(&mut self.buffer, PartSortBuffer::All(Vec::new()))
else {
unreachable!()
};
self.sort_record_batches(&buffer, self.limit, self.limit.is_none())
}
fn sort_topk_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
let PartSortBuffer::TopK(buffer) =
std::mem::replace(&mut self.buffer, PartSortBuffer::TopK(Vec::new()))
else {
unreachable!()
};
self.sort_record_batches(&buffer, self.limit, false)
}
/// Sorts current buffer and returns `None` when there is nothing to emit.
fn sorted_buffer_if_non_empty(&mut self) -> datafusion_common::Result<Option<DfRecordBatch>> {
if self.buffer.is_empty() {
@@ -814,7 +980,7 @@ impl PartSortStream {
let next_range_idx = self.try_find_next_range(&sort_column)?;
let Some(idx) = next_range_idx else {
self.push_buffer(batch)?;
self.push_buffer(batch, sort_column.data_type())?;
// keep polling input for next batch
return Ok(());
};
@@ -822,7 +988,7 @@ impl PartSortStream {
let this_range = batch.slice(0, idx);
let remaining_range = batch.slice(idx, batch.num_rows() - idx);
if this_range.num_rows() != 0 {
self.push_buffer(this_range)?;
self.push_buffer(this_range, sort_column.data_type())?;
}
// Step to next proper PartitionRange
@@ -855,7 +1021,7 @@ impl PartSortStream {
} else if remaining_range.num_rows() != 0 {
// remaining batch is within the current partition range
// push to the buffer and continue polling
self.push_buffer(remaining_range)?;
self.push_buffer(remaining_range, sort_column.data_type())?;
}
Ok(())
@@ -877,7 +1043,7 @@ impl PartSortStream {
let next_range_idx = self.try_find_next_range(&sort_column)?;
let Some(idx) = next_range_idx else {
self.push_buffer(batch)?;
self.push_buffer(batch, sort_column.data_type())?;
// keep polling input for next batch
return Ok(None);
};
@@ -885,7 +1051,7 @@ impl PartSortStream {
let this_range = batch.slice(0, idx);
let remaining_range = batch.slice(idx, batch.num_rows() - idx);
if this_range.num_rows() != 0 {
self.push_buffer(this_range)?;
self.push_buffer(this_range, sort_column.data_type())?;
}
// Step to next proper PartitionRange
@@ -910,7 +1076,7 @@ impl PartSortStream {
} else {
// remaining batch is within the current partition range
if remaining_range.num_rows() != 0 {
self.push_buffer(remaining_range)?;
self.push_buffer(remaining_range, sort_column.data_type())?;
}
}
// Return None to continue collecting within the same group
@@ -930,7 +1096,7 @@ impl PartSortStream {
// remaining batch is within the current partition range
// push to the buffer and continue polling
if remaining_range.num_rows() != 0 {
self.push_buffer(remaining_range)?;
self.push_buffer(remaining_range, sort_column.data_type())?;
}
}
@@ -1015,8 +1181,8 @@ mod test {
use std::sync::Arc;
use arrow::array::{
TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
TimestampSecondArray,
BooleanArray, TimestampMicrosecondArray, TimestampMillisecondArray,
TimestampNanosecondArray, TimestampSecondArray,
};
use arrow::json::ArrayWriter;
use arrow_schema::{DataType, Field, Schema, SortOptions, TimeUnit};
@@ -1637,6 +1803,77 @@ mod test {
.await;
}
#[test]
fn test_topk_buffer_is_bounded_and_updates_dynamic_filter() {
let unit = TimeUnit::Millisecond;
let schema = Arc::new(Schema::new(vec![Field::new(
"ts",
DataType::Timestamp(unit, None),
false,
)]));
let sort_data_type = DataType::Timestamp(unit, None);
let partition_range = PartitionRange {
start: Timestamp::new(0, unit.into()),
end: Timestamp::new(10, unit.into()),
num_rows: 9,
identifier: 0,
};
let mock_input = Arc::new(MockInputExec::new(vec![vec![]], schema.clone()));
let exec = PartSortExec::try_new(
PhysicalSortExpr {
expr: Arc::new(Column::new("ts", 0)),
options: SortOptions {
descending: true,
..Default::default()
},
},
Some(3),
vec![vec![partition_range]],
mock_input.clone(),
)
.unwrap();
let input_stream = mock_input
.execute(0, Arc::new(TaskContext::default()))
.unwrap();
let mut stream = PartSortStream::new(
Arc::new(TaskContext::default()),
&exec,
Some(3),
input_stream,
vec![partition_range],
0,
)
.unwrap();
for batch in [vec![1, 2, 3], vec![4, 5, 6], vec![0, 7, 8]] {
stream
.push_buffer(
DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, batch)])
.unwrap(),
&sort_data_type,
)
.unwrap();
assert_eq!(stream.buffer.num_rows(), 3);
}
let dynamic_filter = stream.dynamic_filter.as_ref().unwrap().clone();
let probe = DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![5, 6, 7])])
.unwrap();
let predicate = dynamic_filter.current().unwrap();
let result = predicate
.evaluate(&probe)
.unwrap()
.into_array(probe.num_rows())
.unwrap();
let result = result.as_any().downcast_ref::<BooleanArray>().unwrap();
assert_eq!(result, &BooleanArray::from(vec![false, false, true]));
let expected =
DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![8, 7, 6])])
.unwrap();
assert_eq!(stream.sort_buffer().unwrap(), expected);
}
/// Test that verifies early termination behavior.
/// Once we've produced limit * num_partitions rows, we should stop
/// pulling from input stream.