feat: reimplement limit in PartSort to reduce memory footprint (#5018)

* feat: support windowed sort with where condition

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* fix split logic

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* modify fuzz test to reflect logic change

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* feat: handle sort that wont preserving partition

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* clean up

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* fix typo

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* fix test case and add more cases

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* basic impl

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* install topk

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* tests: add test for limit

* add debug assertion

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

---------

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
Co-authored-by: discord9 <discord9@163.com>
This commit is contained in:
Ruihang Xia
2024-11-20 15:11:49 +08:00
committed by GitHub
parent db345c92df
commit 6a958e2c36
2 changed files with 212 additions and 23 deletions

View File

@@ -16,6 +16,7 @@
#![feature(int_roundings)]
#![feature(trait_upcasting)]
#![feature(try_blocks)]
#![feature(stmt_expr_attributes)]
mod analyze;
pub mod dataframe;

View File

@@ -33,11 +33,11 @@ use datafusion::execution::{RecordBatchStream, TaskContext};
use datafusion::physical_plan::coalesce_batches::concat_batches;
use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
use datafusion::physical_plan::{
DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties,
DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, TopK,
};
use datafusion_common::{internal_err, DataFusionError};
use datafusion_physical_expr::PhysicalSortExpr;
use futures::Stream;
use futures::{Stream, StreamExt};
use itertools::Itertools;
use snafu::location;
use store_api::region_engine::PartitionRange;
@@ -108,7 +108,7 @@ impl PartSortExec {
input_stream,
self.partition_ranges[partition].clone(),
partition,
)) as _;
)?) as _;
Ok(df_stream)
}
@@ -185,10 +185,28 @@ impl ExecutionPlan for PartSortExec {
}
}
enum PartSortBuffer {
All(Vec<DfRecordBatch>),
/// TopK buffer with row count.
///
/// Given this heap only keeps k element, the capacity of this buffer
/// is not accurate, and is only used for empty check.
Top(TopK, usize),
}
impl PartSortBuffer {
pub fn is_empty(&self) -> bool {
match self {
PartSortBuffer::All(v) => v.is_empty(),
PartSortBuffer::Top(_, cnt) => *cnt == 0,
}
}
}
struct PartSortStream {
/// Memory pool for this stream
reservation: MemoryReservation,
buffer: Vec<DfRecordBatch>,
buffer: PartSortBuffer,
expression: PhysicalSortExpr,
limit: Option<usize>,
produced: usize,
@@ -201,6 +219,8 @@ struct PartSortStream {
cur_part_idx: usize,
evaluating_batch: Option<DfRecordBatch>,
metrics: BaselineMetrics,
context: Arc<TaskContext>,
root_metrics: ExecutionPlanMetricsSet,
}
impl PartSortStream {
@@ -211,11 +231,29 @@ impl PartSortStream {
input: DfSendableRecordBatchStream,
partition_ranges: Vec<PartitionRange>,
partition: usize,
) -> Self {
Self {
) -> datafusion_common::Result<Self> {
let buffer = if let Some(limit) = limit {
PartSortBuffer::Top(
TopK::try_new(
partition,
sort.schema().clone(),
vec![sort.expression.clone()],
limit,
context.session_config().batch_size(),
context.runtime_env(),
&sort.metrics,
partition,
)?,
0,
)
} else {
PartSortBuffer::All(Vec::new())
};
Ok(Self {
reservation: MemoryConsumer::new("PartSortStream".to_string())
.register(&context.runtime_env().memory_pool),
buffer: Vec::new(),
buffer,
expression: sort.expression.clone(),
limit,
produced: 0,
@@ -227,7 +265,9 @@ impl PartSortStream {
cur_part_idx: 0,
evaluating_batch: None,
metrics: BaselineMetrics::new(&sort.metrics, partition),
}
context,
root_metrics: sort.metrics.clone(),
})
}
}
@@ -338,16 +378,42 @@ impl PartSortStream {
Ok(None)
}
fn push_buffer(&mut self, batch: DfRecordBatch) -> datafusion_common::Result<()> {
match &mut self.buffer {
PartSortBuffer::All(v) => v.push(batch),
PartSortBuffer::Top(top, cnt) => {
*cnt += batch.num_rows();
top.insert_batch(batch)?;
}
}
Ok(())
}
/// Sort and clear the buffer and return the sorted record batch
///
/// this function will return a empty record batch if the buffer is empty
fn sort_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
if self.buffer.is_empty() {
match &mut self.buffer {
PartSortBuffer::All(_) => self.sort_all_buffer(),
PartSortBuffer::Top(_, _) => self.sort_top_buffer(),
}
}
/// Internal method for sorting `All` buffer (without limit).
fn sort_all_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
let PartSortBuffer::All(buffer) =
std::mem::replace(&mut self.buffer, PartSortBuffer::All(Vec::new()))
else {
unreachable!("buffer type is checked before and should be All variant")
};
if buffer.is_empty() {
return Ok(DfRecordBatch::new_empty(self.schema.clone()));
}
let mut sort_columns = Vec::with_capacity(self.buffer.len());
let mut sort_columns = Vec::with_capacity(buffer.len());
let mut opt = None;
for batch in self.buffer.iter() {
for batch in buffer.iter() {
let sort_column = self.expression.evaluate_to_sort_column(batch)?;
opt = opt.or(sort_column.options);
sort_columns.push(sort_column.values);
@@ -390,13 +456,13 @@ impl PartSortStream {
})?;
// reserve memory for the concat input and sorted output
let total_mem: usize = self.buffer.iter().map(|r| r.get_array_memory_size()).sum();
let total_mem: usize = buffer.iter().map(|r| r.get_array_memory_size()).sum();
self.reservation.try_grow(total_mem * 2)?;
let full_input = concat_batches(
&self.schema,
&self.buffer,
self.buffer.iter().map(|r| r.num_rows()).sum(),
&buffer,
buffer.iter().map(|r| r.num_rows()).sum(),
)
.map_err(|e| {
DataFusionError::ArrowError(
@@ -418,8 +484,6 @@ impl PartSortStream {
)
})?;
// only clear after sorted for better debugging
self.buffer.clear();
self.produced += sorted.num_rows();
drop(full_input);
// here remove both buffer and full_input memory
@@ -427,6 +491,68 @@ impl PartSortStream {
Ok(sorted)
}
/// Internal method for sorting `Top` buffer (with limit).
fn sort_top_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
let new_top_buffer = TopK::try_new(
self.partition,
self.schema().clone(),
vec![self.expression.clone()],
self.limit.unwrap(),
self.context.session_config().batch_size(),
self.context.runtime_env(),
&self.root_metrics,
self.partition,
)?;
let PartSortBuffer::Top(top_k, _) =
std::mem::replace(&mut self.buffer, PartSortBuffer::Top(new_top_buffer, 0))
else {
unreachable!("buffer type is checked before and should be Top variant")
};
let mut result_stream = top_k.emit()?;
let mut placeholder_ctx = std::task::Context::from_waker(futures::task::noop_waker_ref());
let mut results = vec![];
let mut row_count = 0;
// according to the current implementation of `TopK`, the result stream will always be ready
loop {
match result_stream.poll_next_unpin(&mut placeholder_ctx) {
Poll::Ready(Some(batch)) => {
let batch = batch?;
row_count += batch.num_rows();
results.push(batch);
}
Poll::Pending => {
#[cfg(debug_assertions)]
unreachable!("TopK result stream should always be ready")
}
Poll::Ready(None) => {
break;
}
}
}
let concat_batch = concat_batches(&self.schema, &results, row_count).map_err(|e| {
DataFusionError::ArrowError(
e,
Some(format!(
"Fail to concat top k result record batch when sorting at {}",
location!()
)),
)
})?;
Ok(concat_batch)
}
/// Try to split the input batch if it contains data that exceeds the current partition range.
///
/// When the input batch contains data that exceeds the current partition range, this function
/// will split the input batch into two parts, the first part is within the current partition
/// range will be merged and sorted with previous buffer, and the second part will be registered
/// to `evaluating_batch` for next polling.
///
/// Returns `None` if the input batch is empty or fully within the current partition range, and
/// `Some(batch)` otherwise.
fn split_batch(
&mut self,
batch: DfRecordBatch,
@@ -443,7 +569,7 @@ impl PartSortStream {
let next_range_idx = self.try_find_next_range(&sort_column)?;
let Some(idx) = next_range_idx else {
self.buffer.push(batch);
self.push_buffer(batch)?;
// keep polling input for next batch
return Ok(None);
};
@@ -451,7 +577,7 @@ impl PartSortStream {
let this_range = batch.slice(0, idx);
let remaining_range = batch.slice(idx, batch.num_rows() - idx);
if this_range.num_rows() != 0 {
self.buffer.push(this_range);
self.push_buffer(this_range)?;
}
// mark end of current PartitionRange
let sorted_batch = self.sort_buffer();
@@ -466,7 +592,7 @@ impl PartSortStream {
// remaining batch is within the current partition range
// push to the buffer and continue polling
if remaining_range.num_rows() != 0 {
self.buffer.push(remaining_range);
self.push_buffer(remaining_range)?;
}
}
@@ -556,9 +682,12 @@ mod test {
#[tokio::test]
async fn fuzzy_test() {
let test_cnt = 100;
// bound for total count of PartitionRange
let part_cnt_bound = 100;
// bound for timestamp range size and offset for each PartitionRange
let range_size_bound = 100;
let range_offset_bound = 100;
// bound for batch count and size within each PartitionRange
let batch_cnt_bound = 20;
let batch_size_bound = 100;
@@ -575,6 +704,11 @@ mod test {
descending,
nulls_first,
};
let limit = if rng.bool() {
Some(rng.usize(0..batch_cnt_bound * batch_size_bound))
} else {
None
};
let unit = match rng.u8(0..3) {
0 => TimeUnit::Second,
1 => TimeUnit::Millisecond,
@@ -685,19 +819,39 @@ mod test {
DfRecordBatch::try_new(schema.clone(), vec![new_ts_array(unit.clone(), a)])
.unwrap()
})
.map(|rb| {
// trim expected output with limit
if let Some(limit) = limit
&& rb.num_rows() > limit
{
rb.slice(0, limit)
} else {
rb
}
})
.collect_vec();
test_cases.push((
case_id,
unit,
input_ranged_data,
schema,
opt,
limit,
expected_output,
));
}
for (case_id, _unit, input_ranged_data, schema, opt, expected_output) in test_cases {
run_test(case_id, input_ranged_data, schema, opt, expected_output).await;
for (case_id, _unit, input_ranged_data, schema, opt, limit, expected_output) in test_cases {
run_test(
case_id,
input_ranged_data,
schema,
opt,
limit,
expected_output,
)
.await;
}
}
@@ -711,6 +865,7 @@ mod test {
((5, 10), vec![vec![5, 6], vec![7, 8]]),
],
false,
None,
vec![vec![1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9]],
),
(
@@ -720,6 +875,7 @@ mod test {
((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
],
true,
None,
vec![vec![9, 8, 7, 6, 5], vec![8, 7, 6, 5, 4, 3, 2, 1]],
),
(
@@ -729,6 +885,7 @@ mod test {
((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
],
true,
None,
vec![vec![8, 7, 6, 5, 4, 3, 2, 1]],
),
(
@@ -740,6 +897,7 @@ mod test {
((0, 10), vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8]]),
],
true,
None,
vec![vec![19, 18, 17], vec![8, 7, 6, 5, 4, 3, 2, 1]],
),
(
@@ -751,6 +909,7 @@ mod test {
((0, 10), vec![]),
],
true,
None,
vec![],
),
(
@@ -765,6 +924,7 @@ mod test {
((0, 10), vec![]),
],
true,
None,
vec![
vec![19, 17, 15],
vec![12, 11, 10],
@@ -772,9 +932,24 @@ mod test {
vec![4, 3, 2, 1],
],
),
(
TimeUnit::Millisecond,
vec![
(
(15, 20),
vec![vec![15, 17, 19, 10, 11, 12, 5, 6, 7, 8, 9, 1, 2, 3, 4]],
),
((10, 15), vec![]),
((5, 10), vec![]),
((0, 10), vec![]),
],
true,
Some(2),
vec![vec![19, 17], vec![12, 11], vec![9, 8], vec![4, 3]],
),
];
for (identifier, (unit, input_ranged_data, descending, expected_output)) in
for (identifier, (unit, input_ranged_data, descending, limit, expected_output)) in
testcases.into_iter().enumerate()
{
let schema = Schema::new(vec![Field::new(
@@ -787,6 +962,7 @@ mod test {
descending,
..Default::default()
};
let input_ranged_data = input_ranged_data
.into_iter()
.map(|(range, data)| {
@@ -821,6 +997,7 @@ mod test {
input_ranged_data,
schema.clone(),
opt,
limit,
expected_output,
)
.await;
@@ -833,8 +1010,19 @@ mod test {
input_ranged_data: Vec<(PartitionRange, Vec<DfRecordBatch>)>,
schema: SchemaRef,
opt: SortOptions,
limit: Option<usize>,
expected_output: Vec<DfRecordBatch>,
) {
for rb in &expected_output {
if let Some(limit) = limit {
assert!(
rb.num_rows() <= limit,
"Expect row count in expected output's batch({}) <= limit({})",
rb.num_rows(),
limit
);
}
}
let (ranges, batches): (Vec<_>, Vec<_>) = input_ranged_data.clone().into_iter().unzip();
let batches = batches
@@ -851,7 +1039,7 @@ mod test {
expr: Arc::new(Column::new("ts", 0)),
options: opt,
},
None,
limit,
vec![ranges.clone()],
Arc::new(mock_input),
);