From fdedbb8261a253634d6eac7c1f6bfbbc55880d85 Mon Sep 17 00:00:00 2001 From: discord9 <55937128+discord9@users.noreply.github.com> Date: Tue, 23 Dec 2025 17:24:55 +0800 Subject: [PATCH] fix: part sort share same topk dyn filter&early stop use dyn filter (#7460) * fix: part sort share same topk dyn filter Signed-off-by: discord9 * test: one Signed-off-by: discord9 * feat: use dyn filter properly instead Signed-off-by: discord9 * c Signed-off-by: discord9 * docs: explain why dyn filter work Signed-off-by: discord9 * chore: after rebase fix Signed-off-by: discord9 --------- Signed-off-by: discord9 --- src/query/src/part_sort.rs | 235 +++++++++++++++++++++++++++---------- 1 file changed, 174 insertions(+), 61 deletions(-) diff --git a/src/query/src/part_sort.rs b/src/query/src/part_sort.rs index 7a6fa18836..22682b9a3a 100644 --- a/src/query/src/part_sort.rs +++ b/src/query/src/part_sort.rs @@ -23,11 +23,16 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use arrow::array::ArrayRef; +use arrow::array::{ + ArrayRef, AsArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, +}; use arrow::compute::{concat, concat_batches, take_record_batch}; -use arrow_schema::SchemaRef; +use arrow_schema::{Schema, SchemaRef}; use common_recordbatch::{DfRecordBatch, DfSendableRecordBatchStream}; +use common_telemetry::warn; use common_time::Timestamp; +use common_time::timestamp::TimeUnit; use datafusion::common::arrow::compute::sort_to_indices; use datafusion::execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion::execution::{RecordBatchStream, TaskContext}; @@ -40,8 +45,9 @@ use datafusion::physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, TopK, TopKDynamicFilters, }; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{DataFusionError, internal_err}; -use datafusion_physical_expr::expressions::{DynamicFilterPhysicalExpr, lit}; +use datafusion_physical_expr::expressions::{Column, DynamicFilterPhysicalExpr, lit}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; use futures::{Stream, StreamExt}; use itertools::Itertools; @@ -347,6 +353,9 @@ struct PartSortStream { range_groups: Vec<(Timestamp, usize, usize)>, /// Current group being processed (index into range_groups). cur_group_idx: usize, + /// Dynamic Filter for all TopK instance, notice the `PartSortExec`/`PartSortStream`/`TopK` must share the same filter + /// so that updates from each `TopK` can be seen by others(and by the table scan operator). + filter: Option>>, } impl PartSortStream { @@ -360,7 +369,7 @@ impl PartSortStream { filter: Option>>, ) -> datafusion_common::Result { let buffer = if let Some(limit) = limit { - let Some(filter) = filter else { + let Some(filter) = filter.clone() else { return internal_err!( "TopKDynamicFilters must be provided when limit is set at {}", snafu::location!() @@ -377,7 +386,7 @@ impl PartSortStream { context.session_config().batch_size(), context.runtime_env(), &sort.metrics, - filter, + filter.clone(), )?, 0, ) @@ -407,23 +416,11 @@ impl PartSortStream { root_metrics: sort.metrics.clone(), range_groups, cur_group_idx: 0, + filter, }) } } -macro_rules! ts_to_timestamp { - ($t:ty, $unit:expr, $arr:expr) => {{ - let arr = $arr - .as_any() - .downcast_ref::>() - .unwrap(); - - arr.iter() - .map(|v| v.map(|v| Timestamp::new(v, common_time::timestamp::TimeUnit::from(&$unit)))) - .collect_vec() - }}; -} - macro_rules! array_check_helper { ($t:ty, $unit:expr, $arr:expr, $cur_range:expr, $min_max_idx:expr) => {{ if $cur_range.start.unit().as_arrow_time_unit() != $unit @@ -546,9 +543,10 @@ impl PartSortStream { Ok(()) } - /// A temporary solution for stop read earlier when current group do not overlap with any of those next group + /// Stop read earlier when current group do not overlap with any of those next group /// If not overlap, we can stop read further input as current top k is final - fn can_stop_early(&mut self) -> datafusion_common::Result { + /// Use dynamic filter to evaluate the next group's primary end + fn can_stop_early(&mut self, schema: &Arc) -> datafusion_common::Result { let topk_cnt = match &self.buffer { PartSortBuffer::Top(_, cnt) => *cnt, _ => return Ok(false), @@ -557,46 +555,74 @@ impl PartSortStream { if Some(topk_cnt) < self.limit { return Ok(false); } - // else check if last value in topk is not in next group range - let topk_buffer = self.sort_top_buffer()?; - - // Guard against empty buffer - this can happen if TopK's internal filtering - // removed all rows, or if the buffer was cleared. In this case, we cannot - // determine if we can stop early, so continue processing. - // Fixes: https://github.com/orgs/GreptimeTeam/discussions/7457 - if topk_buffer.num_rows() == 0 { - return Ok(false); - } - - let min_batch = topk_buffer.slice(topk_buffer.num_rows() - 1, 1); - let min_sort_column = self.expression.evaluate_to_sort_column(&min_batch)?.values; - let last_val = downcast_ts_array!( - min_sort_column.data_type() => (ts_to_timestamp, min_sort_column), - _ => internal_err!( - "Unsupported data type for sort column: {:?}", - min_sort_column.data_type() - )?, - )[0]; - let Some(last_val) = last_val else { - return Ok(false); - }; let next_group_primary_end = if self.cur_group_idx + 1 < self.range_groups.len() { self.range_groups[self.cur_group_idx + 1].0 } else { // no next group return Ok(false); }; - let descending = self.expression.options.descending; - let not_in_next_group_range = if descending { - last_val >= next_group_primary_end - } else { - last_val < next_group_primary_end + + // dyn filter is updated based on the last value of topk heap("threshold") + // it's a max-heap for a ASC TopK operator + // so can use dyn filter to prune data range + let filter = self + .filter + .as_ref() + .expect("TopKDynamicFilters must be provided when limit is set"); + let filter = filter.read().expr().current()?; + let mut ts_index = None; + // invariant: the filter must contain only the same column expr that's time index column + let filter = filter + .transform_down(|c| { + // rewrite all column's index as 0 + if let Some(column) = c.as_any().downcast_ref::() { + ts_index = Some(column.index()); + Ok(Transformed::yes( + Arc::new(Column::new(column.name(), 0)) as Arc + )) + } else { + Ok(Transformed::no(c)) + } + })? + .data; + let Some(ts_index) = ts_index else { + return Ok(false); // dyn filter is still true, cannot decide, continue read }; - - // refill topk buffer count - self.push_buffer(topk_buffer)?; - - Ok(not_in_next_group_range) + let field = if schema.fields().len() <= ts_index { + warn!( + "Schema mismatch when evaluating dynamic filter for PartSortExec at {}, schema: {:?}, ts_index: {}", + self.partition, schema, ts_index + ); + return Ok(false); // schema mismatch, cannot decide, continue read + } else { + schema.field(ts_index) + }; + let schema = Arc::new(Schema::new(vec![field.clone()])); + // convert next_group_primary_end to array&filter, if eval to false, means no overlap, can stop early + let primary_end_array = match next_group_primary_end.unit() { + TimeUnit::Second => Arc::new(TimestampSecondArray::from(vec![ + next_group_primary_end.value(), + ])) as ArrayRef, + TimeUnit::Millisecond => Arc::new(TimestampMillisecondArray::from(vec![ + next_group_primary_end.value(), + ])) as ArrayRef, + TimeUnit::Microsecond => Arc::new(TimestampMicrosecondArray::from(vec![ + next_group_primary_end.value(), + ])) as ArrayRef, + TimeUnit::Nanosecond => Arc::new(TimestampNanosecondArray::from(vec![ + next_group_primary_end.value(), + ])) as ArrayRef, + }; + let primary_end_batch = DfRecordBatch::try_new(schema, vec![primary_end_array])?; + let res = filter.evaluate(&primary_end_batch)?; + let array = res.into_array(primary_end_batch.num_rows())?; + let filter = array.as_boolean().clone(); + let overlap = filter.iter().next().flatten(); + if let Some(false) = overlap { + Ok(true) + } else { + Ok(false) + } } /// Check if the given partition index is within the current group. @@ -749,9 +775,13 @@ impl PartSortStream { /// Internal method for sorting `Top` buffer (with limit). fn sort_top_buffer(&mut self) -> datafusion_common::Result { - let filter = Arc::new(RwLock::new(TopKDynamicFilters::new(Arc::new( - DynamicFilterPhysicalExpr::new(vec![], lit(true)), - )))); + let Some(filter) = self.filter.clone() else { + return internal_err!( + "TopKDynamicFilters must be provided when sorting with limit at {}", + snafu::location!() + ); + }; + let new_top_buffer = TopK::try_new( self.partition, self.schema().clone(), @@ -888,7 +918,7 @@ impl PartSortStream { // When TopK is fulfilled and we are switching to a new group, stop consuming further ranges if possible. // read from topk heap and determine whether we can stop earlier. - if !in_same_group && self.can_stop_early()? { + if !in_same_group && self.can_stop_early(&batch.schema())? { self.input_complete = true; self.evaluating_batch = None; return Ok(()); @@ -1127,7 +1157,7 @@ mod test { // The TopK result buffer is empty, so we cannot determine early-stop. // Ensure this path returns `Ok(false)` (and, importantly, does not panic). - assert!(!stream.can_stop_early().unwrap()); + assert!(!stream.can_stop_early(&schema).unwrap()); } #[ignore = "hard to gen expected data correctly here, TODO(discord9): fix it later"] @@ -2096,12 +2126,11 @@ mod test { // Group 1 (end=100) has 6 rows, TopK will keep top 4 // Group 2 (end=98) has 3 rows - threshold (96) < 98, so next group - // could theoretically have better values. But limit exhaustion stops us. - // Note: Data values must not overlap between ranges to avoid ambiguity. + // could theoretically have better values. Continue reading. let input_ranged_data = vec![ ( PartitionRange { - start: Timestamp::new(70, unit.into()), + start: Timestamp::new(90, unit.into()), end: Timestamp::new(100, unit.into()), num_rows: 6, identifier: 0, @@ -2888,4 +2917,88 @@ mod test { ) .await; } + + /// First group: [0,20), data: [0, 5, 15] + /// Second group: [10, 30), data: [21, 25, 29] + /// after first group, calling early stop manually, and check if filter is updated + #[tokio::test] + async fn test_early_stop_check_update_dyn_filter() { + let unit = TimeUnit::Millisecond; + let schema = Arc::new(Schema::new(vec![Field::new( + "ts", + DataType::Timestamp(unit, None), + false, + )])); + + let mock_input = Arc::new(MockInputExec::new(vec![vec![]], schema.clone())); + let exec = PartSortExec::try_new( + PhysicalSortExpr { + expr: Arc::new(Column::new("ts", 0)), + options: SortOptions { + descending: false, + ..Default::default() + }, + }, + Some(3), + vec![vec![ + PartitionRange { + start: Timestamp::new(0, unit.into()), + end: Timestamp::new(20, unit.into()), + num_rows: 3, + identifier: 1, + }, + PartitionRange { + start: Timestamp::new(10, unit.into()), + end: Timestamp::new(30, unit.into()), + num_rows: 3, + identifier: 1, + }, + ]], + mock_input.clone(), + ) + .unwrap(); + + let filter = exec.filter.clone().unwrap(); + let input_stream = mock_input + .execute(0, Arc::new(TaskContext::default())) + .unwrap(); + let mut stream = PartSortStream::new( + Arc::new(TaskContext::default()), + &exec, + Some(3), + input_stream, + vec![], + 0, + Some(filter.clone()), + ) + .unwrap(); + + // initially, snapshot_generation is 1 + assert_eq!(filter.read().expr().snapshot_generation(), 1); + let batch = + DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![0, 5, 15])]) + .unwrap(); + stream.push_buffer(batch).unwrap(); + + // after pushing first batch, snapshot_generation is updated to 2 + assert_eq!(filter.read().expr().snapshot_generation(), 2); + assert!(!stream.can_stop_early(&schema).unwrap()); + // still two as not updated + assert_eq!(filter.read().expr().snapshot_generation(), 2); + + let _ = stream.sort_top_buffer().unwrap(); + + let batch = + DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit, vec![21, 25, 29])]) + .unwrap(); + stream.push_buffer(batch).unwrap(); + // still two as not updated + assert_eq!(filter.read().expr().snapshot_generation(), 2); + let new = stream.sort_top_buffer().unwrap(); + // still two as not updated + assert_eq!(filter.read().expr().snapshot_generation(), 2); + + // dyn filter kick in, and filter out all rows >= 15(the filter is rows<15) + assert_eq!(new.num_rows(), 0) + } }