diff --git a/src/query/src/lib.rs b/src/query/src/lib.rs index 18b0b22582..435e9b4bcc 100644 --- a/src/query/src/lib.rs +++ b/src/query/src/lib.rs @@ -16,6 +16,7 @@ #![feature(int_roundings)] #![feature(trait_upcasting)] #![feature(try_blocks)] +#![feature(stmt_expr_attributes)] mod analyze; pub mod dataframe; diff --git a/src/query/src/part_sort.rs b/src/query/src/part_sort.rs index bcf8250205..9339c16b08 100644 --- a/src/query/src/part_sort.rs +++ b/src/query/src/part_sort.rs @@ -33,11 +33,11 @@ use datafusion::execution::{RecordBatchStream, TaskContext}; use datafusion::physical_plan::coalesce_batches::concat_batches; use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use datafusion::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, + DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, TopK, }; use datafusion_common::{internal_err, DataFusionError}; use datafusion_physical_expr::PhysicalSortExpr; -use futures::Stream; +use futures::{Stream, StreamExt}; use itertools::Itertools; use snafu::location; use store_api::region_engine::PartitionRange; @@ -108,7 +108,7 @@ impl PartSortExec { input_stream, self.partition_ranges[partition].clone(), partition, - )) as _; + )?) as _; Ok(df_stream) } @@ -185,10 +185,28 @@ impl ExecutionPlan for PartSortExec { } } +enum PartSortBuffer { + All(Vec), + /// TopK buffer with row count. + /// + /// Given this heap only keeps k element, the capacity of this buffer + /// is not accurate, and is only used for empty check. + Top(TopK, usize), +} + +impl PartSortBuffer { + pub fn is_empty(&self) -> bool { + match self { + PartSortBuffer::All(v) => v.is_empty(), + PartSortBuffer::Top(_, cnt) => *cnt == 0, + } + } +} + struct PartSortStream { /// Memory pool for this stream reservation: MemoryReservation, - buffer: Vec, + buffer: PartSortBuffer, expression: PhysicalSortExpr, limit: Option, produced: usize, @@ -201,6 +219,8 @@ struct PartSortStream { cur_part_idx: usize, evaluating_batch: Option, metrics: BaselineMetrics, + context: Arc, + root_metrics: ExecutionPlanMetricsSet, } impl PartSortStream { @@ -211,11 +231,29 @@ impl PartSortStream { input: DfSendableRecordBatchStream, partition_ranges: Vec, partition: usize, - ) -> Self { - Self { + ) -> datafusion_common::Result { + let buffer = if let Some(limit) = limit { + PartSortBuffer::Top( + TopK::try_new( + partition, + sort.schema().clone(), + vec![sort.expression.clone()], + limit, + context.session_config().batch_size(), + context.runtime_env(), + &sort.metrics, + partition, + )?, + 0, + ) + } else { + PartSortBuffer::All(Vec::new()) + }; + + Ok(Self { reservation: MemoryConsumer::new("PartSortStream".to_string()) .register(&context.runtime_env().memory_pool), - buffer: Vec::new(), + buffer, expression: sort.expression.clone(), limit, produced: 0, @@ -227,7 +265,9 @@ impl PartSortStream { cur_part_idx: 0, evaluating_batch: None, metrics: BaselineMetrics::new(&sort.metrics, partition), - } + context, + root_metrics: sort.metrics.clone(), + }) } } @@ -338,16 +378,42 @@ impl PartSortStream { Ok(None) } + fn push_buffer(&mut self, batch: DfRecordBatch) -> datafusion_common::Result<()> { + match &mut self.buffer { + PartSortBuffer::All(v) => v.push(batch), + PartSortBuffer::Top(top, cnt) => { + *cnt += batch.num_rows(); + top.insert_batch(batch)?; + } + } + + Ok(()) + } + /// Sort and clear the buffer and return the sorted record batch /// /// this function will return a empty record batch if the buffer is empty fn sort_buffer(&mut self) -> datafusion_common::Result { - if self.buffer.is_empty() { + match &mut self.buffer { + PartSortBuffer::All(_) => self.sort_all_buffer(), + PartSortBuffer::Top(_, _) => self.sort_top_buffer(), + } + } + + /// Internal method for sorting `All` buffer (without limit). + fn sort_all_buffer(&mut self) -> datafusion_common::Result { + let PartSortBuffer::All(buffer) = + std::mem::replace(&mut self.buffer, PartSortBuffer::All(Vec::new())) + else { + unreachable!("buffer type is checked before and should be All variant") + }; + + if buffer.is_empty() { return Ok(DfRecordBatch::new_empty(self.schema.clone())); } - let mut sort_columns = Vec::with_capacity(self.buffer.len()); + let mut sort_columns = Vec::with_capacity(buffer.len()); let mut opt = None; - for batch in self.buffer.iter() { + for batch in buffer.iter() { let sort_column = self.expression.evaluate_to_sort_column(batch)?; opt = opt.or(sort_column.options); sort_columns.push(sort_column.values); @@ -390,13 +456,13 @@ impl PartSortStream { })?; // reserve memory for the concat input and sorted output - let total_mem: usize = self.buffer.iter().map(|r| r.get_array_memory_size()).sum(); + let total_mem: usize = buffer.iter().map(|r| r.get_array_memory_size()).sum(); self.reservation.try_grow(total_mem * 2)?; let full_input = concat_batches( &self.schema, - &self.buffer, - self.buffer.iter().map(|r| r.num_rows()).sum(), + &buffer, + buffer.iter().map(|r| r.num_rows()).sum(), ) .map_err(|e| { DataFusionError::ArrowError( @@ -418,8 +484,6 @@ impl PartSortStream { ) })?; - // only clear after sorted for better debugging - self.buffer.clear(); self.produced += sorted.num_rows(); drop(full_input); // here remove both buffer and full_input memory @@ -427,6 +491,68 @@ impl PartSortStream { Ok(sorted) } + /// Internal method for sorting `Top` buffer (with limit). + fn sort_top_buffer(&mut self) -> datafusion_common::Result { + let new_top_buffer = TopK::try_new( + self.partition, + self.schema().clone(), + vec![self.expression.clone()], + self.limit.unwrap(), + self.context.session_config().batch_size(), + self.context.runtime_env(), + &self.root_metrics, + self.partition, + )?; + let PartSortBuffer::Top(top_k, _) = + std::mem::replace(&mut self.buffer, PartSortBuffer::Top(new_top_buffer, 0)) + else { + unreachable!("buffer type is checked before and should be Top variant") + }; + + let mut result_stream = top_k.emit()?; + let mut placeholder_ctx = std::task::Context::from_waker(futures::task::noop_waker_ref()); + let mut results = vec![]; + let mut row_count = 0; + // according to the current implementation of `TopK`, the result stream will always be ready + loop { + match result_stream.poll_next_unpin(&mut placeholder_ctx) { + Poll::Ready(Some(batch)) => { + let batch = batch?; + row_count += batch.num_rows(); + results.push(batch); + } + Poll::Pending => { + #[cfg(debug_assertions)] + unreachable!("TopK result stream should always be ready") + } + Poll::Ready(None) => { + break; + } + } + } + + let concat_batch = concat_batches(&self.schema, &results, row_count).map_err(|e| { + DataFusionError::ArrowError( + e, + Some(format!( + "Fail to concat top k result record batch when sorting at {}", + location!() + )), + ) + })?; + + Ok(concat_batch) + } + + /// Try to split the input batch if it contains data that exceeds the current partition range. + /// + /// When the input batch contains data that exceeds the current partition range, this function + /// will split the input batch into two parts, the first part is within the current partition + /// range will be merged and sorted with previous buffer, and the second part will be registered + /// to `evaluating_batch` for next polling. + /// + /// Returns `None` if the input batch is empty or fully within the current partition range, and + /// `Some(batch)` otherwise. fn split_batch( &mut self, batch: DfRecordBatch, @@ -443,7 +569,7 @@ impl PartSortStream { let next_range_idx = self.try_find_next_range(&sort_column)?; let Some(idx) = next_range_idx else { - self.buffer.push(batch); + self.push_buffer(batch)?; // keep polling input for next batch return Ok(None); }; @@ -451,7 +577,7 @@ impl PartSortStream { let this_range = batch.slice(0, idx); let remaining_range = batch.slice(idx, batch.num_rows() - idx); if this_range.num_rows() != 0 { - self.buffer.push(this_range); + self.push_buffer(this_range)?; } // mark end of current PartitionRange let sorted_batch = self.sort_buffer(); @@ -466,7 +592,7 @@ impl PartSortStream { // remaining batch is within the current partition range // push to the buffer and continue polling if remaining_range.num_rows() != 0 { - self.buffer.push(remaining_range); + self.push_buffer(remaining_range)?; } } @@ -556,9 +682,12 @@ mod test { #[tokio::test] async fn fuzzy_test() { let test_cnt = 100; + // bound for total count of PartitionRange let part_cnt_bound = 100; + // bound for timestamp range size and offset for each PartitionRange let range_size_bound = 100; let range_offset_bound = 100; + // bound for batch count and size within each PartitionRange let batch_cnt_bound = 20; let batch_size_bound = 100; @@ -575,6 +704,11 @@ mod test { descending, nulls_first, }; + let limit = if rng.bool() { + Some(rng.usize(0..batch_cnt_bound * batch_size_bound)) + } else { + None + }; let unit = match rng.u8(0..3) { 0 => TimeUnit::Second, 1 => TimeUnit::Millisecond, @@ -685,19 +819,39 @@ mod test { DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit.clone(), a)]) .unwrap() }) + .map(|rb| { + // trim expected output with limit + if let Some(limit) = limit + && rb.num_rows() > limit + { + rb.slice(0, limit) + } else { + rb + } + }) .collect_vec(); + test_cases.push(( case_id, unit, input_ranged_data, schema, opt, + limit, expected_output, )); } - for (case_id, _unit, input_ranged_data, schema, opt, expected_output) in test_cases { - run_test(case_id, input_ranged_data, schema, opt, expected_output).await; + for (case_id, _unit, input_ranged_data, schema, opt, limit, expected_output) in test_cases { + run_test( + case_id, + input_ranged_data, + schema, + opt, + limit, + expected_output, + ) + .await; } } @@ -711,6 +865,7 @@ mod test { ((5, 10), vec![vec![5, 6], vec![7, 8]]), ], false, + None, vec![vec![1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9]], ), ( @@ -720,6 +875,7 @@ mod test { ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]), ], true, + None, vec![vec![9, 8, 7, 6, 5], vec![8, 7, 6, 5, 4, 3, 2, 1]], ), ( @@ -729,6 +885,7 @@ mod test { ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]), ], true, + None, vec![vec![8, 7, 6, 5, 4, 3, 2, 1]], ), ( @@ -740,6 +897,7 @@ mod test { ((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]), ], true, + None, vec![vec![19, 18, 17], vec![8, 7, 6, 5, 4, 3, 2, 1]], ), ( @@ -751,6 +909,7 @@ mod test { ((0, 10), vec![]), ], true, + None, vec![], ), ( @@ -765,6 +924,7 @@ mod test { ((0, 10), vec![]), ], true, + None, vec![ vec![19, 17, 15], vec![12, 11, 10], @@ -772,9 +932,24 @@ mod test { vec![4, 3, 2, 1], ], ), + ( + TimeUnit::Millisecond, + vec![ + ( + (15, 20), + vec![vec![15, 17, 19, 10, 11, 12, 5, 6, 7, 8, 9, 1, 2, 3, 4]], + ), + ((10, 15), vec![]), + ((5, 10), vec![]), + ((0, 10), vec![]), + ], + true, + Some(2), + vec![vec![19, 17], vec![12, 11], vec![9, 8], vec![4, 3]], + ), ]; - for (identifier, (unit, input_ranged_data, descending, expected_output)) in + for (identifier, (unit, input_ranged_data, descending, limit, expected_output)) in testcases.into_iter().enumerate() { let schema = Schema::new(vec![Field::new( @@ -787,6 +962,7 @@ mod test { descending, ..Default::default() }; + let input_ranged_data = input_ranged_data .into_iter() .map(|(range, data)| { @@ -821,6 +997,7 @@ mod test { input_ranged_data, schema.clone(), opt, + limit, expected_output, ) .await; @@ -833,8 +1010,19 @@ mod test { input_ranged_data: Vec<(PartitionRange, Vec)>, schema: SchemaRef, opt: SortOptions, + limit: Option, expected_output: Vec, ) { + for rb in &expected_output { + if let Some(limit) = limit { + assert!( + rb.num_rows() <= limit, + "Expect row count in expected output's batch({}) <= limit({})", + rb.num_rows(), + limit + ); + } + } let (ranges, batches): (Vec<_>, Vec<_>) = input_ranged_data.clone().into_iter().unzip(); let batches = batches @@ -851,7 +1039,7 @@ mod test { expr: Arc::new(Column::new("ts", 0)), options: opt, }, - None, + limit, vec![ranges.clone()], Arc::new(mock_input), );