diff --git a/src/query/src/part_sort.rs b/src/query/src/part_sort.rs index d44304b415..f1574ec334 100644 --- a/src/query/src/part_sort.rs +++ b/src/query/src/part_sort.rs @@ -534,43 +534,69 @@ impl PartSortStream { } if topk { - self.compact_topk_buffer()?; - self.update_dynamic_filter(sort_data_type)?; + let threshold = self.compact_topk_buffer(sort_data_type)?; + self.update_dynamic_filter(sort_data_type, threshold)?; } Ok(()) } - fn compact_topk_buffer(&mut self) -> datafusion_common::Result<()> { + fn compact_topk_buffer( + &mut self, + sort_data_type: &DataType, + ) -> datafusion_common::Result> { let Some(limit) = self.limit else { - return Ok(()); + return Ok(None); }; let PartSortBuffer::TopK(buffer) = std::mem::replace(&mut self.buffer, PartSortBuffer::TopK(Vec::new())) else { - return Ok(()); + return Ok(None); }; if limit == 0 || buffer.is_empty() { self.buffer = PartSortBuffer::TopK(Vec::new()); - return Ok(()); + return Ok(None); } let total_rows: usize = buffer.iter().map(|batch| batch.num_rows()).sum(); if total_rows <= limit { self.buffer = PartSortBuffer::TopK(buffer); - return Ok(()); + return Ok(None); } let topk = self.sort_record_batches(&buffer, Some(limit), false)?; + let threshold = self.threshold_from_sorted_batch(&topk, sort_data_type)?; self.buffer = if topk.num_rows() == 0 { PartSortBuffer::TopK(Vec::new()) } else { PartSortBuffer::TopK(vec![topk]) }; - Ok(()) + Ok(threshold) + } + + fn threshold_from_sorted_batch( + &self, + batch: &DfRecordBatch, + sort_data_type: &DataType, + ) -> datafusion_common::Result> { + if batch.num_rows() == 0 { + return Ok(None); + } + + let threshold_idx = batch.num_rows() - 1; + let sort_column = self.expression.evaluate_to_sort_column(batch)?.values; + let threshold = downcast_ts_array!( + sort_data_type => (threshold_helper, sort_column, threshold_idx), + _ => internal_err!( + "Unsupported data type for sort column: {:?}", + sort_data_type + )?, + ); + + Ok(Some(threshold)) } fn topk_threshold( @@ -581,7 +607,7 @@ impl PartSortStream { return Ok(None); }; - if self.buffer.num_rows() < limit { + if limit == 0 || self.buffer.num_rows() < limit { return Ok(None); } @@ -691,13 +717,19 @@ impl PartSortStream { fn update_dynamic_filter( &mut self, sort_data_type: &DataType, + threshold: Option, ) -> datafusion_common::Result<()> { let Some(filter) = &self.dynamic_filter else { return Ok(()); }; - let Some(threshold) = self.topk_threshold(sort_data_type)? else { - return Ok(()); + let threshold = if let Some(threshold) = threshold { + threshold + } else { + let Some(threshold) = self.topk_threshold(sort_data_type)? else { + return Ok(()); + }; + threshold }; if self.dynamic_filter_threshold.as_ref() == Some(&threshold) { @@ -722,8 +754,13 @@ impl PartSortStream { return Ok(false); } - let Some(threshold) = self.topk_threshold(sort_data_type)? else { - return Ok(false); + let threshold = if let Some(threshold) = &self.dynamic_filter_threshold { + threshold.clone() + } else { + let Some(threshold) = self.topk_threshold(sort_data_type)? else { + return Ok(false); + }; + threshold }; let (_, start_idx, _) = self.range_groups[group_idx]; @@ -1881,6 +1918,59 @@ mod test { assert_eq!(stream.sort_buffer().unwrap(), expected); } + #[test] + fn test_topk_limit_zero_clears_buffer_without_threshold() { + let unit = TimeUnit::Millisecond; + let schema = Arc::new(Schema::new(vec![Field::new( + "ts", + DataType::Timestamp(unit, None), + false, + )])); + let sort_data_type = DataType::Timestamp(unit, None); + let partition_range = PartitionRange { + start: Timestamp::new(0, unit.into()), + end: Timestamp::new(10, unit.into()), + num_rows: 3, + identifier: 0, + }; + let mock_input = Arc::new(MockInputExec::new(vec![vec![]], schema.clone())); + let exec = PartSortExec::try_new( + PhysicalSortExpr { + expr: Arc::new(Column::new("ts", 0)), + options: SortOptions { + descending: true, + ..Default::default() + }, + }, + Some(0), + vec![vec![partition_range]], + mock_input.clone(), + ) + .unwrap(); + let input_stream = mock_input + .execute(0, Arc::new(TaskContext::default())) + .unwrap(); + let mut stream = PartSortStream::new( + Arc::new(TaskContext::default()), + &exec, + Some(0), + input_stream, + vec![partition_range], + 0, + ) + .unwrap(); + + stream + .push_buffer( + DfRecordBatch::try_new(schema, vec![new_ts_array(unit, vec![1, 2, 3])]).unwrap(), + &sort_data_type, + ) + .unwrap(); + + assert_eq!(stream.buffer.num_rows(), 0); + assert_eq!(stream.dynamic_filter_threshold, None); + } + /// Test that verifies early termination behavior. /// Once we've produced limit * num_partitions rows, we should stop /// pulling from input stream.