feat: support filter with windowed sort (#4960)

* feat: support windowed sort with where condition

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* fix split logic

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* modify fuzz test to reflect logic change

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* feat: handle sort that wont preserving partition

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* clean up

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* fix typo

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* fix test case and add more cases

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

---------

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
This commit is contained in:
Ruihang Xia
2024-11-08 10:49:36 +08:00
committed by GitHub
parent fcd0ceea94
commit 8efbafa538
3 changed files with 254 additions and 45 deletions

View File

@@ -19,6 +19,7 @@ use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use datafusion::physical_plan::repartition::RepartitionExec;
use datafusion::physical_plan::sorts::sort::SortExec;
use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
use datafusion::physical_plan::ExecutionPlan;
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::Result as DataFusionResult;
@@ -67,10 +68,12 @@ impl WindowedSortPhysicalRule {
.transform_down(|plan| {
if let Some(sort_exec) = plan.as_any().downcast_ref::<SortExec>() {
// TODO: support multiple expr in windowed sort
if !sort_exec.preserve_partitioning() || sort_exec.expr().len() != 1 {
if sort_exec.expr().len() != 1 {
return Ok(Transformed::no(plan));
}
let preserve_partitioning = sort_exec.preserve_partitioning();
let Some(scanner_info) = fetch_partition_range(sort_exec.input().clone())?
else {
return Ok(Transformed::no(plan));
@@ -111,11 +114,23 @@ impl WindowedSortPhysicalRule {
new_input,
)?;
return Ok(Transformed {
data: Arc::new(windowed_sort_exec),
transformed: true,
tnr: datafusion_common::tree_node::TreeNodeRecursion::Stop,
});
if !preserve_partitioning {
let order_preserving_merge = SortPreservingMergeExec::new(
sort_exec.expr().to_vec(),
Arc::new(windowed_sort_exec),
);
return Ok(Transformed {
data: Arc::new(order_preserving_merge),
transformed: true,
tnr: datafusion_common::tree_node::TreeNodeRecursion::Stop,
});
} else {
return Ok(Transformed {
data: Arc::new(windowed_sort_exec),
transformed: true,
tnr: datafusion_common::tree_node::TreeNodeRecursion::Stop,
});
}
}
Ok(Transformed::no(plan))
@@ -126,6 +141,7 @@ impl WindowedSortPhysicalRule {
}
}
#[derive(Debug)]
struct ScannerInfo {
partition_ranges: Vec<Vec<PartitionRange>>,
time_index: String,
@@ -136,11 +152,11 @@ fn fetch_partition_range(input: Arc<dyn ExecutionPlan>) -> DataFusionResult<Opti
let mut partition_ranges = None;
let mut time_index = None;
let mut tag_columns = None;
let mut is_batch_coalesced = false;
input.transform_up(|plan| {
// Unappliable case, reset the state.
if plan.as_any().is::<RepartitionExec>()
|| plan.as_any().is::<CoalesceBatchesExec>()
|| plan.as_any().is::<CoalescePartitionsExec>()
|| plan.as_any().is::<SortExec>()
|| plan.as_any().is::<WindowedSortExec>()
@@ -148,13 +164,19 @@ fn fetch_partition_range(input: Arc<dyn ExecutionPlan>) -> DataFusionResult<Opti
partition_ranges = None;
}
if plan.as_any().is::<CoalesceBatchesExec>() {
is_batch_coalesced = true;
}
if let Some(region_scan_exec) = plan.as_any().downcast_ref::<RegionScanExec>() {
partition_ranges = Some(region_scan_exec.get_uncollapsed_partition_ranges());
time_index = Some(region_scan_exec.time_index());
tag_columns = Some(region_scan_exec.tag_columns());
// set distinguish_partition_ranges to true, this is an incorrect workaround
region_scan_exec.with_distinguish_partition_range(true);
if !is_batch_coalesced {
region_scan_exec.with_distinguish_partition_range(true);
}
}
Ok(Transformed::no(plan))

View File

@@ -12,6 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//! Module for sorting input data within each [`PartitionRange`].
//!
//! This module defines the [`PartSortExec`] execution plan, which sorts each
//! partition ([`PartitionRange`]) independently based on the provided physical
//! sort expressions.
use std::any::Any;
use std::pin::Pin;
use std::sync::Arc;
@@ -36,7 +42,7 @@ use itertools::Itertools;
use snafu::location;
use store_api::region_engine::PartitionRange;
use crate::downcast_ts_array;
use crate::{array_iter_helper, downcast_ts_array};
/// Sort input within given PartitionRange
///
@@ -193,6 +199,7 @@ struct PartSortStream {
#[allow(dead_code)] // this is used under #[debug_assertions]
partition: usize,
cur_part_idx: usize,
evaluating_batch: Option<DfRecordBatch>,
metrics: BaselineMetrics,
}
@@ -218,6 +225,7 @@ impl PartSortStream {
partition_ranges,
partition,
cur_part_idx: 0,
evaluating_batch: None,
metrics: BaselineMetrics::new(&sort.metrics, partition),
}
}
@@ -288,9 +296,51 @@ impl PartSortStream {
Ok(())
}
/// Try find data whose value exceeds the current partition range.
///
/// Returns `None` if no such data is found, and `Some(idx)` where idx points to
/// the first data that exceeds the current partition range.
fn try_find_next_range(
&self,
sort_column: &ArrayRef,
) -> datafusion_common::Result<Option<usize>> {
if sort_column.len() == 0 {
return Ok(Some(0));
}
// check if the current partition index is out of range
if self.cur_part_idx >= self.partition_ranges.len() {
internal_err!(
"Partition index out of range: {} >= {}",
self.cur_part_idx,
self.partition_ranges.len()
)?;
}
let cur_range = self.partition_ranges[self.cur_part_idx];
let sort_column_iter = downcast_ts_array!(
sort_column.data_type() => (array_iter_helper, sort_column),
_ => internal_err!(
"Unsupported data type for sort column: {:?}",
sort_column.data_type()
)?,
);
for (idx, val) in sort_column_iter {
// ignore vacant time index data
if let Some(val) = val {
if val >= cur_range.end.value() || val < cur_range.start.value() {
return Ok(Some(idx));
}
}
}
Ok(None)
}
/// Sort and clear the buffer and return the sorted record batch
///
/// this function should return a empty record batch if the buffer is empty
/// this function will return a empty record batch if the buffer is empty
fn sort_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
if self.buffer.is_empty() {
return Ok(DfRecordBatch::new_empty(self.schema.clone()));
@@ -317,6 +367,9 @@ impl PartSortStream {
Some(format!("Fail to sort to indices at {}", location!())),
)
})?;
if indices.is_empty() {
return Ok(DfRecordBatch::new_empty(self.schema.clone()));
}
self.check_in_range(
&sort_column,
@@ -374,11 +427,58 @@ impl PartSortStream {
Ok(sorted)
}
fn split_batch(
&mut self,
batch: DfRecordBatch,
) -> datafusion_common::Result<Option<DfRecordBatch>> {
if batch.num_rows() == 0 {
return Ok(None);
}
let sort_column = self
.expression
.expr
.evaluate(&batch)?
.into_array(batch.num_rows())?;
let next_range_idx = self.try_find_next_range(&sort_column)?;
let Some(idx) = next_range_idx else {
self.buffer.push(batch);
// keep polling input for next batch
return Ok(None);
};
let this_range = batch.slice(0, idx);
let remaining_range = batch.slice(idx, batch.num_rows() - idx);
if this_range.num_rows() != 0 {
self.buffer.push(this_range);
}
// mark end of current PartitionRange
let sorted_batch = self.sort_buffer();
// step to next proper PartitionRange
self.cur_part_idx += 1;
let next_sort_column = sort_column.slice(idx, batch.num_rows() - idx);
if self.try_find_next_range(&next_sort_column)?.is_some() {
// remaining batch still contains data that exceeds the current partition range
// register the remaining batch for next polling
self.evaluating_batch = Some(remaining_range);
} else {
// remaining batch is within the current partition range
// push to the buffer and continue polling
if remaining_range.num_rows() != 0 {
self.buffer.push(remaining_range);
}
}
sorted_batch.map(|x| if x.num_rows() == 0 { None } else { Some(x) })
}
pub fn poll_next_inner(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<datafusion_common::Result<DfRecordBatch>>> {
loop {
// no more input, sort the buffer and return
if self.input_complete {
if self.buffer.is_empty() {
return Poll::Ready(None);
@@ -386,24 +486,30 @@ impl PartSortStream {
return Poll::Ready(Some(self.sort_buffer()));
}
}
// if there is a remaining batch being evaluated from last run,
// split on it instead of fetching new batch
if let Some(evaluating_batch) = self.evaluating_batch.take()
&& evaluating_batch.num_rows() != 0
{
if let Some(sorted_batch) = self.split_batch(evaluating_batch)? {
return Poll::Ready(Some(Ok(sorted_batch)));
} else {
continue;
}
}
// fetch next batch from input
let res = self.input.as_mut().poll_next(cx);
match res {
Poll::Ready(Some(Ok(batch))) => {
if batch.num_rows() == 0 {
// mark end of current PartitionRange
let sorted_batch = self.sort_buffer()?;
self.cur_part_idx += 1;
if sorted_batch.num_rows() == 0 {
// Current part is empty, continue polling next part.
continue;
}
if let Some(sorted_batch) = self.split_batch(batch)? {
return Poll::Ready(Some(Ok(sorted_batch)));
} else {
continue;
}
self.buffer.push(batch);
// keep polling until boundary(a empty RecordBatch) is reached
continue;
}
// input stream end, sort the buffer and return
// input stream end, mark and continue
Poll::Ready(None) => {
self.input_complete = true;
continue;
@@ -484,14 +590,19 @@ mod test {
let schema = Arc::new(schema);
let mut input_ranged_data = vec![];
let mut output_ranges = vec![];
let mut output_data = vec![];
// generate each input `PartitionRange`
for part_id in 0..rng.usize(0..part_cnt_bound) {
// generate each `PartitionRange`'s timestamp range
let (start, end) = if descending {
let end = bound_val
.map(|i| i.checked_sub(rng.i64(0..range_offset_bound)).expect("Bad luck, fuzzy test generate data that will overflow, change seed and try again"))
.unwrap_or_else(|| rng.i64(..));
.map(
|i| i
.checked_sub(rng.i64(0..range_offset_bound))
.expect("Bad luck, fuzzy test generate data that will overflow, change seed and try again")
)
.unwrap_or_else(|| rng.i64(-100000000..100000000));
bound_val = Some(end);
let start = end - rng.i64(1..range_size_bound);
let start = Timestamp::new(start, unit.clone().into());
@@ -514,13 +625,15 @@ mod test {
for _batch_idx in 0..rng.usize(1..batch_cnt_bound) {
let cnt = rng.usize(0..batch_size_bound) + 1;
let iter = 0..rng.usize(0..cnt);
let data_gen = iter
let mut data_gen = iter
.map(|_| rng.i64(start.value()..end.value()))
.collect_vec();
if data_gen.is_empty() {
// current batch is empty, skip
continue;
}
// mito always sort on ASC order
data_gen.sort();
per_part_sort_data.extend(data_gen.clone());
let arr = new_ts_array(unit.clone(), data_gen.clone());
let batch = DfRecordBatch::try_new(schema.clone(), vec![arr]).unwrap();
@@ -535,15 +648,35 @@ mod test {
};
input_ranged_data.push((range, batches));
if descending {
per_part_sort_data.sort_by(|a, b| b.cmp(a));
} else {
per_part_sort_data.sort();
}
output_ranges.push(range);
if per_part_sort_data.is_empty() {
continue;
}
output_data.push(per_part_sort_data);
output_data.extend_from_slice(&per_part_sort_data);
}
// adjust output data with adjacent PartitionRanges
let mut output_data_iter = output_data.iter().peekable();
let mut output_data = vec![];
for range in output_ranges.clone() {
let mut cur_data = vec![];
while let Some(val) = output_data_iter.peek() {
if **val < range.start.value() || **val >= range.end.value() {
break;
}
cur_data.push(*output_data_iter.next().unwrap());
}
if cur_data.is_empty() {
continue;
}
if descending {
cur_data.sort_by(|a, b| b.cmp(a));
} else {
cur_data.sort();
}
output_data.push(cur_data);
}
let expected_output = output_data
@@ -578,7 +711,7 @@ mod test {
((5, 10), vec![vec![5, 6], vec![7, 8]]),
],
false,
vec![vec![1, 2, 3, 4, 5, 6, 7, 8, 9], vec![5, 6, 7, 8]],
vec![vec![1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9]],
),
(
TimeUnit::Millisecond,
@@ -620,6 +753,25 @@ mod test {
true,
vec![],
),
(
TimeUnit::Millisecond,
vec![
(
(15, 20),
vec![vec![15, 17, 19, 10, 11, 12, 5, 6, 7, 8, 9, 1, 2, 3, 4]],
),
((10, 15), vec![]),
((5, 10), vec![]),
((0, 10), vec![]),
],
true,
vec![
vec![19, 17, 15],
vec![12, 11, 10],
vec![9, 8, 7, 6, 5],
vec![4, 3, 2, 1],
],
),
];
for (identifier, (unit, input_ranged_data, descending, expected_output)) in
@@ -664,10 +816,18 @@ mod test {
})
.collect_vec();
run_test(0, input_ranged_data, schema.clone(), opt, expected_output).await;
run_test(
identifier,
input_ranged_data,
schema.clone(),
opt,
expected_output,
)
.await;
}
}
#[allow(clippy::print_stdout)]
async fn run_test(
case_id: usize,
input_ranged_data: Vec<(PartitionRange, Vec<DfRecordBatch>)>,
@@ -692,20 +852,36 @@ mod test {
options: opt,
},
None,
vec![ranges],
vec![ranges.clone()],
Arc::new(mock_input),
);
let exec_stream = exec.execute(0, Arc::new(TaskContext::default())).unwrap();
let real_output = exec_stream.map(|r| r.unwrap()).collect::<Vec<_>>().await;
// a makeshift solution for compare large data
if real_output != expected_output {
let mut first_diff = 0;
for (idx, (lhs, rhs)) in real_output.iter().zip(expected_output.iter()).enumerate() {
if lhs != rhs {
first_diff = idx;
break;
}
}
println!("first diff batch at {}", first_diff);
println!(
"ranges: {:?}",
ranges
.into_iter()
.map(|r| (r.start.to_chrono_datetime(), r.end.to_chrono_datetime()))
.enumerate()
.collect::<Vec<_>>()
);
let mut full_msg = String::new();
{
let mut buf = Vec::with_capacity(10 * real_output.len());
for batch in &real_output {
for batch in real_output.iter().skip(first_diff) {
let mut rb_json: Vec<u8> = Vec::new();
let mut writer = ArrayWriter::new(&mut rb_json);
writer.write(batch).unwrap();
@@ -714,12 +890,12 @@ mod test {
buf.push(b',');
}
// TODO(discord9): better ways to print buf
let _buf = String::from_utf8_lossy(&buf);
full_msg += &format!("case_id:{case_id}, real_output");
let buf = String::from_utf8_lossy(&buf);
full_msg += &format!("\ncase_id:{case_id}, real_output \n{buf}\n");
}
{
let mut buf = Vec::with_capacity(10 * real_output.len());
for batch in &expected_output {
for batch in expected_output.iter().skip(first_diff) {
let mut rb_json: Vec<u8> = Vec::new();
let mut writer = ArrayWriter::new(&mut rb_json);
writer.write(batch).unwrap();
@@ -727,12 +903,16 @@ mod test {
buf.append(&mut rb_json);
buf.push(b',');
}
let _buf = String::from_utf8_lossy(&buf);
full_msg += &format!("case_id:{case_id}, expected_output");
let buf = String::from_utf8_lossy(&buf);
full_msg += &format!("case_id:{case_id}, expected_output \n{buf}");
}
panic!(
"case_{} failed, opt: {:?}, full msg: {}",
case_id, opt, full_msg
"case_{} failed, opt: {:?},\n real output has {} batches, {} rows, expected has {} batches with {} rows\nfull msg: {}",
case_id, opt,
real_output.len(),
real_output.iter().map(|x|x.num_rows()).sum::<usize>(),
expected_output.len(),
expected_output.iter().map(|x|x.num_rows()).sum::<usize>(), full_msg
);
}
}

View File

@@ -21,7 +21,7 @@ use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use arrow::array::{Array, ArrayRef, PrimitiveArray};
use arrow::array::{Array, ArrayRef};
use arrow::compute::SortColumn;
use arrow_schema::{DataType, SchemaRef, SortOptions};
use common_error::ext::{BoxedError, PlainError};
@@ -812,9 +812,16 @@ fn find_slice_from_range(
Ok((start, end - start))
}
/// Get an iterator from a primitive array.
///
/// Used with `downcast_ts_array`. The returned iter is wrapped with `.enumerate()`.
#[macro_export]
macro_rules! array_iter_helper {
($t:ty, $unit:expr, $arr:expr) => {{
let typed = $arr.as_any().downcast_ref::<PrimitiveArray<$t>>().unwrap();
let typed = $arr
.as_any()
.downcast_ref::<arrow::array::PrimitiveArray<$t>>()
.unwrap();
let iter = typed.iter().enumerate();
Box::new(iter) as Box<dyn Iterator<Item = (usize, Option<i64>)>>
}};