diff --git a/src/query/src/dist_plan/analyzer.rs b/src/query/src/dist_plan/analyzer.rs index bbb3e5ddd9..a81e62db1c 100644 --- a/src/query/src/dist_plan/analyzer.rs +++ b/src/query/src/dist_plan/analyzer.rs @@ -158,7 +158,10 @@ impl PlanRewriter { return true; } - match Categorizer::check_plan(plan, self.partition_cols.clone()) { + let result = Categorizer::check_plan(plan, self.partition_cols.clone()); + common_telemetry::info!("[DEBUG] Categorizer result: {:?}", result); + + match result { Commutativity::Commutative => {} Commutativity::PartialCommutative => { if let Some(plan) = partial_commutative_transformer(plan) { @@ -169,6 +172,7 @@ impl PlanRewriter { if let Some(transformer) = transformer && let Some(plan) = transformer(plan) { + common_telemetry::info!("ConditionalCommutative new plan: {:?}", plan); self.stage.push(plan) } } @@ -176,6 +180,7 @@ impl PlanRewriter { if let Some(transformer) = transformer && let Some(plan) = transformer(plan) { + common_telemetry::info!("TransformedCommutative new plan: {:?}", plan); self.stage.push(plan) } } @@ -277,8 +282,9 @@ impl TreeNodeRewriter for PlanRewriter { // add merge scan as the new root let mut node = MergeScanLogicalPlan::new(node, false).into_logical_plan(); // expand stages + common_telemetry::info!("[DEBUG] end with stage, expanding: {:?}", self.stage); for new_stage in self.stage.drain(..) { - node = new_stage.with_new_exprs(node.expressions(), vec![node.clone()])? + node = new_stage.with_new_exprs(new_stage.expressions(), vec![node.clone()])? } self.set_expanded(); @@ -288,12 +294,14 @@ impl TreeNodeRewriter for PlanRewriter { // TODO(ruihang): avoid this clone if self.should_expand(&parent.clone()) { - // TODO(ruihang): does this work for nodes with multiple children?; + common_telemetry::info!("[DEBUG] expand with stage: {:?}", self.stage); + // TODO(ruihang): does this work for nodes with multiple children? // replace the current node with expanded one let mut node = MergeScanLogicalPlan::new(node, false).into_logical_plan(); // expand stages for new_stage in self.stage.drain(..) { - node = new_stage.with_new_exprs(node.expressions(), vec![node.clone()])? + node = new_stage.with_new_exprs(new_stage.expressions(), vec![node.clone()])? + // node = new_stage.with_new } self.set_expanded(); diff --git a/src/query/src/dist_plan/commutativity.rs b/src/query/src/dist_plan/commutativity.rs index 6da9d4bf92..4649980391 100644 --- a/src/query/src/dist_plan/commutativity.rs +++ b/src/query/src/dist_plan/commutativity.rs @@ -15,8 +15,11 @@ use std::collections::HashSet; use std::sync::Arc; +use datafusion::functions_aggregate::sum::Sum; +use datafusion_expr::aggregate_function::AggregateFunction as BuiltInAggregateFunction; +use datafusion_expr::expr::{AggregateFunction, AggregateFunctionDefinition}; use datafusion_expr::utils::exprlist_to_columns; -use datafusion_expr::{Expr, LogicalPlan, UserDefinedLogicalNode}; +use datafusion_expr::{AggregateUDF, Expr, LogicalPlan, UserDefinedLogicalNode}; use promql::extension_plan::{ EmptyMetric, InstantManipulate, RangeManipulate, SeriesDivide, SeriesNormalize, }; @@ -24,21 +27,91 @@ use promql::extension_plan::{ use crate::dist_plan::MergeScanLogicalPlan; #[allow(dead_code)] -pub enum Commutativity { +pub enum Commutativity { Commutative, PartialCommutative, - ConditionalCommutative(Option), - TransformedCommutative(Option), + ConditionalCommutative(Option>), + TransformedCommutative(Option>), NonCommutative, Unimplemented, /// For unrelated plans like DDL Unsupported, } +impl Commutativity { + /// Check if self is stricter than `lhs` + fn is_stricter_than(&self, lhs: &Self) -> bool { + match (lhs, self) { + (Commutativity::Commutative, Commutativity::Commutative) => false, + (Commutativity::Commutative, _) => true, + + ( + Commutativity::PartialCommutative, + Commutativity::Commutative | Commutativity::PartialCommutative, + ) => false, + (Commutativity::PartialCommutative, _) => true, + + ( + Commutativity::ConditionalCommutative(_), + Commutativity::Commutative + | Commutativity::PartialCommutative + | Commutativity::ConditionalCommutative(_), + ) => false, + (Commutativity::ConditionalCommutative(_), _) => true, + + ( + Commutativity::TransformedCommutative(_), + Commutativity::Commutative + | Commutativity::PartialCommutative + | Commutativity::ConditionalCommutative(_) + | Commutativity::TransformedCommutative(_), + ) => false, + (Commutativity::TransformedCommutative(_), _) => true, + + ( + Commutativity::NonCommutative + | Commutativity::Unimplemented + | Commutativity::Unsupported, + _, + ) => false, + } + } + + /// Return a bare commutative level without any transformer + fn bare_level(&self) -> Commutativity { + match self { + Commutativity::Commutative => Commutativity::Commutative, + Commutativity::PartialCommutative => Commutativity::PartialCommutative, + Commutativity::ConditionalCommutative(_) => Commutativity::ConditionalCommutative(None), + Commutativity::TransformedCommutative(_) => Commutativity::TransformedCommutative(None), + Commutativity::NonCommutative => Commutativity::NonCommutative, + Commutativity::Unimplemented => Commutativity::Unimplemented, + Commutativity::Unsupported => Commutativity::Unsupported, + } + } +} + +impl std::fmt::Debug for Commutativity { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Commutativity::Commutative => write!(f, "Commutative"), + Commutativity::PartialCommutative => write!(f, "PartialCommutative"), + Commutativity::ConditionalCommutative(_) => write!(f, "ConditionalCommutative"), + Commutativity::TransformedCommutative(_) => write!(f, "TransformedCommutative"), + Commutativity::NonCommutative => write!(f, "NonCommutative"), + Commutativity::Unimplemented => write!(f, "Unimplemented"), + Commutativity::Unsupported => write!(f, "Unsupported"), + } + } +} + pub struct Categorizer {} impl Categorizer { - pub fn check_plan(plan: &LogicalPlan, partition_cols: Option>) -> Commutativity { + pub fn check_plan( + plan: &LogicalPlan, + partition_cols: Option>, + ) -> Commutativity { let partition_cols = partition_cols.unwrap_or_default(); match plan { @@ -46,21 +119,91 @@ impl Categorizer { for expr in &proj.expr { let commutativity = Self::check_expr(expr); if !matches!(commutativity, Commutativity::Commutative) { - return commutativity; + return commutativity.bare_level(); } } Commutativity::Commutative } // TODO(ruihang): Change this to Commutative once Like is supported in substrait - LogicalPlan::Filter(filter) => Self::check_expr(&filter.predicate), + LogicalPlan::Filter(filter) => Self::check_expr(&filter.predicate).bare_level(), LogicalPlan::Window(_) => Commutativity::Unimplemented, LogicalPlan::Aggregate(aggr) => { + // fast path: if the group_expr is a subset of partition_cols if Self::check_partition(&aggr.group_expr, &partition_cols) { return Commutativity::Commutative; } - // check all children exprs and uses the strictest level - Commutativity::Unimplemented + common_telemetry::info!("[DEBUG] aggregate plan expr: {:?}", aggr.aggr_expr); + + // get all commutativity levels of aggregate exprs and find the strictest one + let aggr_expr_comm = aggr + .aggr_expr + .iter() + .map(Self::check_expr) + .collect::>(); + let mut strictest = Commutativity::Commutative; + for comm in &aggr_expr_comm { + if comm.is_stricter_than(&strictest) { + strictest = comm.bare_level(); + } + } + + common_telemetry::info!("[DEBUG] aggr_expr_comm: {:?}", aggr_expr_comm); + common_telemetry::info!("[DEBUG] strictest: {:?}", strictest); + + // fast path: if any expr is commutative or non-commutative + if matches!( + strictest, + Commutativity::Commutative + | Commutativity::NonCommutative + | Commutativity::Unimplemented + | Commutativity::Unsupported + ) { + return strictest.bare_level(); + } + + common_telemetry::info!("[DEBUG] continue for strictest",); + + // collect expr transformers + let mut expr_transformer = Vec::with_capacity(aggr.aggr_expr.len()); + for expr_comm in aggr_expr_comm { + match expr_comm { + Commutativity::Commutative => expr_transformer.push(None), + Commutativity::ConditionalCommutative(transformer) => { + expr_transformer.push(transformer.clone()); + } + Commutativity::PartialCommutative => expr_transformer + .push(Some(Arc::new(expr_partial_commutative_transformer))), + _ => expr_transformer.push(None), + } + } + + // build plan transformer + let transformer = Arc::new(move |plan: &LogicalPlan| { + if let LogicalPlan::Aggregate(aggr) = plan { + let mut new_plan = aggr.clone(); + for (expr, transformer) in + new_plan.aggr_expr.iter_mut().zip(&expr_transformer) + { + if let Some(transformer) = transformer { + let new_expr = transformer(expr)?; + *expr = new_expr; + } + } + common_telemetry::info!( + "[DEBUG] new plan aggr expr: {:?}, group expr: {:?}", + new_plan.aggr_expr, + new_plan.group_expr + ); + Some(LogicalPlan::Aggregate(new_plan)) + } else { + None + } + }); + + common_telemetry::info!("[DEBUG] done TransformedCommutative for aggr plan "); + + Commutativity::TransformedCommutative(Some(transformer)) } LogicalPlan::Sort(_) => { if partition_cols.is_empty() { @@ -111,7 +254,7 @@ impl Categorizer { } } - pub fn check_extension_plan(plan: &dyn UserDefinedLogicalNode) -> Commutativity { + pub fn check_extension_plan(plan: &dyn UserDefinedLogicalNode) -> Commutativity { match plan.name() { name if name == EmptyMetric::name() || name == InstantManipulate::name() @@ -126,7 +269,7 @@ impl Categorizer { } } - pub fn check_expr(expr: &Expr) -> Commutativity { + pub fn check_expr(expr: &Expr) -> Commutativity { match expr { Expr::Column(_) | Expr::ScalarVariable(_, _) @@ -152,13 +295,14 @@ impl Categorizer { | Expr::Case(_) | Expr::Cast(_) | Expr::TryCast(_) - | Expr::AggregateFunction(_) | Expr::WindowFunction(_) | Expr::InList(_) | Expr::InSubquery(_) | Expr::ScalarSubquery(_) | Expr::Wildcard { .. } => Commutativity::Unimplemented, + Expr::AggregateFunction(aggr_fn) => Self::check_aggregate_fn(aggr_fn), + Expr::Alias(_) | Expr::Unnest(_) | Expr::GroupingSet(_) @@ -167,6 +311,59 @@ impl Categorizer { } } + fn check_aggregate_fn(aggr_fn: &AggregateFunction) -> Commutativity { + common_telemetry::info!("[DEBUG] checking aggr_fn: {:?}", aggr_fn); + match &aggr_fn.func_def { + AggregateFunctionDefinition::BuiltIn(func_def) => match func_def { + BuiltInAggregateFunction::Max | BuiltInAggregateFunction::Min => { + // Commutativity::PartialCommutative + common_telemetry::info!("[DEBUG] checking min/max: {:?}", aggr_fn); + let mut new_fn = aggr_fn.clone(); + let col_name = Expr::AggregateFunction(aggr_fn.clone()) + .name_for_alias() + .expect("not a sort expr"); + let alias = col_name.clone(); + new_fn.args = vec![Expr::Column(col_name.into())]; + + // new_fn.func_def = + // AggregateFunctionDefinition::BuiltIn(BuiltInAggregateFunction::Sum); + Commutativity::ConditionalCommutative(Some(Arc::new(move |_| { + common_telemetry::info!("[DEBUG] transforming min/max fn: {:?}", new_fn); + Some(Expr::AggregateFunction(new_fn.clone()).alias(alias.clone())) + }))) + } + BuiltInAggregateFunction::Count => { + common_telemetry::info!("[DEBUG] checking count_fn: {:?}", aggr_fn); + let col_name = Expr::AggregateFunction(aggr_fn.clone()) + .name_for_alias() + .expect("not a sort expr"); + let sum_udf = Arc::new(AggregateUDF::new_from_impl(Sum::new())); + let alias = col_name.clone(); + // let sum_func = Arc::new(AggregateFunction::new_udf( + // sum_udf, + // vec![Expr::Column(col_name.into())], + // false, + // None, + // None, + // None, + // )); + let mut sum_expr = aggr_fn.clone(); + sum_expr.func_def = AggregateFunctionDefinition::UDF(sum_udf); + sum_expr.args = vec![Expr::Column(col_name.into())]; + // let mut sum_fn = aggr_fn.clone(); + // sum_fn.func_def = + // AggregateFunctionDefinition::BuiltIn(BuiltInAggregateFunction::Sum); + Commutativity::ConditionalCommutative(Some(Arc::new(move |_| { + common_telemetry::info!("[DEBUG] transforming sum_fn: {:?}", sum_expr); + Some(Expr::AggregateFunction(sum_expr.clone()).alias(alias.clone())) + }))) + } + _ => Commutativity::Unimplemented, + }, + AggregateFunctionDefinition::UDF(_) => Commutativity::Unimplemented, + } + } + /// Return true if the given expr and partition cols satisfied the rule. /// In this case the plan can be treated as fully commutative. fn check_partition(exprs: &[Expr], partition_cols: &[String]) -> bool { @@ -188,12 +385,16 @@ impl Categorizer { } } -pub type Transformer = Arc Option>; +pub type Transformer = Arc Fn(&'a T) -> Option>; pub fn partial_commutative_transformer(plan: &LogicalPlan) -> Option { Some(plan.clone()) } +pub fn expr_partial_commutative_transformer(expr: &Expr) -> Option { + Some(expr.clone()) +} + #[cfg(test)] mod test { use datafusion_expr::{LogicalPlanBuilder, Sort};