mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-01-05 21:02:58 +00:00
feat: support transforming min/max/count aggr fn
Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
This commit is contained in:
@@ -158,7 +158,10 @@ impl PlanRewriter {
|
||||
return true;
|
||||
}
|
||||
|
||||
match Categorizer::check_plan(plan, self.partition_cols.clone()) {
|
||||
let result = Categorizer::check_plan(plan, self.partition_cols.clone());
|
||||
common_telemetry::info!("[DEBUG] Categorizer result: {:?}", result);
|
||||
|
||||
match result {
|
||||
Commutativity::Commutative => {}
|
||||
Commutativity::PartialCommutative => {
|
||||
if let Some(plan) = partial_commutative_transformer(plan) {
|
||||
@@ -169,6 +172,7 @@ impl PlanRewriter {
|
||||
if let Some(transformer) = transformer
|
||||
&& let Some(plan) = transformer(plan)
|
||||
{
|
||||
common_telemetry::info!("ConditionalCommutative new plan: {:?}", plan);
|
||||
self.stage.push(plan)
|
||||
}
|
||||
}
|
||||
@@ -176,6 +180,7 @@ impl PlanRewriter {
|
||||
if let Some(transformer) = transformer
|
||||
&& let Some(plan) = transformer(plan)
|
||||
{
|
||||
common_telemetry::info!("TransformedCommutative new plan: {:?}", plan);
|
||||
self.stage.push(plan)
|
||||
}
|
||||
}
|
||||
@@ -277,8 +282,9 @@ impl TreeNodeRewriter for PlanRewriter {
|
||||
// add merge scan as the new root
|
||||
let mut node = MergeScanLogicalPlan::new(node, false).into_logical_plan();
|
||||
// expand stages
|
||||
common_telemetry::info!("[DEBUG] end with stage, expanding: {:?}", self.stage);
|
||||
for new_stage in self.stage.drain(..) {
|
||||
node = new_stage.with_new_exprs(node.expressions(), vec![node.clone()])?
|
||||
node = new_stage.with_new_exprs(new_stage.expressions(), vec![node.clone()])?
|
||||
}
|
||||
self.set_expanded();
|
||||
|
||||
@@ -288,12 +294,14 @@ impl TreeNodeRewriter for PlanRewriter {
|
||||
|
||||
// TODO(ruihang): avoid this clone
|
||||
if self.should_expand(&parent.clone()) {
|
||||
// TODO(ruihang): does this work for nodes with multiple children?;
|
||||
common_telemetry::info!("[DEBUG] expand with stage: {:?}", self.stage);
|
||||
// TODO(ruihang): does this work for nodes with multiple children?
|
||||
// replace the current node with expanded one
|
||||
let mut node = MergeScanLogicalPlan::new(node, false).into_logical_plan();
|
||||
// expand stages
|
||||
for new_stage in self.stage.drain(..) {
|
||||
node = new_stage.with_new_exprs(node.expressions(), vec![node.clone()])?
|
||||
node = new_stage.with_new_exprs(new_stage.expressions(), vec![node.clone()])?
|
||||
// node = new_stage.with_new
|
||||
}
|
||||
self.set_expanded();
|
||||
|
||||
|
||||
@@ -15,8 +15,11 @@
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion::functions_aggregate::sum::Sum;
|
||||
use datafusion_expr::aggregate_function::AggregateFunction as BuiltInAggregateFunction;
|
||||
use datafusion_expr::expr::{AggregateFunction, AggregateFunctionDefinition};
|
||||
use datafusion_expr::utils::exprlist_to_columns;
|
||||
use datafusion_expr::{Expr, LogicalPlan, UserDefinedLogicalNode};
|
||||
use datafusion_expr::{AggregateUDF, Expr, LogicalPlan, UserDefinedLogicalNode};
|
||||
use promql::extension_plan::{
|
||||
EmptyMetric, InstantManipulate, RangeManipulate, SeriesDivide, SeriesNormalize,
|
||||
};
|
||||
@@ -24,21 +27,91 @@ use promql::extension_plan::{
|
||||
use crate::dist_plan::MergeScanLogicalPlan;
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub enum Commutativity {
|
||||
pub enum Commutativity<T> {
|
||||
Commutative,
|
||||
PartialCommutative,
|
||||
ConditionalCommutative(Option<Transformer>),
|
||||
TransformedCommutative(Option<Transformer>),
|
||||
ConditionalCommutative(Option<Transformer<T>>),
|
||||
TransformedCommutative(Option<Transformer<T>>),
|
||||
NonCommutative,
|
||||
Unimplemented,
|
||||
/// For unrelated plans like DDL
|
||||
Unsupported,
|
||||
}
|
||||
|
||||
impl<T> Commutativity<T> {
|
||||
/// Check if self is stricter than `lhs`
|
||||
fn is_stricter_than(&self, lhs: &Self) -> bool {
|
||||
match (lhs, self) {
|
||||
(Commutativity::Commutative, Commutativity::Commutative) => false,
|
||||
(Commutativity::Commutative, _) => true,
|
||||
|
||||
(
|
||||
Commutativity::PartialCommutative,
|
||||
Commutativity::Commutative | Commutativity::PartialCommutative,
|
||||
) => false,
|
||||
(Commutativity::PartialCommutative, _) => true,
|
||||
|
||||
(
|
||||
Commutativity::ConditionalCommutative(_),
|
||||
Commutativity::Commutative
|
||||
| Commutativity::PartialCommutative
|
||||
| Commutativity::ConditionalCommutative(_),
|
||||
) => false,
|
||||
(Commutativity::ConditionalCommutative(_), _) => true,
|
||||
|
||||
(
|
||||
Commutativity::TransformedCommutative(_),
|
||||
Commutativity::Commutative
|
||||
| Commutativity::PartialCommutative
|
||||
| Commutativity::ConditionalCommutative(_)
|
||||
| Commutativity::TransformedCommutative(_),
|
||||
) => false,
|
||||
(Commutativity::TransformedCommutative(_), _) => true,
|
||||
|
||||
(
|
||||
Commutativity::NonCommutative
|
||||
| Commutativity::Unimplemented
|
||||
| Commutativity::Unsupported,
|
||||
_,
|
||||
) => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Return a bare commutative level without any transformer
|
||||
fn bare_level<To>(&self) -> Commutativity<To> {
|
||||
match self {
|
||||
Commutativity::Commutative => Commutativity::Commutative,
|
||||
Commutativity::PartialCommutative => Commutativity::PartialCommutative,
|
||||
Commutativity::ConditionalCommutative(_) => Commutativity::ConditionalCommutative(None),
|
||||
Commutativity::TransformedCommutative(_) => Commutativity::TransformedCommutative(None),
|
||||
Commutativity::NonCommutative => Commutativity::NonCommutative,
|
||||
Commutativity::Unimplemented => Commutativity::Unimplemented,
|
||||
Commutativity::Unsupported => Commutativity::Unsupported,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::fmt::Debug for Commutativity<T> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Commutativity::Commutative => write!(f, "Commutative"),
|
||||
Commutativity::PartialCommutative => write!(f, "PartialCommutative"),
|
||||
Commutativity::ConditionalCommutative(_) => write!(f, "ConditionalCommutative"),
|
||||
Commutativity::TransformedCommutative(_) => write!(f, "TransformedCommutative"),
|
||||
Commutativity::NonCommutative => write!(f, "NonCommutative"),
|
||||
Commutativity::Unimplemented => write!(f, "Unimplemented"),
|
||||
Commutativity::Unsupported => write!(f, "Unsupported"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Categorizer {}
|
||||
|
||||
impl Categorizer {
|
||||
pub fn check_plan(plan: &LogicalPlan, partition_cols: Option<Vec<String>>) -> Commutativity {
|
||||
pub fn check_plan(
|
||||
plan: &LogicalPlan,
|
||||
partition_cols: Option<Vec<String>>,
|
||||
) -> Commutativity<LogicalPlan> {
|
||||
let partition_cols = partition_cols.unwrap_or_default();
|
||||
|
||||
match plan {
|
||||
@@ -46,21 +119,91 @@ impl Categorizer {
|
||||
for expr in &proj.expr {
|
||||
let commutativity = Self::check_expr(expr);
|
||||
if !matches!(commutativity, Commutativity::Commutative) {
|
||||
return commutativity;
|
||||
return commutativity.bare_level();
|
||||
}
|
||||
}
|
||||
Commutativity::Commutative
|
||||
}
|
||||
// TODO(ruihang): Change this to Commutative once Like is supported in substrait
|
||||
LogicalPlan::Filter(filter) => Self::check_expr(&filter.predicate),
|
||||
LogicalPlan::Filter(filter) => Self::check_expr(&filter.predicate).bare_level(),
|
||||
LogicalPlan::Window(_) => Commutativity::Unimplemented,
|
||||
LogicalPlan::Aggregate(aggr) => {
|
||||
// fast path: if the group_expr is a subset of partition_cols
|
||||
if Self::check_partition(&aggr.group_expr, &partition_cols) {
|
||||
return Commutativity::Commutative;
|
||||
}
|
||||
|
||||
// check all children exprs and uses the strictest level
|
||||
Commutativity::Unimplemented
|
||||
common_telemetry::info!("[DEBUG] aggregate plan expr: {:?}", aggr.aggr_expr);
|
||||
|
||||
// get all commutativity levels of aggregate exprs and find the strictest one
|
||||
let aggr_expr_comm = aggr
|
||||
.aggr_expr
|
||||
.iter()
|
||||
.map(Self::check_expr)
|
||||
.collect::<Vec<_>>();
|
||||
let mut strictest = Commutativity::Commutative;
|
||||
for comm in &aggr_expr_comm {
|
||||
if comm.is_stricter_than(&strictest) {
|
||||
strictest = comm.bare_level();
|
||||
}
|
||||
}
|
||||
|
||||
common_telemetry::info!("[DEBUG] aggr_expr_comm: {:?}", aggr_expr_comm);
|
||||
common_telemetry::info!("[DEBUG] strictest: {:?}", strictest);
|
||||
|
||||
// fast path: if any expr is commutative or non-commutative
|
||||
if matches!(
|
||||
strictest,
|
||||
Commutativity::Commutative
|
||||
| Commutativity::NonCommutative
|
||||
| Commutativity::Unimplemented
|
||||
| Commutativity::Unsupported
|
||||
) {
|
||||
return strictest.bare_level();
|
||||
}
|
||||
|
||||
common_telemetry::info!("[DEBUG] continue for strictest",);
|
||||
|
||||
// collect expr transformers
|
||||
let mut expr_transformer = Vec::with_capacity(aggr.aggr_expr.len());
|
||||
for expr_comm in aggr_expr_comm {
|
||||
match expr_comm {
|
||||
Commutativity::Commutative => expr_transformer.push(None),
|
||||
Commutativity::ConditionalCommutative(transformer) => {
|
||||
expr_transformer.push(transformer.clone());
|
||||
}
|
||||
Commutativity::PartialCommutative => expr_transformer
|
||||
.push(Some(Arc::new(expr_partial_commutative_transformer))),
|
||||
_ => expr_transformer.push(None),
|
||||
}
|
||||
}
|
||||
|
||||
// build plan transformer
|
||||
let transformer = Arc::new(move |plan: &LogicalPlan| {
|
||||
if let LogicalPlan::Aggregate(aggr) = plan {
|
||||
let mut new_plan = aggr.clone();
|
||||
for (expr, transformer) in
|
||||
new_plan.aggr_expr.iter_mut().zip(&expr_transformer)
|
||||
{
|
||||
if let Some(transformer) = transformer {
|
||||
let new_expr = transformer(expr)?;
|
||||
*expr = new_expr;
|
||||
}
|
||||
}
|
||||
common_telemetry::info!(
|
||||
"[DEBUG] new plan aggr expr: {:?}, group expr: {:?}",
|
||||
new_plan.aggr_expr,
|
||||
new_plan.group_expr
|
||||
);
|
||||
Some(LogicalPlan::Aggregate(new_plan))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
|
||||
common_telemetry::info!("[DEBUG] done TransformedCommutative for aggr plan ");
|
||||
|
||||
Commutativity::TransformedCommutative(Some(transformer))
|
||||
}
|
||||
LogicalPlan::Sort(_) => {
|
||||
if partition_cols.is_empty() {
|
||||
@@ -111,7 +254,7 @@ impl Categorizer {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check_extension_plan(plan: &dyn UserDefinedLogicalNode) -> Commutativity {
|
||||
pub fn check_extension_plan(plan: &dyn UserDefinedLogicalNode) -> Commutativity<LogicalPlan> {
|
||||
match plan.name() {
|
||||
name if name == EmptyMetric::name()
|
||||
|| name == InstantManipulate::name()
|
||||
@@ -126,7 +269,7 @@ impl Categorizer {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check_expr(expr: &Expr) -> Commutativity {
|
||||
pub fn check_expr(expr: &Expr) -> Commutativity<Expr> {
|
||||
match expr {
|
||||
Expr::Column(_)
|
||||
| Expr::ScalarVariable(_, _)
|
||||
@@ -152,13 +295,14 @@ impl Categorizer {
|
||||
| Expr::Case(_)
|
||||
| Expr::Cast(_)
|
||||
| Expr::TryCast(_)
|
||||
| Expr::AggregateFunction(_)
|
||||
| Expr::WindowFunction(_)
|
||||
| Expr::InList(_)
|
||||
| Expr::InSubquery(_)
|
||||
| Expr::ScalarSubquery(_)
|
||||
| Expr::Wildcard { .. } => Commutativity::Unimplemented,
|
||||
|
||||
Expr::AggregateFunction(aggr_fn) => Self::check_aggregate_fn(aggr_fn),
|
||||
|
||||
Expr::Alias(_)
|
||||
| Expr::Unnest(_)
|
||||
| Expr::GroupingSet(_)
|
||||
@@ -167,6 +311,59 @@ impl Categorizer {
|
||||
}
|
||||
}
|
||||
|
||||
fn check_aggregate_fn(aggr_fn: &AggregateFunction) -> Commutativity<Expr> {
|
||||
common_telemetry::info!("[DEBUG] checking aggr_fn: {:?}", aggr_fn);
|
||||
match &aggr_fn.func_def {
|
||||
AggregateFunctionDefinition::BuiltIn(func_def) => match func_def {
|
||||
BuiltInAggregateFunction::Max | BuiltInAggregateFunction::Min => {
|
||||
// Commutativity::PartialCommutative
|
||||
common_telemetry::info!("[DEBUG] checking min/max: {:?}", aggr_fn);
|
||||
let mut new_fn = aggr_fn.clone();
|
||||
let col_name = Expr::AggregateFunction(aggr_fn.clone())
|
||||
.name_for_alias()
|
||||
.expect("not a sort expr");
|
||||
let alias = col_name.clone();
|
||||
new_fn.args = vec![Expr::Column(col_name.into())];
|
||||
|
||||
// new_fn.func_def =
|
||||
// AggregateFunctionDefinition::BuiltIn(BuiltInAggregateFunction::Sum);
|
||||
Commutativity::ConditionalCommutative(Some(Arc::new(move |_| {
|
||||
common_telemetry::info!("[DEBUG] transforming min/max fn: {:?}", new_fn);
|
||||
Some(Expr::AggregateFunction(new_fn.clone()).alias(alias.clone()))
|
||||
})))
|
||||
}
|
||||
BuiltInAggregateFunction::Count => {
|
||||
common_telemetry::info!("[DEBUG] checking count_fn: {:?}", aggr_fn);
|
||||
let col_name = Expr::AggregateFunction(aggr_fn.clone())
|
||||
.name_for_alias()
|
||||
.expect("not a sort expr");
|
||||
let sum_udf = Arc::new(AggregateUDF::new_from_impl(Sum::new()));
|
||||
let alias = col_name.clone();
|
||||
// let sum_func = Arc::new(AggregateFunction::new_udf(
|
||||
// sum_udf,
|
||||
// vec![Expr::Column(col_name.into())],
|
||||
// false,
|
||||
// None,
|
||||
// None,
|
||||
// None,
|
||||
// ));
|
||||
let mut sum_expr = aggr_fn.clone();
|
||||
sum_expr.func_def = AggregateFunctionDefinition::UDF(sum_udf);
|
||||
sum_expr.args = vec![Expr::Column(col_name.into())];
|
||||
// let mut sum_fn = aggr_fn.clone();
|
||||
// sum_fn.func_def =
|
||||
// AggregateFunctionDefinition::BuiltIn(BuiltInAggregateFunction::Sum);
|
||||
Commutativity::ConditionalCommutative(Some(Arc::new(move |_| {
|
||||
common_telemetry::info!("[DEBUG] transforming sum_fn: {:?}", sum_expr);
|
||||
Some(Expr::AggregateFunction(sum_expr.clone()).alias(alias.clone()))
|
||||
})))
|
||||
}
|
||||
_ => Commutativity::Unimplemented,
|
||||
},
|
||||
AggregateFunctionDefinition::UDF(_) => Commutativity::Unimplemented,
|
||||
}
|
||||
}
|
||||
|
||||
/// Return true if the given expr and partition cols satisfied the rule.
|
||||
/// In this case the plan can be treated as fully commutative.
|
||||
fn check_partition(exprs: &[Expr], partition_cols: &[String]) -> bool {
|
||||
@@ -188,12 +385,16 @@ impl Categorizer {
|
||||
}
|
||||
}
|
||||
|
||||
pub type Transformer = Arc<dyn Fn(&LogicalPlan) -> Option<LogicalPlan>>;
|
||||
pub type Transformer<T> = Arc<dyn for<'a> Fn(&'a T) -> Option<T>>;
|
||||
|
||||
pub fn partial_commutative_transformer(plan: &LogicalPlan) -> Option<LogicalPlan> {
|
||||
Some(plan.clone())
|
||||
}
|
||||
|
||||
pub fn expr_partial_commutative_transformer(expr: &Expr) -> Option<Expr> {
|
||||
Some(expr.clone())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use datafusion_expr::{LogicalPlanBuilder, Sort};
|
||||
|
||||
Reference in New Issue
Block a user