diff --git a/src/query/src/part_sort.rs b/src/query/src/part_sort.rs index 72dd3d5d4c..5309b3aabd 100644 --- a/src/query/src/part_sort.rs +++ b/src/query/src/part_sort.rs @@ -25,7 +25,7 @@ use std::task::{Context, Poll}; use arrow::array::{Array, ArrayRef}; use arrow::compute::{concat, concat_batches, take_record_batch}; -use arrow_schema::SchemaRef; +use arrow_schema::{DataType, SchemaRef, TimeUnit}; use common_recordbatch::{DfRecordBatch, DfSendableRecordBatchStream}; use common_time::Timestamp; use datafusion::common::arrow::compute::sort_to_indices; @@ -39,8 +39,11 @@ use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSe use datafusion::physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, }; -use datafusion_common::{DataFusionError, internal_err}; -use datafusion_physical_expr::expressions::{DynamicFilterPhysicalExpr, lit}; +use datafusion_common::{DataFusionError, ScalarValue, internal_err}; +use datafusion_expr::Operator; +use datafusion_physical_expr::expressions::{ + BinaryExpr, DynamicFilterPhysicalExpr, is_not_null, is_null, lit, +}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; use futures::Stream; use itertools::Itertools; @@ -299,8 +302,10 @@ impl ExecutionPlan for PartSortExec { enum PartSortBuffer { All(Vec), + TopK(Vec), } +#[derive(Clone, Debug, PartialEq, Eq)] enum TopKThreshold { Null, Value(i64), @@ -310,12 +315,14 @@ impl PartSortBuffer { pub fn is_empty(&self) -> bool { match self { PartSortBuffer::All(v) => v.is_empty(), + PartSortBuffer::TopK(v) => v.is_empty(), } } pub fn num_rows(&self) -> usize { match self { PartSortBuffer::All(v) => v.iter().map(|batch| batch.num_rows()).sum(), + PartSortBuffer::TopK(v) => v.iter().map(|batch| batch.num_rows()).sum(), } } } @@ -336,6 +343,7 @@ struct PartSortStream { evaluating_batch: Option, metrics: BaselineMetrics, dynamic_filter: Option>, + dynamic_filter_threshold: Option, /// Groups of ranges by primary end: (primary_end, start_idx_inclusive, end_idx_exclusive). /// Ranges in the same group must be processed together before outputting results. range_groups: Vec<(Timestamp, usize, usize)>, @@ -352,7 +360,11 @@ impl PartSortStream { partition_ranges: Vec, partition: usize, ) -> datafusion_common::Result { - let buffer = PartSortBuffer::All(Vec::new()); + let buffer = if limit.is_some() { + PartSortBuffer::TopK(Vec::new()) + } else { + PartSortBuffer::All(Vec::new()) + }; // Compute range groups by primary end let descending = sort.expression.options.descending; @@ -373,6 +385,7 @@ impl PartSortStream { evaluating_batch: None, metrics: BaselineMetrics::new(&sort.metrics, partition), dynamic_filter: sort.dynamic_filter.clone(), + dynamic_filter_threshold: None, range_groups, cur_group_idx: 0, }) @@ -507,11 +520,54 @@ impl PartSortStream { Ok(None) } - fn push_buffer(&mut self, batch: DfRecordBatch) -> datafusion_common::Result<()> { + fn push_buffer( + &mut self, + batch: DfRecordBatch, + sort_data_type: &DataType, + ) -> datafusion_common::Result<()> { + let topk = matches!(self.buffer, PartSortBuffer::TopK(_)); match &mut self.buffer { PartSortBuffer::All(v) => v.push(batch), + PartSortBuffer::TopK(v) => v.push(batch), } + if topk { + self.compact_topk_buffer()?; + self.update_dynamic_filter(sort_data_type)?; + } + + Ok(()) + } + + fn compact_topk_buffer(&mut self) -> datafusion_common::Result<()> { + let Some(limit) = self.limit else { + return Ok(()); + }; + + let PartSortBuffer::TopK(buffer) = + std::mem::replace(&mut self.buffer, PartSortBuffer::TopK(Vec::new())) + else { + return Ok(()); + }; + + if limit == 0 || buffer.is_empty() { + self.buffer = PartSortBuffer::TopK(Vec::new()); + return Ok(()); + } + + let total_rows: usize = buffer.iter().map(|batch| batch.num_rows()).sum(); + if total_rows <= limit { + self.buffer = PartSortBuffer::TopK(buffer); + return Ok(()); + } + + let topk = self.sort_record_batches(&buffer, Some(limit), false)?; + self.buffer = if topk.num_rows() == 0 { + PartSortBuffer::TopK(Vec::new()) + } else { + PartSortBuffer::TopK(vec![topk]) + }; + Ok(()) } @@ -527,7 +583,9 @@ impl PartSortStream { return Ok(None); } - let PartSortBuffer::All(buffer) = &self.buffer; + let buffer = match &self.buffer { + PartSortBuffer::All(buffer) | PartSortBuffer::TopK(buffer) => buffer, + }; let mut sort_columns = Vec::with_capacity(buffer.len()); let mut opt = None; for batch in buffer { @@ -567,6 +625,90 @@ impl PartSortStream { Ok(Some(threshold)) } + fn threshold_scalar_value( + sort_data_type: &DataType, + threshold: &TopKThreshold, + ) -> datafusion_common::Result { + let value = match threshold { + TopKThreshold::Null => None, + TopKThreshold::Value(value) => Some(*value), + }; + + let scalar = match sort_data_type { + DataType::Timestamp(TimeUnit::Second, tz) => { + ScalarValue::TimestampSecond(value, tz.clone()) + } + DataType::Timestamp(TimeUnit::Millisecond, tz) => { + ScalarValue::TimestampMillisecond(value, tz.clone()) + } + DataType::Timestamp(TimeUnit::Microsecond, tz) => { + ScalarValue::TimestampMicrosecond(value, tz.clone()) + } + DataType::Timestamp(TimeUnit::Nanosecond, tz) => { + ScalarValue::TimestampNanosecond(value, tz.clone()) + } + _ => internal_err!( + "Unsupported data type for sort column: {:?}", + sort_data_type + )?, + }; + + Ok(scalar) + } + + fn build_dynamic_filter_expr( + &self, + sort_data_type: &DataType, + threshold: &TopKThreshold, + ) -> datafusion_common::Result> { + let op = if self.expression.options.descending { + Operator::Gt + } else { + Operator::Lt + }; + let value_null = matches!(threshold, TopKThreshold::Null); + let value = Self::threshold_scalar_value(sort_data_type, threshold)?; + let comparison: Arc = Arc::new(BinaryExpr::new( + self.expression.expr.clone(), + op, + lit(value), + )); + + match (self.expression.options.nulls_first, value_null) { + (true, true) => Ok(lit(false)), + (true, false) => Ok(Arc::new(BinaryExpr::new( + is_null(self.expression.expr.clone())?, + Operator::Or, + comparison, + ))), + (false, true) => is_not_null(self.expression.expr.clone()), + (false, false) => Ok(comparison), + } + } + + fn update_dynamic_filter( + &mut self, + sort_data_type: &DataType, + ) -> datafusion_common::Result<()> { + let Some(filter) = &self.dynamic_filter else { + return Ok(()); + }; + + let Some(threshold) = self.topk_threshold(sort_data_type)? else { + return Ok(()); + }; + + if self.dynamic_filter_threshold.as_ref() == Some(&threshold) { + return Ok(()); + } + + let predicate = self.build_dynamic_filter_expr(sort_data_type, &threshold)?; + filter.update(predicate)?; + self.dynamic_filter_threshold = Some(threshold); + + Ok(()) + } + /// Returns true when all rows in the next group are guaranteed to be worse /// than the current top-k threshold. fn can_stop_before_group( @@ -661,17 +803,20 @@ impl PartSortStream { fn sort_buffer(&mut self) -> datafusion_common::Result { match &mut self.buffer { PartSortBuffer::All(_) => self.sort_all_buffer(), + PartSortBuffer::TopK(_) => self.sort_topk_buffer(), } } - /// Internal method for sorting `All` buffer (without limit). - fn sort_all_buffer(&mut self) -> datafusion_common::Result { - let PartSortBuffer::All(buffer) = - std::mem::replace(&mut self.buffer, PartSortBuffer::All(Vec::new())); - + fn sort_record_batches( + &mut self, + buffer: &[DfRecordBatch], + limit: Option, + check_range: bool, + ) -> datafusion_common::Result { if buffer.is_empty() { return Ok(DfRecordBatch::new_empty(self.schema.clone())); } + let mut sort_columns = Vec::with_capacity(buffer.len()); let mut opt = None; for batch in buffer.iter() { @@ -688,7 +833,7 @@ impl PartSortStream { ) })?; - let indices = sort_to_indices(&sort_column, opt, self.limit).map_err(|e| { + let indices = sort_to_indices(&sort_column, opt, limit).map_err(|e| { DataFusionError::ArrowError( Box::new(e), Some(format!("Fail to sort to indices at {}", location!())), @@ -698,7 +843,7 @@ impl PartSortStream { return Ok(DfRecordBatch::new_empty(self.schema.clone())); } - if self.limit.is_none() { + if check_range { self.check_in_range( &sort_column, ( @@ -722,7 +867,7 @@ impl PartSortStream { let total_mem: usize = buffer.iter().map(|r| r.get_array_memory_size()).sum(); self.reservation.try_grow(total_mem * 2)?; - let full_input = concat_batches(&self.schema, &buffer).map_err(|e| { + let full_input = concat_batches(&self.schema, buffer).map_err(|e| { DataFusionError::ArrowError( Box::new(e), Some(format!( @@ -748,6 +893,27 @@ impl PartSortStream { Ok(sorted) } + /// Internal method for sorting `All` buffer (without limit). + fn sort_all_buffer(&mut self) -> datafusion_common::Result { + let PartSortBuffer::All(buffer) = + std::mem::replace(&mut self.buffer, PartSortBuffer::All(Vec::new())) + else { + unreachable!() + }; + + self.sort_record_batches(&buffer, self.limit, self.limit.is_none()) + } + + fn sort_topk_buffer(&mut self) -> datafusion_common::Result { + let PartSortBuffer::TopK(buffer) = + std::mem::replace(&mut self.buffer, PartSortBuffer::TopK(Vec::new())) + else { + unreachable!() + }; + + self.sort_record_batches(&buffer, self.limit, false) + } + /// Sorts current buffer and returns `None` when there is nothing to emit. fn sorted_buffer_if_non_empty(&mut self) -> datafusion_common::Result> { if self.buffer.is_empty() { @@ -814,7 +980,7 @@ impl PartSortStream { let next_range_idx = self.try_find_next_range(&sort_column)?; let Some(idx) = next_range_idx else { - self.push_buffer(batch)?; + self.push_buffer(batch, sort_column.data_type())?; // keep polling input for next batch return Ok(()); }; @@ -822,7 +988,7 @@ impl PartSortStream { let this_range = batch.slice(0, idx); let remaining_range = batch.slice(idx, batch.num_rows() - idx); if this_range.num_rows() != 0 { - self.push_buffer(this_range)?; + self.push_buffer(this_range, sort_column.data_type())?; } // Step to next proper PartitionRange @@ -855,7 +1021,7 @@ impl PartSortStream { } else if remaining_range.num_rows() != 0 { // remaining batch is within the current partition range // push to the buffer and continue polling - self.push_buffer(remaining_range)?; + self.push_buffer(remaining_range, sort_column.data_type())?; } Ok(()) @@ -877,7 +1043,7 @@ impl PartSortStream { let next_range_idx = self.try_find_next_range(&sort_column)?; let Some(idx) = next_range_idx else { - self.push_buffer(batch)?; + self.push_buffer(batch, sort_column.data_type())?; // keep polling input for next batch return Ok(None); }; @@ -885,7 +1051,7 @@ impl PartSortStream { let this_range = batch.slice(0, idx); let remaining_range = batch.slice(idx, batch.num_rows() - idx); if this_range.num_rows() != 0 { - self.push_buffer(this_range)?; + self.push_buffer(this_range, sort_column.data_type())?; } // Step to next proper PartitionRange @@ -910,7 +1076,7 @@ impl PartSortStream { } else { // remaining batch is within the current partition range if remaining_range.num_rows() != 0 { - self.push_buffer(remaining_range)?; + self.push_buffer(remaining_range, sort_column.data_type())?; } } // Return None to continue collecting within the same group @@ -930,7 +1096,7 @@ impl PartSortStream { // remaining batch is within the current partition range // push to the buffer and continue polling if remaining_range.num_rows() != 0 { - self.push_buffer(remaining_range)?; + self.push_buffer(remaining_range, sort_column.data_type())?; } } @@ -1015,8 +1181,8 @@ mod test { use std::sync::Arc; use arrow::array::{ - TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, - TimestampSecondArray, + BooleanArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, }; use arrow::json::ArrayWriter; use arrow_schema::{DataType, Field, Schema, SortOptions, TimeUnit}; @@ -1637,6 +1803,77 @@ mod test { .await; } + #[test] + fn test_topk_buffer_is_bounded_and_updates_dynamic_filter() { + let unit = TimeUnit::Millisecond; + let schema = Arc::new(Schema::new(vec![Field::new( + "ts", + DataType::Timestamp(unit, None), + false, + )])); + let sort_data_type = DataType::Timestamp(unit, None); + let partition_range = PartitionRange { + start: Timestamp::new(0, unit.into()), + end: Timestamp::new(10, unit.into()), + num_rows: 9, + identifier: 0, + }; + let mock_input = Arc::new(MockInputExec::new(vec![vec![]], schema.clone())); + let exec = PartSortExec::try_new( + PhysicalSortExpr { + expr: Arc::new(Column::new("ts", 0)), + options: SortOptions { + descending: true, + ..Default::default() + }, + }, + Some(3), + vec![vec![partition_range]], + mock_input.clone(), + ) + .unwrap(); + let input_stream = mock_input + .execute(0, Arc::new(TaskContext::default())) + .unwrap(); + let mut stream = PartSortStream::new( + Arc::new(TaskContext::default()), + &exec, + Some(3), + input_stream, + vec![partition_range], + 0, + ) + .unwrap(); + + for batch in [vec![1, 2, 3], vec![4, 5, 6], vec![0, 7, 8]] { + stream + .push_buffer( + DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, batch)]) + .unwrap(), + &sort_data_type, + ) + .unwrap(); + assert_eq!(stream.buffer.num_rows(), 3); + } + + let dynamic_filter = stream.dynamic_filter.as_ref().unwrap().clone(); + let probe = DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![5, 6, 7])]) + .unwrap(); + let predicate = dynamic_filter.current().unwrap(); + let result = predicate + .evaluate(&probe) + .unwrap() + .into_array(probe.num_rows()) + .unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + assert_eq!(result, &BooleanArray::from(vec![false, false, true])); + + let expected = + DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![8, 7, 6])]) + .unwrap(); + assert_eq!(stream.sort_buffer().unwrap(), expected); + } + /// Test that verifies early termination behavior. /// Once we've produced limit * num_partitions rows, we should stop /// pulling from input stream.