Compare commits

...

13 Commits

Author SHA1 Message Date
Ruihang Xia
d4aa4159d4 feat: support windowed sort with where condition
Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
2024-11-04 19:34:03 +08:00
evenyag
960f6d821b feat: spawn block write wal 2024-11-04 17:35:12 +08:00
Ruihang Xia
9c5d044238 Merge branch 'main' into transform-count-min-max
Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
2024-11-01 17:45:28 +08:00
Ruihang Xia
70c354eed6 fix: the way to retrieve time index column
Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
2024-11-01 12:10:12 +08:00
Ruihang Xia
23bf663d58 feat: handle sort that wont preserving partition
Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
2024-10-31 22:13:36 +08:00
Ruihang Xia
817648eac5 Merge branch 'main' into transform-count-min-max
Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
2024-10-31 15:38:12 +08:00
Ruihang Xia
03b29439e2 Merge branch 'main' into transform-count-min-max
Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
2024-09-11 11:09:07 +08:00
Ruihang Xia
712f4ca0ef try sort partial commutative
Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
2024-09-09 21:08:59 +08:00
Ruihang Xia
60bacff57e ignore unmatched left and right greater
Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
2024-09-08 11:12:21 +08:00
Ruihang Xia
6208772ba4 Merge branch 'main' into transform-count-min-max
Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
2024-09-08 11:02:04 +08:00
Ruihang Xia
67184c0498 Merge branch 'main' into transform-count-min-max
Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
2024-09-05 14:30:47 +08:00
Ruihang Xia
1dd908fdf7 handle group by
Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
2024-09-05 12:50:13 +08:00
Ruihang Xia
8179b4798e feat: support transforming min/max/count aggr fn
Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
2024-09-04 22:17:31 +08:00
7 changed files with 371 additions and 30 deletions

View File

@@ -17,6 +17,7 @@ use std::any::Any;
use common_error::ext::ErrorExt; use common_error::ext::ErrorExt;
use common_macro::stack_trace_debug; use common_macro::stack_trace_debug;
use common_runtime::error::Error as RuntimeError; use common_runtime::error::Error as RuntimeError;
use common_runtime::JoinError;
use serde_json::error::Error as JsonError; use serde_json::error::Error as JsonError;
use snafu::{Location, Snafu}; use snafu::{Location, Snafu};
use store_api::storage::RegionId; use store_api::storage::RegionId;
@@ -306,6 +307,14 @@ pub enum Error {
#[snafu(implicit)] #[snafu(implicit)]
location: Location, location: Location,
}, },
#[snafu(display("Join error"))]
Join {
#[snafu(implicit)]
location: Location,
#[snafu(source)]
error: JoinError,
},
} }
impl ErrorExt for Error { impl ErrorExt for Error {

View File

@@ -31,8 +31,8 @@ use store_api::storage::RegionId;
use crate::error::{ use crate::error::{
AddEntryLogBatchSnafu, DiscontinuousLogIndexSnafu, Error, FetchEntrySnafu, AddEntryLogBatchSnafu, DiscontinuousLogIndexSnafu, Error, FetchEntrySnafu,
IllegalNamespaceSnafu, IllegalStateSnafu, InvalidProviderSnafu, OverrideCompactedEntrySnafu, IllegalNamespaceSnafu, IllegalStateSnafu, InvalidProviderSnafu, JoinSnafu,
RaftEngineSnafu, Result, StartGcTaskSnafu, StopGcTaskSnafu, OverrideCompactedEntrySnafu, RaftEngineSnafu, Result, StartGcTaskSnafu, StopGcTaskSnafu,
}; };
use crate::metrics; use crate::metrics;
use crate::raft_engine::backend::SYSTEM_NAMESPACE; use crate::raft_engine::backend::SYSTEM_NAMESPACE;
@@ -250,6 +250,12 @@ impl LogStore for RaftEngineLogStore {
.engine .engine
.write(&mut batch, sync) .write(&mut batch, sync)
.context(RaftEngineSnafu)?; .context(RaftEngineSnafu)?;
let engine = self.engine.clone();
let _ = common_runtime::spawn_blocking_global(move || {
engine.write(&mut batch, sync).context(RaftEngineSnafu)
})
.await
.context(JoinSnafu)?;
Ok(AppendBatchResponse { last_entry_ids }) Ok(AppendBatchResponse { last_entry_ids })
} }

View File

@@ -274,7 +274,7 @@ impl<'a> RuleChecker<'a> {
fn check_axis(&self) -> Result<()> { fn check_axis(&self) -> Result<()> {
for (col_index, axis) in self.axis.iter().enumerate() { for (col_index, axis) in self.axis.iter().enumerate() {
for (val, split_point) in axis { for (val, split_point) in axis {
if split_point.less_than_counter != 0 || !split_point.is_equal { if !split_point.is_equal {
UnclosedValueSnafu { UnclosedValueSnafu {
value: format!("{val:?}"), value: format!("{val:?}"),
column: self.rule.partition_columns[col_index].clone(), column: self.rule.partition_columns[col_index].clone(),
@@ -410,6 +410,7 @@ mod tests {
/// b <= h b >= s /// b <= h b >= s
/// ``` /// ```
#[test] #[test]
#[ignore = "don't check unmatched `>` and `<` for now"]
fn empty_expr_case_1() { fn empty_expr_case_1() {
// PARTITION ON COLUMNS (b) ( // PARTITION ON COLUMNS (b) (
// b <= 'h', // b <= 'h',
@@ -451,6 +452,7 @@ mod tests {
/// 10 20 /// 10 20
/// ``` /// ```
#[test] #[test]
#[ignore = "don't check unmatched `>` and `<` for now"]
fn empty_expr_case_2() { fn empty_expr_case_2() {
// PARTITION ON COLUMNS (b) ( // PARTITION ON COLUMNS (b) (
// a >= 100 AND b <= 10 OR a > 100 AND a <= 200 AND b <= 10 OR a >= 200 AND b > 10 AND b <= 20 OR a > 200 AND b <= 20 // a >= 100 AND b <= 10 OR a > 100 AND a <= 200 AND b <= 10 OR a >= 200 AND b > 10 AND b <= 20 OR a > 200 AND b <= 20
@@ -580,6 +582,7 @@ mod tests {
} }
#[test] #[test]
#[ignore = "don't check unmatched `>` and `<` for now"]
fn duplicate_expr_case_1() { fn duplicate_expr_case_1() {
// PARTITION ON COLUMNS (a) ( // PARTITION ON COLUMNS (a) (
// a <= 20, // a <= 20,

View File

@@ -15,8 +15,11 @@
use std::collections::HashSet; use std::collections::HashSet;
use std::sync::Arc; use std::sync::Arc;
use datafusion::functions_aggregate::sum::Sum;
use datafusion_expr::aggregate_function::AggregateFunction as BuiltInAggregateFunction;
use datafusion_expr::expr::{AggregateFunction, AggregateFunctionDefinition};
use datafusion_expr::utils::exprlist_to_columns; use datafusion_expr::utils::exprlist_to_columns;
use datafusion_expr::{Expr, LogicalPlan, UserDefinedLogicalNode}; use datafusion_expr::{AggregateUDF, Expr, LogicalPlan, UserDefinedLogicalNode};
use promql::extension_plan::{ use promql::extension_plan::{
EmptyMetric, InstantManipulate, RangeManipulate, SeriesDivide, SeriesNormalize, EmptyMetric, InstantManipulate, RangeManipulate, SeriesDivide, SeriesNormalize,
}; };
@@ -25,21 +28,91 @@ use crate::dist_plan::merge_sort::{merge_sort_transformer, MergeSortLogicalPlan}
use crate::dist_plan::MergeScanLogicalPlan; use crate::dist_plan::MergeScanLogicalPlan;
#[allow(dead_code)] #[allow(dead_code)]
pub enum Commutativity { pub enum Commutativity<T> {
Commutative, Commutative,
PartialCommutative, PartialCommutative,
ConditionalCommutative(Option<Transformer>), ConditionalCommutative(Option<Transformer<T>>),
TransformedCommutative(Option<Transformer>), TransformedCommutative(Option<Transformer<T>>),
NonCommutative, NonCommutative,
Unimplemented, Unimplemented,
/// For unrelated plans like DDL /// For unrelated plans like DDL
Unsupported, Unsupported,
} }
impl<T> Commutativity<T> {
/// Check if self is stricter than `lhs`
fn is_stricter_than(&self, lhs: &Self) -> bool {
match (lhs, self) {
(Commutativity::Commutative, Commutativity::Commutative) => false,
(Commutativity::Commutative, _) => true,
(
Commutativity::PartialCommutative,
Commutativity::Commutative | Commutativity::PartialCommutative,
) => false,
(Commutativity::PartialCommutative, _) => true,
(
Commutativity::ConditionalCommutative(_),
Commutativity::Commutative
| Commutativity::PartialCommutative
| Commutativity::ConditionalCommutative(_),
) => false,
(Commutativity::ConditionalCommutative(_), _) => true,
(
Commutativity::TransformedCommutative(_),
Commutativity::Commutative
| Commutativity::PartialCommutative
| Commutativity::ConditionalCommutative(_)
| Commutativity::TransformedCommutative(_),
) => false,
(Commutativity::TransformedCommutative(_), _) => true,
(
Commutativity::NonCommutative
| Commutativity::Unimplemented
| Commutativity::Unsupported,
_,
) => false,
}
}
/// Return a bare commutative level without any transformer
fn bare_level<To>(&self) -> Commutativity<To> {
match self {
Commutativity::Commutative => Commutativity::Commutative,
Commutativity::PartialCommutative => Commutativity::PartialCommutative,
Commutativity::ConditionalCommutative(_) => Commutativity::ConditionalCommutative(None),
Commutativity::TransformedCommutative(_) => Commutativity::TransformedCommutative(None),
Commutativity::NonCommutative => Commutativity::NonCommutative,
Commutativity::Unimplemented => Commutativity::Unimplemented,
Commutativity::Unsupported => Commutativity::Unsupported,
}
}
}
impl<T> std::fmt::Debug for Commutativity<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Commutativity::Commutative => write!(f, "Commutative"),
Commutativity::PartialCommutative => write!(f, "PartialCommutative"),
Commutativity::ConditionalCommutative(_) => write!(f, "ConditionalCommutative"),
Commutativity::TransformedCommutative(_) => write!(f, "TransformedCommutative"),
Commutativity::NonCommutative => write!(f, "NonCommutative"),
Commutativity::Unimplemented => write!(f, "Unimplemented"),
Commutativity::Unsupported => write!(f, "Unsupported"),
}
}
}
pub struct Categorizer {} pub struct Categorizer {}
impl Categorizer { impl Categorizer {
pub fn check_plan(plan: &LogicalPlan, partition_cols: Option<Vec<String>>) -> Commutativity { pub fn check_plan(
plan: &LogicalPlan,
partition_cols: Option<Vec<String>>,
) -> Commutativity<LogicalPlan> {
let partition_cols = partition_cols.unwrap_or_default(); let partition_cols = partition_cols.unwrap_or_default();
match plan { match plan {
@@ -47,21 +120,104 @@ impl Categorizer {
for expr in &proj.expr { for expr in &proj.expr {
let commutativity = Self::check_expr(expr); let commutativity = Self::check_expr(expr);
if !matches!(commutativity, Commutativity::Commutative) { if !matches!(commutativity, Commutativity::Commutative) {
return commutativity; return commutativity.bare_level();
} }
} }
Commutativity::Commutative Commutativity::Commutative
} }
// TODO(ruihang): Change this to Commutative once Like is supported in substrait // TODO(ruihang): Change this to Commutative once Like is supported in substrait
LogicalPlan::Filter(filter) => Self::check_expr(&filter.predicate), LogicalPlan::Filter(filter) => Self::check_expr(&filter.predicate).bare_level(),
LogicalPlan::Window(_) => Commutativity::Unimplemented, LogicalPlan::Window(_) => Commutativity::Unimplemented,
LogicalPlan::Aggregate(aggr) => { LogicalPlan::Aggregate(aggr) => {
// fast path: if the group_expr is a subset of partition_cols
if Self::check_partition(&aggr.group_expr, &partition_cols) { if Self::check_partition(&aggr.group_expr, &partition_cols) {
return Commutativity::Commutative; return Commutativity::Commutative;
} }
// check all children exprs and uses the strictest level common_telemetry::info!("[DEBUG] aggregate plan expr: {:?}", aggr.aggr_expr);
Commutativity::Unimplemented
// get all commutativity levels of aggregate exprs and find the strictest one
let aggr_expr_comm = aggr
.aggr_expr
.iter()
.map(Self::check_expr)
.collect::<Vec<_>>();
let mut strictest = Commutativity::Commutative;
for comm in &aggr_expr_comm {
if comm.is_stricter_than(&strictest) {
strictest = comm.bare_level();
}
}
common_telemetry::info!("[DEBUG] aggr_expr_comm: {:?}", aggr_expr_comm);
common_telemetry::info!("[DEBUG] strictest: {:?}", strictest);
// fast path: if any expr is commutative or non-commutative
if matches!(
strictest,
Commutativity::Commutative
| Commutativity::NonCommutative
| Commutativity::Unimplemented
| Commutativity::Unsupported
) {
return strictest.bare_level();
}
common_telemetry::info!("[DEBUG] continue for strictest",);
// collect expr transformers
let mut expr_transformer = Vec::with_capacity(aggr.aggr_expr.len());
for expr_comm in aggr_expr_comm {
match expr_comm {
Commutativity::Commutative => expr_transformer.push(None),
Commutativity::ConditionalCommutative(transformer) => {
expr_transformer.push(transformer.clone());
}
Commutativity::PartialCommutative => expr_transformer
.push(Some(Arc::new(expr_partial_commutative_transformer))),
_ => expr_transformer.push(None),
}
}
// build plan transformer
let transformer = Arc::new(move |plan: &LogicalPlan| {
if let LogicalPlan::Aggregate(aggr) = plan {
let mut new_plan = aggr.clone();
// transform aggr exprs
for (expr, transformer) in
new_plan.aggr_expr.iter_mut().zip(&expr_transformer)
{
if let Some(transformer) = transformer {
let new_expr = transformer(expr)?;
*expr = new_expr;
}
}
// transform group exprs
for expr in new_plan.group_expr.iter_mut() {
// if let Some(transformer) = transformer {
// let new_expr = transformer(expr)?;
// *expr = new_expr;
// }
let expr_name = expr.name_for_alias().expect("not a sort expr");
*expr = Expr::Column(expr_name.into());
}
common_telemetry::info!(
"[DEBUG] new plan aggr expr: {:?}, group expr: {:?}",
new_plan.aggr_expr,
new_plan.group_expr
);
Some(LogicalPlan::Aggregate(new_plan))
} else {
None
}
});
common_telemetry::info!("[DEBUG] done TransformedCommutative for aggr plan ");
Commutativity::TransformedCommutative(Some(transformer))
} }
LogicalPlan::Sort(_) => { LogicalPlan::Sort(_) => {
if partition_cols.is_empty() { if partition_cols.is_empty() {
@@ -113,7 +269,7 @@ impl Categorizer {
} }
} }
pub fn check_extension_plan(plan: &dyn UserDefinedLogicalNode) -> Commutativity { pub fn check_extension_plan(plan: &dyn UserDefinedLogicalNode) -> Commutativity<LogicalPlan> {
match plan.name() { match plan.name() {
name if name == EmptyMetric::name() name if name == EmptyMetric::name()
|| name == InstantManipulate::name() || name == InstantManipulate::name()
@@ -129,7 +285,7 @@ impl Categorizer {
} }
} }
pub fn check_expr(expr: &Expr) -> Commutativity { pub fn check_expr(expr: &Expr) -> Commutativity<Expr> {
match expr { match expr {
Expr::Column(_) Expr::Column(_)
| Expr::ScalarVariable(_, _) | Expr::ScalarVariable(_, _)
@@ -155,13 +311,14 @@ impl Categorizer {
| Expr::Case(_) | Expr::Case(_)
| Expr::Cast(_) | Expr::Cast(_)
| Expr::TryCast(_) | Expr::TryCast(_)
| Expr::AggregateFunction(_)
| Expr::WindowFunction(_) | Expr::WindowFunction(_)
| Expr::InList(_) | Expr::InList(_)
| Expr::InSubquery(_) | Expr::InSubquery(_)
| Expr::ScalarSubquery(_) | Expr::ScalarSubquery(_)
| Expr::Wildcard { .. } => Commutativity::Unimplemented, | Expr::Wildcard { .. } => Commutativity::Unimplemented,
Expr::AggregateFunction(aggr_fn) => Self::check_aggregate_fn(aggr_fn),
Expr::Alias(_) Expr::Alias(_)
| Expr::Unnest(_) | Expr::Unnest(_)
| Expr::GroupingSet(_) | Expr::GroupingSet(_)
@@ -170,6 +327,59 @@ impl Categorizer {
} }
} }
fn check_aggregate_fn(aggr_fn: &AggregateFunction) -> Commutativity<Expr> {
common_telemetry::info!("[DEBUG] checking aggr_fn: {:?}", aggr_fn);
match &aggr_fn.func_def {
AggregateFunctionDefinition::BuiltIn(func_def) => match func_def {
BuiltInAggregateFunction::Max | BuiltInAggregateFunction::Min => {
// Commutativity::PartialCommutative
common_telemetry::info!("[DEBUG] checking min/max: {:?}", aggr_fn);
let mut new_fn = aggr_fn.clone();
let col_name = Expr::AggregateFunction(aggr_fn.clone())
.name_for_alias()
.expect("not a sort expr");
let alias = col_name.clone();
new_fn.args = vec![Expr::Column(col_name.into())];
// new_fn.func_def =
// AggregateFunctionDefinition::BuiltIn(BuiltInAggregateFunction::Sum);
Commutativity::ConditionalCommutative(Some(Arc::new(move |_| {
common_telemetry::info!("[DEBUG] transforming min/max fn: {:?}", new_fn);
Some(Expr::AggregateFunction(new_fn.clone()).alias(alias.clone()))
})))
}
BuiltInAggregateFunction::Count => {
common_telemetry::info!("[DEBUG] checking count_fn: {:?}", aggr_fn);
let col_name = Expr::AggregateFunction(aggr_fn.clone())
.name_for_alias()
.expect("not a sort expr");
let sum_udf = Arc::new(AggregateUDF::new_from_impl(Sum::new()));
let alias = col_name.clone();
// let sum_func = Arc::new(AggregateFunction::new_udf(
// sum_udf,
// vec![Expr::Column(col_name.into())],
// false,
// None,
// None,
// None,
// ));
let mut sum_expr = aggr_fn.clone();
sum_expr.func_def = AggregateFunctionDefinition::UDF(sum_udf);
sum_expr.args = vec![Expr::Column(col_name.into())];
// let mut sum_fn = aggr_fn.clone();
// sum_fn.func_def =
// AggregateFunctionDefinition::BuiltIn(BuiltInAggregateFunction::Sum);
Commutativity::ConditionalCommutative(Some(Arc::new(move |_| {
common_telemetry::info!("[DEBUG] transforming sum_fn: {:?}", sum_expr);
Some(Expr::AggregateFunction(sum_expr.clone()).alias(alias.clone()))
})))
}
_ => Commutativity::Unimplemented,
},
AggregateFunctionDefinition::UDF(_) => Commutativity::Unimplemented,
}
}
/// Return true if the given expr and partition cols satisfied the rule. /// Return true if the given expr and partition cols satisfied the rule.
/// In this case the plan can be treated as fully commutative. /// In this case the plan can be treated as fully commutative.
fn check_partition(exprs: &[Expr], partition_cols: &[String]) -> bool { fn check_partition(exprs: &[Expr], partition_cols: &[String]) -> bool {
@@ -191,12 +401,16 @@ impl Categorizer {
} }
} }
pub type Transformer = Arc<dyn Fn(&LogicalPlan) -> Option<LogicalPlan>>; pub type Transformer<T> = Arc<dyn for<'a> Fn(&'a T) -> Option<T>>;
pub fn partial_commutative_transformer(plan: &LogicalPlan) -> Option<LogicalPlan> { pub fn partial_commutative_transformer(plan: &LogicalPlan) -> Option<LogicalPlan> {
Some(plan.clone()) Some(plan.clone())
} }
pub fn expr_partial_commutative_transformer(expr: &Expr) -> Option<Expr> {
Some(expr.clone())
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use datafusion_expr::{LogicalPlanBuilder, Sort}; use datafusion_expr::{LogicalPlanBuilder, Sort};

View File

@@ -19,6 +19,7 @@ use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::repartition::RepartitionExec;
use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::sorts::sort::SortExec;
use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
use datafusion::physical_plan::ExecutionPlan; use datafusion::physical_plan::ExecutionPlan;
use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::Result as DataFusionResult; use datafusion_common::Result as DataFusionResult;
@@ -67,10 +68,12 @@ impl WindowedSortPhysicalRule {
.transform_down(|plan| { .transform_down(|plan| {
if let Some(sort_exec) = plan.as_any().downcast_ref::<SortExec>() { if let Some(sort_exec) = plan.as_any().downcast_ref::<SortExec>() {
// TODO: support multiple expr in windowed sort // TODO: support multiple expr in windowed sort
if !sort_exec.preserve_partitioning() || sort_exec.expr().len() != 1 { if sort_exec.expr().len() != 1 {
return Ok(Transformed::no(plan)); return Ok(Transformed::no(plan));
} }
let preserve_partitioning = sort_exec.preserve_partitioning();
let Some(scanner_info) = fetch_partition_range(sort_exec.input().clone())? let Some(scanner_info) = fetch_partition_range(sort_exec.input().clone())?
else { else {
return Ok(Transformed::no(plan)); return Ok(Transformed::no(plan));
@@ -111,11 +114,23 @@ impl WindowedSortPhysicalRule {
new_input, new_input,
)?; )?;
return Ok(Transformed { if !preserve_partitioning {
data: Arc::new(windowed_sort_exec), let order_preserving_merge = SortPreservingMergeExec::new(
transformed: true, sort_exec.expr().to_vec(),
tnr: datafusion_common::tree_node::TreeNodeRecursion::Stop, Arc::new(windowed_sort_exec),
}); );
return Ok(Transformed {
data: Arc::new(order_preserving_merge),
transformed: true,
tnr: datafusion_common::tree_node::TreeNodeRecursion::Stop,
});
} else {
return Ok(Transformed {
data: Arc::new(windowed_sort_exec),
transformed: true,
tnr: datafusion_common::tree_node::TreeNodeRecursion::Stop,
});
}
} }
Ok(Transformed::no(plan)) Ok(Transformed::no(plan))
@@ -126,6 +141,7 @@ impl WindowedSortPhysicalRule {
} }
} }
#[derive(Debug)]
struct ScannerInfo { struct ScannerInfo {
partition_ranges: Vec<Vec<PartitionRange>>, partition_ranges: Vec<Vec<PartitionRange>>,
time_index: String, time_index: String,
@@ -136,11 +152,11 @@ fn fetch_partition_range(input: Arc<dyn ExecutionPlan>) -> DataFusionResult<Opti
let mut partition_ranges = None; let mut partition_ranges = None;
let mut time_index = None; let mut time_index = None;
let mut tag_columns = None; let mut tag_columns = None;
let mut is_batch_coalesced = false;
input.transform_up(|plan| { input.transform_up(|plan| {
// Unappliable case, reset the state. // Unappliable case, reset the state.
if plan.as_any().is::<RepartitionExec>() if plan.as_any().is::<RepartitionExec>()
|| plan.as_any().is::<CoalesceBatchesExec>()
|| plan.as_any().is::<CoalescePartitionsExec>() || plan.as_any().is::<CoalescePartitionsExec>()
|| plan.as_any().is::<SortExec>() || plan.as_any().is::<SortExec>()
|| plan.as_any().is::<WindowedSortExec>() || plan.as_any().is::<WindowedSortExec>()
@@ -148,13 +164,19 @@ fn fetch_partition_range(input: Arc<dyn ExecutionPlan>) -> DataFusionResult<Opti
partition_ranges = None; partition_ranges = None;
} }
if plan.as_any().is::<CoalesceBatchesExec>() {
is_batch_coalesced = true;
}
if let Some(region_scan_exec) = plan.as_any().downcast_ref::<RegionScanExec>() { if let Some(region_scan_exec) = plan.as_any().downcast_ref::<RegionScanExec>() {
partition_ranges = Some(region_scan_exec.get_uncollapsed_partition_ranges()); partition_ranges = Some(region_scan_exec.get_uncollapsed_partition_ranges());
time_index = Some(region_scan_exec.time_index()); time_index = Some(region_scan_exec.time_index());
tag_columns = Some(region_scan_exec.tag_columns()); tag_columns = Some(region_scan_exec.tag_columns());
// set distinguish_partition_ranges to true, this is an incorrect workaround // set distinguish_partition_ranges to true, this is an incorrect workaround
region_scan_exec.with_distinguish_partition_range(true); if !is_batch_coalesced {
region_scan_exec.with_distinguish_partition_range(true);
}
} }
Ok(Transformed::no(plan)) Ok(Transformed::no(plan))

View File

@@ -12,6 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
//! Module for sorting input data within each [`PartitionRange`].
//!
//! This module defines the [`PartSortExec`] execution plan, which sorts each
//! partition ([`PartitionRange`]) independently based on the provided physical
//! sort expressions.
use std::any::Any; use std::any::Any;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
@@ -36,7 +42,7 @@ use itertools::Itertools;
use snafu::location; use snafu::location;
use store_api::region_engine::PartitionRange; use store_api::region_engine::PartitionRange;
use crate::downcast_ts_array; use crate::{array_iter_helper, downcast_ts_array};
/// Sort input within given PartitionRange /// Sort input within given PartitionRange
/// ///
@@ -288,9 +294,51 @@ impl PartSortStream {
Ok(()) Ok(())
} }
/// Try find data whose value exceeds the current partition range.
///
/// Returns `None` if no such data is found, and `Some(idx)` where idx points to
/// the first data that exceeds the current partition range.
fn try_find_next_range(
&self,
sort_column: &ArrayRef,
) -> datafusion_common::Result<Option<usize>> {
if sort_column.len() == 0 {
return Ok(Some(0));
}
// check if the current partition index is out of range
if self.cur_part_idx >= self.partition_ranges.len() {
internal_err!(
"Partition index out of range: {} >= {}",
self.cur_part_idx,
self.partition_ranges.len()
)?;
}
let cur_range = self.partition_ranges[self.cur_part_idx];
let sort_column_iter = downcast_ts_array!(
sort_column.data_type() => (array_iter_helper, sort_column),
_ => internal_err!(
"Unsupported data type for sort column: {:?}",
sort_column.data_type()
)?,
);
for (idx, val) in sort_column_iter {
// ignore vacant time index data
if let Some(val) = val {
if val >= cur_range.end.value() || val < cur_range.start.value() {
return Ok(Some(idx));
}
}
}
Ok(None)
}
/// Sort and clear the buffer and return the sorted record batch /// Sort and clear the buffer and return the sorted record batch
/// ///
/// this function should return a empty record batch if the buffer is empty /// this function will return a empty record batch if the buffer is empty
fn sort_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> { fn sort_buffer(&mut self) -> datafusion_common::Result<DfRecordBatch> {
if self.buffer.is_empty() { if self.buffer.is_empty() {
return Ok(DfRecordBatch::new_empty(self.schema.clone())); return Ok(DfRecordBatch::new_empty(self.schema.clone()));
@@ -317,6 +365,9 @@ impl PartSortStream {
Some(format!("Fail to sort to indices at {}", location!())), Some(format!("Fail to sort to indices at {}", location!())),
) )
})?; })?;
if indices.is_empty() {
return Ok(DfRecordBatch::new_empty(self.schema.clone()));
}
self.check_in_range( self.check_in_range(
&sort_column, &sort_column,
@@ -379,6 +430,7 @@ impl PartSortStream {
cx: &mut Context<'_>, cx: &mut Context<'_>,
) -> Poll<Option<datafusion_common::Result<DfRecordBatch>>> { ) -> Poll<Option<datafusion_common::Result<DfRecordBatch>>> {
loop { loop {
// no more input, sort the buffer and return
if self.input_complete { if self.input_complete {
if self.buffer.is_empty() { if self.buffer.is_empty() {
return Poll::Ready(None); return Poll::Ready(None);
@@ -386,19 +438,47 @@ impl PartSortStream {
return Poll::Ready(Some(self.sort_buffer())); return Poll::Ready(Some(self.sort_buffer()));
} }
} }
// fetch next batch from input
let res = self.input.as_mut().poll_next(cx); let res = self.input.as_mut().poll_next(cx);
match res { match res {
Poll::Ready(Some(Ok(batch))) => { Poll::Ready(Some(Ok(batch))) => {
if batch.num_rows() == 0 { let sort_column = self
.expression
.expr
.evaluate(&batch)?
.into_array(batch.num_rows())?;
let next_range_idx = self.try_find_next_range(&sort_column)?;
// `Some` means the current range is finished, split the batch into two parts and sort
if let Some(idx) = next_range_idx {
let this_range = batch.slice(0, idx);
let next_range = batch.slice(idx, batch.num_rows() - idx);
if this_range.num_rows() != 0 {
self.buffer.push(this_range);
}
// mark end of current PartitionRange // mark end of current PartitionRange
let sorted_batch = self.sort_buffer()?; let sorted_batch = self.sort_buffer()?;
self.cur_part_idx += 1; let next_sort_column = sort_column.slice(idx, batch.num_rows() - idx);
// step to next proper PartitionRange
loop {
self.cur_part_idx += 1;
if next_sort_column.is_empty()
|| self.try_find_next_range(&next_sort_column)?.is_none()
{
break;
}
}
// push the next range to the buffer
if next_range.num_rows() != 0 {
self.buffer.push(next_range);
}
if sorted_batch.num_rows() == 0 { if sorted_batch.num_rows() == 0 {
// Current part is empty, continue polling next part. // Current part is empty, continue polling next part.
continue; continue;
} }
return Poll::Ready(Some(Ok(sorted_batch))); return Poll::Ready(Some(Ok(sorted_batch)));
} }
self.buffer.push(batch); self.buffer.push(batch);
// keep polling until boundary(a empty RecordBatch) is reached // keep polling until boundary(a empty RecordBatch) is reached
continue; continue;

View File

@@ -21,7 +21,7 @@ use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use arrow::array::{Array, ArrayRef, PrimitiveArray}; use arrow::array::{Array, ArrayRef};
use arrow::compute::SortColumn; use arrow::compute::SortColumn;
use arrow_schema::{DataType, SchemaRef, SortOptions}; use arrow_schema::{DataType, SchemaRef, SortOptions};
use common_error::ext::{BoxedError, PlainError}; use common_error::ext::{BoxedError, PlainError};
@@ -812,9 +812,16 @@ fn find_slice_from_range(
Ok((start, end - start)) Ok((start, end - start))
} }
/// Get an iterator from a primitive array.
///
/// Used with `downcast_ts_array`. The returned iter is wrapped with `.enumerate()`.
#[macro_export]
macro_rules! array_iter_helper { macro_rules! array_iter_helper {
($t:ty, $unit:expr, $arr:expr) => {{ ($t:ty, $unit:expr, $arr:expr) => {{
let typed = $arr.as_any().downcast_ref::<PrimitiveArray<$t>>().unwrap(); let typed = $arr
.as_any()
.downcast_ref::<arrow::array::PrimitiveArray<$t>>()
.unwrap();
let iter = typed.iter().enumerate(); let iter = typed.iter().enumerate();
Box::new(iter) as Box<dyn Iterator<Item = (usize, Option<i64>)>> Box::new(iter) as Box<dyn Iterator<Item = (usize, Option<i64>)>>
}}; }};