mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-24 17:00:37 +00:00
feat: reimplement limit in PartSort to reduce memory footprint (#5018)
* feat: support windowed sort with where condition Signed-off-by: Ruihang Xia <waynestxia@gmail.com> * fix split logic Signed-off-by: Ruihang Xia <waynestxia@gmail.com> * modify fuzz test to reflect logic change Signed-off-by: Ruihang Xia <waynestxia@gmail.com> * feat: handle sort that wont preserving partition Signed-off-by: Ruihang Xia <waynestxia@gmail.com> * clean up Signed-off-by: Ruihang Xia <waynestxia@gmail.com> * fix typo Signed-off-by: Ruihang Xia <waynestxia@gmail.com> * fix test case and add more cases Signed-off-by: Ruihang Xia <waynestxia@gmail.com> * basic impl Signed-off-by: Ruihang Xia <waynestxia@gmail.com> * install topk Signed-off-by: Ruihang Xia <waynestxia@gmail.com> * tests: add test for limit * add debug assertion Signed-off-by: Ruihang Xia <waynestxia@gmail.com> --------- Signed-off-by: Ruihang Xia <waynestxia@gmail.com> Co-authored-by: discord9 <discord9@163.com>
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
#![feature(int_roundings)]
|
||||
#![feature(trait_upcasting)]
|
||||
#![feature(try_blocks)]
|
||||
#![feature(stmt_expr_attributes)]
|
||||
|
||||
mod analyze;
|
||||
pub mod dataframe;
|
||||
|
||||
@@ -33,11 +33,11 @@ use datafusion::execution::{RecordBatchStream, TaskContext};
|
||||
use datafusion::physical_plan::coalesce_batches::concat_batches;
|
||||
use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
|
||||
use datafusion::physical_plan::{
|
||||
DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties,
|
||||
DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, TopK,
|
||||
};
|
||||
use datafusion_common::{internal_err, DataFusionError};
|
||||
use datafusion_physical_expr::PhysicalSortExpr;
|
||||
use futures::Stream;
|
||||
use futures::{Stream, StreamExt};
|
||||
use itertools::Itertools;
|
||||
use snafu::location;
|
||||
use store_api::region_engine::PartitionRange;
|
||||
@@ -108,7 +108,7 @@ impl PartSortExec {
|
||||
input_stream,
|
||||
self.partition_ranges[partition].clone(),
|
||||
partition,
|
||||
)) as _;
|
||||
)?) as _;
|
||||
|
||||
Ok(df_stream)
|
||||
}
|
||||
@@ -185,10 +185,28 @@ impl ExecutionPlan for PartSortExec {
|
||||
}
|
||||
}
|
||||
|
||||
enum PartSortBuffer {
|
||||
All(Vec<DfRecordBatch>),
|
||||
/// TopK buffer with row count.
|
||||
///
|
||||
/// Given this heap only keeps k element, the capacity of this buffer
|
||||
/// is not accurate, and is only used for empty check.
|
||||
Top(TopK, usize),
|
||||
}
|
||||
|
||||
impl PartSortBuffer {
|
||||
pub fn is_empty(&self) -> bool {
|
||||
match self {
|
||||
PartSortBuffer::All(v) => v.is_empty(),
|
||||
PartSortBuffer::Top(_, cnt) => *cnt == 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct PartSortStream {
|
||||
/// Memory pool for this stream
|
||||
reservation: MemoryReservation,
|
||||
buffer: Vec<DfRecordBatch>,
|
||||
buffer: PartSortBuffer,
|
||||
expression: PhysicalSortExpr,
|
||||
limit: Option<usize>,
|
||||
produced: usize,
|
||||
@@ -201,6 +219,8 @@ struct PartSortStream {
|
||||
cur_part_idx: usize,
|
||||
evaluating_batch: Option<DfRecordBatch>,
|
||||
metrics: BaselineMetrics,
|
||||
context: Arc<TaskContext>,
|
||||
root_metrics: ExecutionPlanMetricsSet,
|
||||
}
|
||||
|
||||
impl PartSortStream {
|
||||
@@ -211,11 +231,29 @@ impl PartSortStream {
|
||||
input: DfSendableRecordBatchStream,
|
||||
partition_ranges: Vec<PartitionRange>,
|
||||
partition: usize,
|
||||
) -> Self {
|
||||
Self {
|
||||
) -> datafusion_common::Result<Self> {
|
||||
let buffer = if let Some(limit) = limit {
|
||||
PartSortBuffer::Top(
|
||||
TopK::try_new(
|
||||
partition,
|
||||
sort.schema().clone(),
|
||||
vec![sort.expression.clone()],
|
||||
limit,
|
||||
context.session_config().batch_size(),
|
||||
context.runtime_env(),
|
||||
&sort.metrics,
|
||||
partition,
|
||||
)?,
|
||||
0,
|
||||
)
|
||||
} else {
|
||||
PartSortBuffer::All(Vec::new())
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
reservation: MemoryConsumer::new("PartSortStream".to_string())
|
||||
.register(&context.runtime_env().memory_pool),
|
||||
buffer: Vec::new(),
|
||||
buffer,
|
||||
expression: sort.expression.clone(),
|
||||
limit,
|
||||
produced: 0,
|
||||
@@ -227,7 +265,9 @@ impl PartSortStream {
|
||||
cur_part_idx: 0,
|
||||
evaluating_batch: None,
|
||||
metrics: BaselineMetrics::new(&sort.metrics, partition),
|
||||
}
|
||||
context,
|
||||
root_metrics: sort.metrics.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -338,16 +378,42 @@ impl PartSortStream {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn push_buffer(&mut self, batch: DfRecordBatch) -> datafusion_common::Result<()> {
|
||||
match &mut self.buffer {
|
||||
PartSortBuffer::All(v) => v.push(batch),
|
||||
PartSortBuffer::Top(top, cnt) => {
|
||||
*cnt += batch.num_rows();
|
||||
top.insert_batch(batch)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Sort and clear the buffer and return the sorted record batch
|
||||
///
|
||||
/// this function will return a empty record batch if the buffer is empty
|
||||
fn sort_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
|
||||
if self.buffer.is_empty() {
|
||||
match &mut self.buffer {
|
||||
PartSortBuffer::All(_) => self.sort_all_buffer(),
|
||||
PartSortBuffer::Top(_, _) => self.sort_top_buffer(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Internal method for sorting `All` buffer (without limit).
|
||||
fn sort_all_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
|
||||
let PartSortBuffer::All(buffer) =
|
||||
std::mem::replace(&mut self.buffer, PartSortBuffer::All(Vec::new()))
|
||||
else {
|
||||
unreachable!("buffer type is checked before and should be All variant")
|
||||
};
|
||||
|
||||
if buffer.is_empty() {
|
||||
return Ok(DfRecordBatch::new_empty(self.schema.clone()));
|
||||
}
|
||||
let mut sort_columns = Vec::with_capacity(self.buffer.len());
|
||||
let mut sort_columns = Vec::with_capacity(buffer.len());
|
||||
let mut opt = None;
|
||||
for batch in self.buffer.iter() {
|
||||
for batch in buffer.iter() {
|
||||
let sort_column = self.expression.evaluate_to_sort_column(batch)?;
|
||||
opt = opt.or(sort_column.options);
|
||||
sort_columns.push(sort_column.values);
|
||||
@@ -390,13 +456,13 @@ impl PartSortStream {
|
||||
})?;
|
||||
|
||||
// reserve memory for the concat input and sorted output
|
||||
let total_mem: usize = self.buffer.iter().map(|r| r.get_array_memory_size()).sum();
|
||||
let total_mem: usize = buffer.iter().map(|r| r.get_array_memory_size()).sum();
|
||||
self.reservation.try_grow(total_mem * 2)?;
|
||||
|
||||
let full_input = concat_batches(
|
||||
&self.schema,
|
||||
&self.buffer,
|
||||
self.buffer.iter().map(|r| r.num_rows()).sum(),
|
||||
&buffer,
|
||||
buffer.iter().map(|r| r.num_rows()).sum(),
|
||||
)
|
||||
.map_err(|e| {
|
||||
DataFusionError::ArrowError(
|
||||
@@ -418,8 +484,6 @@ impl PartSortStream {
|
||||
)
|
||||
})?;
|
||||
|
||||
// only clear after sorted for better debugging
|
||||
self.buffer.clear();
|
||||
self.produced += sorted.num_rows();
|
||||
drop(full_input);
|
||||
// here remove both buffer and full_input memory
|
||||
@@ -427,6 +491,68 @@ impl PartSortStream {
|
||||
Ok(sorted)
|
||||
}
|
||||
|
||||
/// Internal method for sorting `Top` buffer (with limit).
|
||||
fn sort_top_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
|
||||
let new_top_buffer = TopK::try_new(
|
||||
self.partition,
|
||||
self.schema().clone(),
|
||||
vec![self.expression.clone()],
|
||||
self.limit.unwrap(),
|
||||
self.context.session_config().batch_size(),
|
||||
self.context.runtime_env(),
|
||||
&self.root_metrics,
|
||||
self.partition,
|
||||
)?;
|
||||
let PartSortBuffer::Top(top_k, _) =
|
||||
std::mem::replace(&mut self.buffer, PartSortBuffer::Top(new_top_buffer, 0))
|
||||
else {
|
||||
unreachable!("buffer type is checked before and should be Top variant")
|
||||
};
|
||||
|
||||
let mut result_stream = top_k.emit()?;
|
||||
let mut placeholder_ctx = std::task::Context::from_waker(futures::task::noop_waker_ref());
|
||||
let mut results = vec![];
|
||||
let mut row_count = 0;
|
||||
// according to the current implementation of `TopK`, the result stream will always be ready
|
||||
loop {
|
||||
match result_stream.poll_next_unpin(&mut placeholder_ctx) {
|
||||
Poll::Ready(Some(batch)) => {
|
||||
let batch = batch?;
|
||||
row_count += batch.num_rows();
|
||||
results.push(batch);
|
||||
}
|
||||
Poll::Pending => {
|
||||
#[cfg(debug_assertions)]
|
||||
unreachable!("TopK result stream should always be ready")
|
||||
}
|
||||
Poll::Ready(None) => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let concat_batch = concat_batches(&self.schema, &results, row_count).map_err(|e| {
|
||||
DataFusionError::ArrowError(
|
||||
e,
|
||||
Some(format!(
|
||||
"Fail to concat top k result record batch when sorting at {}",
|
||||
location!()
|
||||
)),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(concat_batch)
|
||||
}
|
||||
|
||||
/// Try to split the input batch if it contains data that exceeds the current partition range.
|
||||
///
|
||||
/// When the input batch contains data that exceeds the current partition range, this function
|
||||
/// will split the input batch into two parts, the first part is within the current partition
|
||||
/// range will be merged and sorted with previous buffer, and the second part will be registered
|
||||
/// to `evaluating_batch` for next polling.
|
||||
///
|
||||
/// Returns `None` if the input batch is empty or fully within the current partition range, and
|
||||
/// `Some(batch)` otherwise.
|
||||
fn split_batch(
|
||||
&mut self,
|
||||
batch: DfRecordBatch,
|
||||
@@ -443,7 +569,7 @@ impl PartSortStream {
|
||||
|
||||
let next_range_idx = self.try_find_next_range(&sort_column)?;
|
||||
let Some(idx) = next_range_idx else {
|
||||
self.buffer.push(batch);
|
||||
self.push_buffer(batch)?;
|
||||
// keep polling input for next batch
|
||||
return Ok(None);
|
||||
};
|
||||
@@ -451,7 +577,7 @@ impl PartSortStream {
|
||||
let this_range = batch.slice(0, idx);
|
||||
let remaining_range = batch.slice(idx, batch.num_rows() - idx);
|
||||
if this_range.num_rows() != 0 {
|
||||
self.buffer.push(this_range);
|
||||
self.push_buffer(this_range)?;
|
||||
}
|
||||
// mark end of current PartitionRange
|
||||
let sorted_batch = self.sort_buffer();
|
||||
@@ -466,7 +592,7 @@ impl PartSortStream {
|
||||
// remaining batch is within the current partition range
|
||||
// push to the buffer and continue polling
|
||||
if remaining_range.num_rows() != 0 {
|
||||
self.buffer.push(remaining_range);
|
||||
self.push_buffer(remaining_range)?;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -556,9 +682,12 @@ mod test {
|
||||
#[tokio::test]
|
||||
async fn fuzzy_test() {
|
||||
let test_cnt = 100;
|
||||
// bound for total count of PartitionRange
|
||||
let part_cnt_bound = 100;
|
||||
// bound for timestamp range size and offset for each PartitionRange
|
||||
let range_size_bound = 100;
|
||||
let range_offset_bound = 100;
|
||||
// bound for batch count and size within each PartitionRange
|
||||
let batch_cnt_bound = 20;
|
||||
let batch_size_bound = 100;
|
||||
|
||||
@@ -575,6 +704,11 @@ mod test {
|
||||
descending,
|
||||
nulls_first,
|
||||
};
|
||||
let limit = if rng.bool() {
|
||||
Some(rng.usize(0..batch_cnt_bound * batch_size_bound))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let unit = match rng.u8(0..3) {
|
||||
0 => TimeUnit::Second,
|
||||
1 => TimeUnit::Millisecond,
|
||||
@@ -685,19 +819,39 @@ mod test {
|
||||
DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit.clone(), a)])
|
||||
.unwrap()
|
||||
})
|
||||
.map(|rb| {
|
||||
// trim expected output with limit
|
||||
if let Some(limit) = limit
|
||||
&& rb.num_rows() > limit
|
||||
{
|
||||
rb.slice(0, limit)
|
||||
} else {
|
||||
rb
|
||||
}
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
test_cases.push((
|
||||
case_id,
|
||||
unit,
|
||||
input_ranged_data,
|
||||
schema,
|
||||
opt,
|
||||
limit,
|
||||
expected_output,
|
||||
));
|
||||
}
|
||||
|
||||
for (case_id, _unit, input_ranged_data, schema, opt, expected_output) in test_cases {
|
||||
run_test(case_id, input_ranged_data, schema, opt, expected_output).await;
|
||||
for (case_id, _unit, input_ranged_data, schema, opt, limit, expected_output) in test_cases {
|
||||
run_test(
|
||||
case_id,
|
||||
input_ranged_data,
|
||||
schema,
|
||||
opt,
|
||||
limit,
|
||||
expected_output,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -711,6 +865,7 @@ mod test {
|
||||
((5, 10), vec![vec![5, 6], vec![7, 8]]),
|
||||
],
|
||||
false,
|
||||
None,
|
||||
vec![vec![1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9]],
|
||||
),
|
||||
(
|
||||
@@ -720,6 +875,7 @@ mod test {
|
||||
((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
|
||||
],
|
||||
true,
|
||||
None,
|
||||
vec![vec![9, 8, 7, 6, 5], vec![8, 7, 6, 5, 4, 3, 2, 1]],
|
||||
),
|
||||
(
|
||||
@@ -729,6 +885,7 @@ mod test {
|
||||
((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
|
||||
],
|
||||
true,
|
||||
None,
|
||||
vec![vec![8, 7, 6, 5, 4, 3, 2, 1]],
|
||||
),
|
||||
(
|
||||
@@ -740,6 +897,7 @@ mod test {
|
||||
((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
|
||||
],
|
||||
true,
|
||||
None,
|
||||
vec![vec![19, 18, 17], vec![8, 7, 6, 5, 4, 3, 2, 1]],
|
||||
),
|
||||
(
|
||||
@@ -751,6 +909,7 @@ mod test {
|
||||
((0, 10), vec![]),
|
||||
],
|
||||
true,
|
||||
None,
|
||||
vec![],
|
||||
),
|
||||
(
|
||||
@@ -765,6 +924,7 @@ mod test {
|
||||
((0, 10), vec![]),
|
||||
],
|
||||
true,
|
||||
None,
|
||||
vec![
|
||||
vec![19, 17, 15],
|
||||
vec![12, 11, 10],
|
||||
@@ -772,9 +932,24 @@ mod test {
|
||||
vec![4, 3, 2, 1],
|
||||
],
|
||||
),
|
||||
(
|
||||
TimeUnit::Millisecond,
|
||||
vec![
|
||||
(
|
||||
(15, 20),
|
||||
vec![vec![15, 17, 19, 10, 11, 12, 5, 6, 7, 8, 9, 1, 2, 3, 4]],
|
||||
),
|
||||
((10, 15), vec![]),
|
||||
((5, 10), vec![]),
|
||||
((0, 10), vec![]),
|
||||
],
|
||||
true,
|
||||
Some(2),
|
||||
vec![vec![19, 17], vec![12, 11], vec![9, 8], vec![4, 3]],
|
||||
),
|
||||
];
|
||||
|
||||
for (identifier, (unit, input_ranged_data, descending, expected_output)) in
|
||||
for (identifier, (unit, input_ranged_data, descending, limit, expected_output)) in
|
||||
testcases.into_iter().enumerate()
|
||||
{
|
||||
let schema = Schema::new(vec![Field::new(
|
||||
@@ -787,6 +962,7 @@ mod test {
|
||||
descending,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let input_ranged_data = input_ranged_data
|
||||
.into_iter()
|
||||
.map(|(range, data)| {
|
||||
@@ -821,6 +997,7 @@ mod test {
|
||||
input_ranged_data,
|
||||
schema.clone(),
|
||||
opt,
|
||||
limit,
|
||||
expected_output,
|
||||
)
|
||||
.await;
|
||||
@@ -833,8 +1010,19 @@ mod test {
|
||||
input_ranged_data: Vec<(PartitionRange, Vec<DfRecordBatch>)>,
|
||||
schema: SchemaRef,
|
||||
opt: SortOptions,
|
||||
limit: Option<usize>,
|
||||
expected_output: Vec<DfRecordBatch>,
|
||||
) {
|
||||
for rb in &expected_output {
|
||||
if let Some(limit) = limit {
|
||||
assert!(
|
||||
rb.num_rows() <= limit,
|
||||
"Expect row count in expected output's batch({}) <= limit({})",
|
||||
rb.num_rows(),
|
||||
limit
|
||||
);
|
||||
}
|
||||
}
|
||||
let (ranges, batches): (Vec<_>, Vec<_>) = input_ranged_data.clone().into_iter().unzip();
|
||||
|
||||
let batches = batches
|
||||
@@ -851,7 +1039,7 @@ mod test {
|
||||
expr: Arc::new(Column::new("ts", 0)),
|
||||
options: opt,
|
||||
},
|
||||
None,
|
||||
limit,
|
||||
vec![ranges.clone()],
|
||||
Arc::new(mock_input),
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user