mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-14 12:00:40 +00:00
maintain topk in part sort
Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
This commit is contained in:
@@ -25,7 +25,7 @@ use std::task::{Context, Poll};
|
||||
|
||||
use arrow::array::{Array, ArrayRef};
|
||||
use arrow::compute::{concat, concat_batches, take_record_batch};
|
||||
use arrow_schema::SchemaRef;
|
||||
use arrow_schema::{DataType, SchemaRef, TimeUnit};
|
||||
use common_recordbatch::{DfRecordBatch, DfSendableRecordBatchStream};
|
||||
use common_time::Timestamp;
|
||||
use datafusion::common::arrow::compute::sort_to_indices;
|
||||
@@ -39,8 +39,11 @@ use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSe
|
||||
use datafusion::physical_plan::{
|
||||
DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties,
|
||||
};
|
||||
use datafusion_common::{DataFusionError, internal_err};
|
||||
use datafusion_physical_expr::expressions::{DynamicFilterPhysicalExpr, lit};
|
||||
use datafusion_common::{DataFusionError, ScalarValue, internal_err};
|
||||
use datafusion_expr::Operator;
|
||||
use datafusion_physical_expr::expressions::{
|
||||
BinaryExpr, DynamicFilterPhysicalExpr, is_not_null, is_null, lit,
|
||||
};
|
||||
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
|
||||
use futures::Stream;
|
||||
use itertools::Itertools;
|
||||
@@ -299,8 +302,10 @@ impl ExecutionPlan for PartSortExec {
|
||||
|
||||
enum PartSortBuffer {
|
||||
All(Vec<DfRecordBatch>),
|
||||
TopK(Vec<DfRecordBatch>),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
enum TopKThreshold {
|
||||
Null,
|
||||
Value(i64),
|
||||
@@ -310,12 +315,14 @@ impl PartSortBuffer {
|
||||
pub fn is_empty(&self) -> bool {
|
||||
match self {
|
||||
PartSortBuffer::All(v) => v.is_empty(),
|
||||
PartSortBuffer::TopK(v) => v.is_empty(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn num_rows(&self) -> usize {
|
||||
match self {
|
||||
PartSortBuffer::All(v) => v.iter().map(|batch| batch.num_rows()).sum(),
|
||||
PartSortBuffer::TopK(v) => v.iter().map(|batch| batch.num_rows()).sum(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -336,6 +343,7 @@ struct PartSortStream {
|
||||
evaluating_batch: Option<DfRecordBatch>,
|
||||
metrics: BaselineMetrics,
|
||||
dynamic_filter: Option<Arc<DynamicFilterPhysicalExpr>>,
|
||||
dynamic_filter_threshold: Option<TopKThreshold>,
|
||||
/// Groups of ranges by primary end: (primary_end, start_idx_inclusive, end_idx_exclusive).
|
||||
/// Ranges in the same group must be processed together before outputting results.
|
||||
range_groups: Vec<(Timestamp, usize, usize)>,
|
||||
@@ -352,7 +360,11 @@ impl PartSortStream {
|
||||
partition_ranges: Vec<PartitionRange>,
|
||||
partition: usize,
|
||||
) -> datafusion_common::Result<Self> {
|
||||
let buffer = PartSortBuffer::All(Vec::new());
|
||||
let buffer = if limit.is_some() {
|
||||
PartSortBuffer::TopK(Vec::new())
|
||||
} else {
|
||||
PartSortBuffer::All(Vec::new())
|
||||
};
|
||||
|
||||
// Compute range groups by primary end
|
||||
let descending = sort.expression.options.descending;
|
||||
@@ -373,6 +385,7 @@ impl PartSortStream {
|
||||
evaluating_batch: None,
|
||||
metrics: BaselineMetrics::new(&sort.metrics, partition),
|
||||
dynamic_filter: sort.dynamic_filter.clone(),
|
||||
dynamic_filter_threshold: None,
|
||||
range_groups,
|
||||
cur_group_idx: 0,
|
||||
})
|
||||
@@ -507,11 +520,54 @@ impl PartSortStream {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn push_buffer(&mut self, batch: DfRecordBatch) -> datafusion_common::Result<()> {
|
||||
fn push_buffer(
|
||||
&mut self,
|
||||
batch: DfRecordBatch,
|
||||
sort_data_type: &DataType,
|
||||
) -> datafusion_common::Result<()> {
|
||||
let topk = matches!(self.buffer, PartSortBuffer::TopK(_));
|
||||
match &mut self.buffer {
|
||||
PartSortBuffer::All(v) => v.push(batch),
|
||||
PartSortBuffer::TopK(v) => v.push(batch),
|
||||
}
|
||||
|
||||
if topk {
|
||||
self.compact_topk_buffer()?;
|
||||
self.update_dynamic_filter(sort_data_type)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn compact_topk_buffer(&mut self) -> datafusion_common::Result<()> {
|
||||
let Some(limit) = self.limit else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let PartSortBuffer::TopK(buffer) =
|
||||
std::mem::replace(&mut self.buffer, PartSortBuffer::TopK(Vec::new()))
|
||||
else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
if limit == 0 || buffer.is_empty() {
|
||||
self.buffer = PartSortBuffer::TopK(Vec::new());
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let total_rows: usize = buffer.iter().map(|batch| batch.num_rows()).sum();
|
||||
if total_rows <= limit {
|
||||
self.buffer = PartSortBuffer::TopK(buffer);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let topk = self.sort_record_batches(&buffer, Some(limit), false)?;
|
||||
self.buffer = if topk.num_rows() == 0 {
|
||||
PartSortBuffer::TopK(Vec::new())
|
||||
} else {
|
||||
PartSortBuffer::TopK(vec![topk])
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -527,7 +583,9 @@ impl PartSortStream {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let PartSortBuffer::All(buffer) = &self.buffer;
|
||||
let buffer = match &self.buffer {
|
||||
PartSortBuffer::All(buffer) | PartSortBuffer::TopK(buffer) => buffer,
|
||||
};
|
||||
let mut sort_columns = Vec::with_capacity(buffer.len());
|
||||
let mut opt = None;
|
||||
for batch in buffer {
|
||||
@@ -567,6 +625,90 @@ impl PartSortStream {
|
||||
Ok(Some(threshold))
|
||||
}
|
||||
|
||||
fn threshold_scalar_value(
|
||||
sort_data_type: &DataType,
|
||||
threshold: &TopKThreshold,
|
||||
) -> datafusion_common::Result<ScalarValue> {
|
||||
let value = match threshold {
|
||||
TopKThreshold::Null => None,
|
||||
TopKThreshold::Value(value) => Some(*value),
|
||||
};
|
||||
|
||||
let scalar = match sort_data_type {
|
||||
DataType::Timestamp(TimeUnit::Second, tz) => {
|
||||
ScalarValue::TimestampSecond(value, tz.clone())
|
||||
}
|
||||
DataType::Timestamp(TimeUnit::Millisecond, tz) => {
|
||||
ScalarValue::TimestampMillisecond(value, tz.clone())
|
||||
}
|
||||
DataType::Timestamp(TimeUnit::Microsecond, tz) => {
|
||||
ScalarValue::TimestampMicrosecond(value, tz.clone())
|
||||
}
|
||||
DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
|
||||
ScalarValue::TimestampNanosecond(value, tz.clone())
|
||||
}
|
||||
_ => internal_err!(
|
||||
"Unsupported data type for sort column: {:?}",
|
||||
sort_data_type
|
||||
)?,
|
||||
};
|
||||
|
||||
Ok(scalar)
|
||||
}
|
||||
|
||||
fn build_dynamic_filter_expr(
|
||||
&self,
|
||||
sort_data_type: &DataType,
|
||||
threshold: &TopKThreshold,
|
||||
) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
|
||||
let op = if self.expression.options.descending {
|
||||
Operator::Gt
|
||||
} else {
|
||||
Operator::Lt
|
||||
};
|
||||
let value_null = matches!(threshold, TopKThreshold::Null);
|
||||
let value = Self::threshold_scalar_value(sort_data_type, threshold)?;
|
||||
let comparison: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
|
||||
self.expression.expr.clone(),
|
||||
op,
|
||||
lit(value),
|
||||
));
|
||||
|
||||
match (self.expression.options.nulls_first, value_null) {
|
||||
(true, true) => Ok(lit(false)),
|
||||
(true, false) => Ok(Arc::new(BinaryExpr::new(
|
||||
is_null(self.expression.expr.clone())?,
|
||||
Operator::Or,
|
||||
comparison,
|
||||
))),
|
||||
(false, true) => is_not_null(self.expression.expr.clone()),
|
||||
(false, false) => Ok(comparison),
|
||||
}
|
||||
}
|
||||
|
||||
fn update_dynamic_filter(
|
||||
&mut self,
|
||||
sort_data_type: &DataType,
|
||||
) -> datafusion_common::Result<()> {
|
||||
let Some(filter) = &self.dynamic_filter else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let Some(threshold) = self.topk_threshold(sort_data_type)? else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
if self.dynamic_filter_threshold.as_ref() == Some(&threshold) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let predicate = self.build_dynamic_filter_expr(sort_data_type, &threshold)?;
|
||||
filter.update(predicate)?;
|
||||
self.dynamic_filter_threshold = Some(threshold);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns true when all rows in the next group are guaranteed to be worse
|
||||
/// than the current top-k threshold.
|
||||
fn can_stop_before_group(
|
||||
@@ -661,17 +803,20 @@ impl PartSortStream {
|
||||
fn sort_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
|
||||
match &mut self.buffer {
|
||||
PartSortBuffer::All(_) => self.sort_all_buffer(),
|
||||
PartSortBuffer::TopK(_) => self.sort_topk_buffer(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Internal method for sorting `All` buffer (without limit).
|
||||
fn sort_all_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
|
||||
let PartSortBuffer::All(buffer) =
|
||||
std::mem::replace(&mut self.buffer, PartSortBuffer::All(Vec::new()));
|
||||
|
||||
fn sort_record_batches(
|
||||
&mut self,
|
||||
buffer: &[DfRecordBatch],
|
||||
limit: Option<usize>,
|
||||
check_range: bool,
|
||||
) -> datafusion_common::Result<DfRecordBatch> {
|
||||
if buffer.is_empty() {
|
||||
return Ok(DfRecordBatch::new_empty(self.schema.clone()));
|
||||
}
|
||||
|
||||
let mut sort_columns = Vec::with_capacity(buffer.len());
|
||||
let mut opt = None;
|
||||
for batch in buffer.iter() {
|
||||
@@ -688,7 +833,7 @@ impl PartSortStream {
|
||||
)
|
||||
})?;
|
||||
|
||||
let indices = sort_to_indices(&sort_column, opt, self.limit).map_err(|e| {
|
||||
let indices = sort_to_indices(&sort_column, opt, limit).map_err(|e| {
|
||||
DataFusionError::ArrowError(
|
||||
Box::new(e),
|
||||
Some(format!("Fail to sort to indices at {}", location!())),
|
||||
@@ -698,7 +843,7 @@ impl PartSortStream {
|
||||
return Ok(DfRecordBatch::new_empty(self.schema.clone()));
|
||||
}
|
||||
|
||||
if self.limit.is_none() {
|
||||
if check_range {
|
||||
self.check_in_range(
|
||||
&sort_column,
|
||||
(
|
||||
@@ -722,7 +867,7 @@ impl PartSortStream {
|
||||
let total_mem: usize = buffer.iter().map(|r| r.get_array_memory_size()).sum();
|
||||
self.reservation.try_grow(total_mem * 2)?;
|
||||
|
||||
let full_input = concat_batches(&self.schema, &buffer).map_err(|e| {
|
||||
let full_input = concat_batches(&self.schema, buffer).map_err(|e| {
|
||||
DataFusionError::ArrowError(
|
||||
Box::new(e),
|
||||
Some(format!(
|
||||
@@ -748,6 +893,27 @@ impl PartSortStream {
|
||||
Ok(sorted)
|
||||
}
|
||||
|
||||
/// Internal method for sorting `All` buffer (without limit).
|
||||
fn sort_all_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
|
||||
let PartSortBuffer::All(buffer) =
|
||||
std::mem::replace(&mut self.buffer, PartSortBuffer::All(Vec::new()))
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
self.sort_record_batches(&buffer, self.limit, self.limit.is_none())
|
||||
}
|
||||
|
||||
fn sort_topk_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
|
||||
let PartSortBuffer::TopK(buffer) =
|
||||
std::mem::replace(&mut self.buffer, PartSortBuffer::TopK(Vec::new()))
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
self.sort_record_batches(&buffer, self.limit, false)
|
||||
}
|
||||
|
||||
/// Sorts current buffer and returns `None` when there is nothing to emit.
|
||||
fn sorted_buffer_if_non_empty(&mut self) -> datafusion_common::Result<Option<DfRecordBatch>> {
|
||||
if self.buffer.is_empty() {
|
||||
@@ -814,7 +980,7 @@ impl PartSortStream {
|
||||
|
||||
let next_range_idx = self.try_find_next_range(&sort_column)?;
|
||||
let Some(idx) = next_range_idx else {
|
||||
self.push_buffer(batch)?;
|
||||
self.push_buffer(batch, sort_column.data_type())?;
|
||||
// keep polling input for next batch
|
||||
return Ok(());
|
||||
};
|
||||
@@ -822,7 +988,7 @@ impl PartSortStream {
|
||||
let this_range = batch.slice(0, idx);
|
||||
let remaining_range = batch.slice(idx, batch.num_rows() - idx);
|
||||
if this_range.num_rows() != 0 {
|
||||
self.push_buffer(this_range)?;
|
||||
self.push_buffer(this_range, sort_column.data_type())?;
|
||||
}
|
||||
|
||||
// Step to next proper PartitionRange
|
||||
@@ -855,7 +1021,7 @@ impl PartSortStream {
|
||||
} else if remaining_range.num_rows() != 0 {
|
||||
// remaining batch is within the current partition range
|
||||
// push to the buffer and continue polling
|
||||
self.push_buffer(remaining_range)?;
|
||||
self.push_buffer(remaining_range, sort_column.data_type())?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -877,7 +1043,7 @@ impl PartSortStream {
|
||||
|
||||
let next_range_idx = self.try_find_next_range(&sort_column)?;
|
||||
let Some(idx) = next_range_idx else {
|
||||
self.push_buffer(batch)?;
|
||||
self.push_buffer(batch, sort_column.data_type())?;
|
||||
// keep polling input for next batch
|
||||
return Ok(None);
|
||||
};
|
||||
@@ -885,7 +1051,7 @@ impl PartSortStream {
|
||||
let this_range = batch.slice(0, idx);
|
||||
let remaining_range = batch.slice(idx, batch.num_rows() - idx);
|
||||
if this_range.num_rows() != 0 {
|
||||
self.push_buffer(this_range)?;
|
||||
self.push_buffer(this_range, sort_column.data_type())?;
|
||||
}
|
||||
|
||||
// Step to next proper PartitionRange
|
||||
@@ -910,7 +1076,7 @@ impl PartSortStream {
|
||||
} else {
|
||||
// remaining batch is within the current partition range
|
||||
if remaining_range.num_rows() != 0 {
|
||||
self.push_buffer(remaining_range)?;
|
||||
self.push_buffer(remaining_range, sort_column.data_type())?;
|
||||
}
|
||||
}
|
||||
// Return None to continue collecting within the same group
|
||||
@@ -930,7 +1096,7 @@ impl PartSortStream {
|
||||
// remaining batch is within the current partition range
|
||||
// push to the buffer and continue polling
|
||||
if remaining_range.num_rows() != 0 {
|
||||
self.push_buffer(remaining_range)?;
|
||||
self.push_buffer(remaining_range, sort_column.data_type())?;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1015,8 +1181,8 @@ mod test {
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow::array::{
|
||||
TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
|
||||
TimestampSecondArray,
|
||||
BooleanArray, TimestampMicrosecondArray, TimestampMillisecondArray,
|
||||
TimestampNanosecondArray, TimestampSecondArray,
|
||||
};
|
||||
use arrow::json::ArrayWriter;
|
||||
use arrow_schema::{DataType, Field, Schema, SortOptions, TimeUnit};
|
||||
@@ -1637,6 +1803,77 @@ mod test {
|
||||
.await;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_topk_buffer_is_bounded_and_updates_dynamic_filter() {
|
||||
let unit = TimeUnit::Millisecond;
|
||||
let schema = Arc::new(Schema::new(vec![Field::new(
|
||||
"ts",
|
||||
DataType::Timestamp(unit, None),
|
||||
false,
|
||||
)]));
|
||||
let sort_data_type = DataType::Timestamp(unit, None);
|
||||
let partition_range = PartitionRange {
|
||||
start: Timestamp::new(0, unit.into()),
|
||||
end: Timestamp::new(10, unit.into()),
|
||||
num_rows: 9,
|
||||
identifier: 0,
|
||||
};
|
||||
let mock_input = Arc::new(MockInputExec::new(vec![vec![]], schema.clone()));
|
||||
let exec = PartSortExec::try_new(
|
||||
PhysicalSortExpr {
|
||||
expr: Arc::new(Column::new("ts", 0)),
|
||||
options: SortOptions {
|
||||
descending: true,
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
Some(3),
|
||||
vec![vec![partition_range]],
|
||||
mock_input.clone(),
|
||||
)
|
||||
.unwrap();
|
||||
let input_stream = mock_input
|
||||
.execute(0, Arc::new(TaskContext::default()))
|
||||
.unwrap();
|
||||
let mut stream = PartSortStream::new(
|
||||
Arc::new(TaskContext::default()),
|
||||
&exec,
|
||||
Some(3),
|
||||
input_stream,
|
||||
vec![partition_range],
|
||||
0,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
for batch in [vec![1, 2, 3], vec![4, 5, 6], vec![0, 7, 8]] {
|
||||
stream
|
||||
.push_buffer(
|
||||
DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, batch)])
|
||||
.unwrap(),
|
||||
&sort_data_type,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(stream.buffer.num_rows(), 3);
|
||||
}
|
||||
|
||||
let dynamic_filter = stream.dynamic_filter.as_ref().unwrap().clone();
|
||||
let probe = DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![5, 6, 7])])
|
||||
.unwrap();
|
||||
let predicate = dynamic_filter.current().unwrap();
|
||||
let result = predicate
|
||||
.evaluate(&probe)
|
||||
.unwrap()
|
||||
.into_array(probe.num_rows())
|
||||
.unwrap();
|
||||
let result = result.as_any().downcast_ref::<BooleanArray>().unwrap();
|
||||
assert_eq!(result, &BooleanArray::from(vec![false, false, true]));
|
||||
|
||||
let expected =
|
||||
DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![8, 7, 6])])
|
||||
.unwrap();
|
||||
assert_eq!(stream.sort_buffer().unwrap(), expected);
|
||||
}
|
||||
|
||||
/// Test that verifies early termination behavior.
|
||||
/// Once we've produced limit * num_partitions rows, we should stop
|
||||
/// pulling from input stream.
|
||||
|
||||
Reference in New Issue
Block a user