feat: register all aggregate function to auto step aggr fn (#6596)

* feat: support generic aggr push down

Signed-off-by: discord9 <discord9@163.com>

* typo

Signed-off-by: discord9 <discord9@163.com>

* fix: type ck in merge wrapper

Signed-off-by: discord9 <discord9@163.com>

* test: update sqlness

Signed-off-by: discord9 <discord9@163.com>

* feat: support all registried aggr func

Signed-off-by: discord9 <discord9@163.com>

* chore: per review

Signed-off-by: discord9 <discord9@163.com>

* chore: per review

Signed-off-by: discord9 <discord9@163.com>

---------

Signed-off-by: discord9 <discord9@163.com>
This commit is contained in:
discord9
2025-08-05 19:37:45 +08:00
committed by Zhenchi
parent 569d93c599
commit 469c3140fe
12 changed files with 510 additions and 442 deletions

View File

@@ -26,6 +26,8 @@ use std::sync::Arc;
use arrow::array::StructArray;
use arrow_schema::Fields;
use common_telemetry::debug;
use datafusion::functions_aggregate::all_default_aggregate_functions;
use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
use datafusion::optimizer::AnalyzerRule;
use datafusion::physical_planner::create_aggregate_expr_and_maybe_filter;
@@ -39,6 +41,8 @@ use datafusion_expr::{
use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
use datatypes::arrow::datatypes::{DataType, Field};
use crate::function_registry::FunctionRegistry;
/// Returns the name of the state function for the given aggregate function name.
/// The state function is used to compute the state of the aggregate function.
/// The state function's name is in the format `__<aggr_name>_state
@@ -65,9 +69,9 @@ pub struct StateMergeHelper;
#[derive(Debug, Clone)]
pub struct StepAggrPlan {
/// Upper merge plan, which is the aggregate plan that merges the states of the state function.
pub upper_merge: Arc<LogicalPlan>,
pub upper_merge: LogicalPlan,
/// Lower state plan, which is the aggregate plan that computes the state of the aggregate function.
pub lower_state: Arc<LogicalPlan>,
pub lower_state: LogicalPlan,
}
pub fn get_aggr_func(expr: &Expr) -> Option<&datafusion_expr::expr::AggregateFunction> {
@@ -83,6 +87,36 @@ pub fn get_aggr_func(expr: &Expr) -> Option<&datafusion_expr::expr::AggregateFun
}
impl StateMergeHelper {
/// Register all the `state` function of supported aggregate functions.
/// Note that can't register `merge` function here, as it needs to be created from the original aggregate function with given input types.
pub fn register(registry: &FunctionRegistry) {
let all_default = all_default_aggregate_functions();
let greptime_custom_aggr_functions = registry.aggregate_functions();
// if our custom aggregate function have the same name as the default aggregate function, we will override it.
let supported = all_default
.into_iter()
.chain(greptime_custom_aggr_functions.into_iter().map(Arc::new))
.collect::<Vec<_>>();
debug!(
"Registering state functions for supported: {:?}",
supported.iter().map(|f| f.name()).collect::<Vec<_>>()
);
let state_func = supported.into_iter().filter_map(|f| {
StateWrapper::new((*f).clone())
.inspect_err(
|e| common_telemetry::error!(e; "Failed to register state function for {:?}", f),
)
.ok()
.map(AggregateUDF::new_from_impl)
});
for func in state_func {
registry.register_aggr(func);
}
}
/// Split an aggregate plan into two aggregate plans, one for the state function and one for the merge function.
pub fn split_aggr_node(aggr_plan: Aggregate) -> datafusion_common::Result<StepAggrPlan> {
let aggr = {
@@ -166,18 +200,18 @@ impl StateMergeHelper {
let lower_plan = LogicalPlan::Aggregate(lower);
// update aggregate's output schema
let lower_plan = Arc::new(lower_plan.recompute_schema()?);
let lower_plan = lower_plan.recompute_schema()?;
let mut upper = aggr.clone();
let aggr_plan = LogicalPlan::Aggregate(aggr);
upper.aggr_expr = upper_aggr_exprs;
upper.input = lower_plan.clone();
upper.input = Arc::new(lower_plan.clone());
// upper schema's output schema should be the same as the original aggregate plan's output schema
let upper_check = upper.clone();
let upper_plan = Arc::new(LogicalPlan::Aggregate(upper_check).recompute_schema()?);
let upper_check = upper;
let upper_plan = LogicalPlan::Aggregate(upper_check).recompute_schema()?;
if *upper_plan.schema() != *aggr_plan.schema() {
return Err(datafusion_common::DataFusionError::Internal(format!(
"Upper aggregate plan's schema is not the same as the original aggregate plan's schema: \n[transformed]:{}\n[ original]{}",
"Upper aggregate plan's schema is not the same as the original aggregate plan's schema: \n[transformed]:{}\n[original]:{}",
upper_plan.schema(), aggr_plan.schema()
)));
}
@@ -407,15 +441,18 @@ impl AggregateUDFImpl for MergeWrapper {
&'a self,
acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
) -> datafusion_common::Result<Box<dyn Accumulator>> {
if acc_args.schema.fields().len() != 1
|| !matches!(acc_args.schema.field(0).data_type(), DataType::Struct(_))
if acc_args.exprs.len() != 1
|| !matches!(
acc_args.exprs[0].data_type(acc_args.schema)?,
DataType::Struct(_)
)
{
return Err(datafusion_common::DataFusionError::Internal(format!(
"Expected one struct type as input, got: {:?}",
acc_args.schema
)));
}
let input_type = acc_args.schema.field(0).data_type();
let input_type = acc_args.exprs[0].data_type(acc_args.schema)?;
let DataType::Struct(fields) = input_type else {
return Err(datafusion_common::DataFusionError::Internal(format!(
"Expected a struct type for input, got: {:?}",
@@ -424,7 +461,7 @@ impl AggregateUDFImpl for MergeWrapper {
};
let inner_accum = self.original_phy_expr.create_accumulator()?;
Ok(Box::new(MergeAccum::new(inner_accum, fields)))
Ok(Box::new(MergeAccum::new(inner_accum, &fields)))
}
fn as_any(&self) -> &dyn std::any::Any {

View File

@@ -258,7 +258,7 @@ async fn test_sum_udaf() {
)
.recompute_schema()
.unwrap();
assert_eq!(res.lower_state.as_ref(), &expected_lower_plan);
assert_eq!(&res.lower_state, &expected_lower_plan);
let expected_merge_plan = LogicalPlan::Aggregate(
Aggregate::try_new(
@@ -297,7 +297,7 @@ async fn test_sum_udaf() {
)
.unwrap(),
);
assert_eq!(res.upper_merge.as_ref(), &expected_merge_plan);
assert_eq!(&res.upper_merge, &expected_merge_plan);
let phy_aggr_state_plan = DefaultPhysicalPlanner::default()
.create_physical_plan(&res.lower_state, &ctx.state())
@@ -405,7 +405,7 @@ async fn test_avg_udaf() {
let coerced_aggr_state_plan = TypeCoercion::new()
.analyze(expected_aggr_state_plan.clone(), &Default::default())
.unwrap();
assert_eq!(res.lower_state.as_ref(), &coerced_aggr_state_plan);
assert_eq!(&res.lower_state, &coerced_aggr_state_plan);
assert_eq!(
res.lower_state.schema().as_arrow(),
&arrow_schema::Schema::new(vec![Field::new(
@@ -456,7 +456,7 @@ async fn test_avg_udaf() {
)
.unwrap(),
);
assert_eq!(res.upper_merge.as_ref(), &expected_merge_plan);
assert_eq!(&res.upper_merge, &expected_merge_plan);
let phy_aggr_state_plan = DefaultPhysicalPlanner::default()
.create_physical_plan(&coerced_aggr_state_plan, &ctx.state())

View File

@@ -20,6 +20,7 @@ use datafusion_expr::AggregateUDF;
use once_cell::sync::Lazy;
use crate::admin::AdminFunction;
use crate::aggrs::aggr_wrapper::StateMergeHelper;
use crate::aggrs::approximate::ApproximateFunction;
use crate::aggrs::count_hash::CountHash;
use crate::aggrs::vector::VectorFunction as VectorAggrFunction;
@@ -105,6 +106,10 @@ impl FunctionRegistry {
.cloned()
.collect()
}
pub fn is_aggr_func_exist(&self, name: &str) -> bool {
self.aggregate_functions.read().unwrap().contains_key(name)
}
}
pub static FUNCTION_REGISTRY: Lazy<Arc<FunctionRegistry>> = Lazy::new(|| {
@@ -148,6 +153,9 @@ pub static FUNCTION_REGISTRY: Lazy<Arc<FunctionRegistry>> = Lazy::new(|| {
// CountHash function
CountHash::register(&function_registry);
// state function of supported aggregate functions
StateMergeHelper::register(&function_registry);
Arc::new(function_registry)
});

View File

@@ -369,6 +369,9 @@ impl PlanRewriter {
.collect::<Vec<_>>()
.join("\n")
);
if let Some(new_child_plan) = &transformer_actions.new_child_plan {
debug!("PlanRewriter: new child plan: {}", new_child_plan);
}
if let Some(last_stage) = transformer_actions.extra_parent_plans.last() {
// update the column requirements from the last stage
// notice current plan's parent plan is where we need to apply the column requirements
@@ -501,12 +504,12 @@ impl PlanRewriter {
}
fn expand(&mut self, mut on_node: LogicalPlan) -> DfResult<LogicalPlan> {
// store schema before expand, new child plan might have a different schema, so not using it
let schema = on_node.schema().clone();
if let Some(new_child_plan) = self.new_child_plan.take() {
// if there is a new child plan, use it as the new root
on_node = new_child_plan;
}
// store schema before expand
let schema = on_node.schema().clone();
let mut rewriter = EnforceDistRequirementRewriter::new(
std::mem::take(&mut self.column_requirements),
self.level,
@@ -514,7 +517,7 @@ impl PlanRewriter {
debug!("PlanRewriter: enforce column requirements for node: {on_node} with rewriter: {rewriter:?}");
on_node = on_node.rewrite(&mut rewriter)?.data;
debug!(
"PlanRewriter: after enforced column requirements for node: {on_node} with rewriter: {rewriter:?}"
"PlanRewriter: after enforced column requirements with rewriter: {rewriter:?} for node:\n{on_node}"
);
// add merge scan as the new root
@@ -702,7 +705,10 @@ impl TreeNodeRewriter for PlanRewriter {
// TODO(ruihang): avoid this clone
if self.should_expand(&parent) {
// TODO(ruihang): does this work for nodes with multiple children?;
debug!("PlanRewriter: should expand child:\n {node}\n Of Parent: {parent}");
debug!(
"PlanRewriter: should expand child:\n {node}\n Of Parent: {}",
parent.display()
);
let node = self.expand(node);
debug!(
"PlanRewriter: expanded plan: {}",

View File

@@ -456,10 +456,9 @@ fn expand_proj_step_aggr() {
let expected = [
"Projection: min(t.number)",
" Projection: min(min(t.number)) AS min(t.number)",
" Aggregate: groupBy=[[]], aggr=[[min(min(t.number))]]",
" MergeScan [is_placeholder=false, remote_input=[",
"Aggregate: groupBy=[[]], aggr=[[min(t.number)]]",
" Aggregate: groupBy=[[]], aggr=[[__min_merge(__min_state(t.number)) AS min(t.number)]]",
" MergeScan [is_placeholder=false, remote_input=[",
"Aggregate: groupBy=[[]], aggr=[[__min_state(t.number)]]",
" Projection: t.number", // This Projection shouldn't add new column requirements
" TableScan: t",
"]]",
@@ -502,10 +501,9 @@ fn expand_proj_alias_fake_part_col_aggr() {
let expected = [
"Projection: pk1, pk2, min(t.number)",
" Projection: pk1, pk2, min(min(t.number)) AS min(t.number)",
" Aggregate: groupBy=[[pk1, pk2]], aggr=[[min(min(t.number))]]",
" MergeScan [is_placeholder=false, remote_input=[",
"Aggregate: groupBy=[[pk1, pk2]], aggr=[[min(t.number)]]",
" Aggregate: groupBy=[[pk1, pk2]], aggr=[[__min_merge(__min_state(t.number)) AS min(t.number)]]",
" MergeScan [is_placeholder=false, remote_input=[",
"Aggregate: groupBy=[[pk1, pk2]], aggr=[[__min_state(t.number)]]",
" Projection: t.number, pk1 AS pk2, pk3 AS pk1",
" Projection: t.number, t.pk3 AS pk1, t.pk2 AS pk3",
" TableScan: t",
@@ -583,10 +581,9 @@ fn expand_part_col_aggr_step_aggr() {
let expected = [
"Projection: min(max(t.number))",
" Projection: min(min(max(t.number))) AS min(max(t.number))",
" Aggregate: groupBy=[[]], aggr=[[min(min(max(t.number)))]]",
" MergeScan [is_placeholder=false, remote_input=[",
"Aggregate: groupBy=[[]], aggr=[[min(max(t.number))]]",
" Aggregate: groupBy=[[]], aggr=[[__min_merge(__min_state(max(t.number))) AS min(max(t.number))]]",
" MergeScan [is_placeholder=false, remote_input=[",
"Aggregate: groupBy=[[]], aggr=[[__min_state(max(t.number))]]",
" Aggregate: groupBy=[[t.pk1, t.pk2]], aggr=[[max(t.number)]]",
" TableScan: t",
"]]",
@@ -618,10 +615,9 @@ fn expand_step_aggr_step_aggr() {
let expected = [
"Aggregate: groupBy=[[]], aggr=[[min(max(t.number))]]",
" Projection: max(t.number)",
" Projection: max(max(t.number)) AS max(t.number)",
" Aggregate: groupBy=[[]], aggr=[[max(max(t.number))]]",
" MergeScan [is_placeholder=false, remote_input=[",
"Aggregate: groupBy=[[]], aggr=[[max(t.number)]]",
" Aggregate: groupBy=[[]], aggr=[[__max_merge(__max_state(t.number)) AS max(t.number)]]",
" MergeScan [is_placeholder=false, remote_input=[",
"Aggregate: groupBy=[[]], aggr=[[__max_state(t.number)]]",
" TableScan: t",
"]]",
]
@@ -695,10 +691,9 @@ fn expand_step_aggr_proj() {
let expected = [
"Projection: min(t.number)",
" Projection: t.pk1, min(t.number)",
" Projection: t.pk1, min(min(t.number)) AS min(t.number)",
" Aggregate: groupBy=[[t.pk1]], aggr=[[min(min(t.number))]]",
" MergeScan [is_placeholder=false, remote_input=[",
"Aggregate: groupBy=[[t.pk1]], aggr=[[min(t.number)]]",
" Aggregate: groupBy=[[t.pk1]], aggr=[[__min_merge(__min_state(t.number)) AS min(t.number)]]",
" MergeScan [is_placeholder=false, remote_input=[",
"Aggregate: groupBy=[[t.pk1]], aggr=[[__min_state(t.number)]]",
" TableScan: t",
"]]",
]
@@ -1109,10 +1104,42 @@ fn expand_step_aggr_limit() {
let expected = [
"Limit: skip=0, fetch=10",
" Projection: t.pk1, min(t.number)",
" Projection: t.pk1, min(min(t.number)) AS min(t.number)",
" Aggregate: groupBy=[[t.pk1]], aggr=[[min(min(t.number))]]",
" MergeScan [is_placeholder=false, remote_input=[",
"Aggregate: groupBy=[[t.pk1]], aggr=[[min(t.number)]]",
" Aggregate: groupBy=[[t.pk1]], aggr=[[__min_merge(__min_state(t.number)) AS min(t.number)]]",
" MergeScan [is_placeholder=false, remote_input=[",
"Aggregate: groupBy=[[t.pk1]], aggr=[[__min_state(t.number)]]",
" TableScan: t",
"]]",
]
.join("\n");
assert_eq!(expected, result.to_string());
}
/// Test how avg get expanded
#[test]
fn expand_step_aggr_avg_limit() {
// use logging for better debugging
init_default_ut_logging();
let test_table = TestTable::table_with_name(0, "numbers".to_string());
let table_source = Arc::new(DefaultTableSource::new(Arc::new(
DfTableProviderAdapter::new(test_table),
)));
let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
.unwrap()
.aggregate(vec![col("pk1")], vec![avg(col("number"))])
.unwrap()
.limit(0, Some(10))
.unwrap()
.build()
.unwrap();
let config = ConfigOptions::default();
let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
let expected = [
"Limit: skip=0, fetch=10",
" Projection: t.pk1, avg(t.number)",
" Aggregate: groupBy=[[t.pk1]], aggr=[[__avg_merge(__avg_state(t.number)) AS avg(t.number)]]",
" MergeScan [is_placeholder=false, remote_input=[",
"Aggregate: groupBy=[[t.pk1]], aggr=[[__avg_state(CAST(t.number AS Float64))]]",
" TableScan: t",
"]]",
]

View File

@@ -15,14 +15,10 @@
use std::collections::HashSet;
use std::sync::Arc;
use common_function::aggrs::approximate::hll::{HllState, HLL_MERGE_NAME, HLL_NAME};
use common_function::aggrs::approximate::uddsketch::{
UddSketchState, UDDSKETCH_MERGE_NAME, UDDSKETCH_STATE_NAME,
};
use common_function::aggrs::aggr_wrapper::{aggr_state_func_name, StateMergeHelper};
use common_function::function_registry::FUNCTION_REGISTRY;
use common_telemetry::debug;
use datafusion::functions_aggregate::sum::sum_udaf;
use datafusion_common::Column;
use datafusion_expr::{Expr, LogicalPlan, Projection, UserDefinedLogicalNode};
use datafusion_expr::{Expr, LogicalPlan, UserDefinedLogicalNode};
use promql::extension_plan::{
EmptyMetric, InstantManipulate, RangeManipulate, SeriesDivide, SeriesNormalize,
};
@@ -31,6 +27,11 @@ use crate::dist_plan::analyzer::AliasMapping;
use crate::dist_plan::merge_sort::{merge_sort_transformer, MergeSortLogicalPlan};
use crate::dist_plan::MergeScanLogicalPlan;
pub struct StepTransformAction {
extra_parent_plans: Vec<LogicalPlan>,
new_child_plan: Option<LogicalPlan>,
}
/// generate the upper aggregation plan that will execute on the frontend.
/// Basically a logical plan resembling the following:
/// Projection:
@@ -42,111 +43,32 @@ use crate::dist_plan::MergeScanLogicalPlan;
/// of the upper aggregation plan.
pub fn step_aggr_to_upper_aggr(
aggr_plan: &LogicalPlan,
) -> datafusion_common::Result<[LogicalPlan; 2]> {
) -> datafusion_common::Result<StepTransformAction> {
let LogicalPlan::Aggregate(input_aggr) = aggr_plan else {
return Err(datafusion_common::DataFusionError::Plan(
"step_aggr_to_upper_aggr only accepts Aggregate plan".to_string(),
));
};
if !is_all_aggr_exprs_steppable(&input_aggr.aggr_expr) {
return Err(datafusion_common::DataFusionError::NotImplemented(
"Some aggregate expressions are not steppable".to_string(),
));
return Err(datafusion_common::DataFusionError::NotImplemented(format!(
"Some aggregate expressions are not steppable in [{}]",
input_aggr
.aggr_expr
.iter()
.map(|e| e.to_string())
.collect::<Vec<_>>()
.join(", ")
)));
}
let mut upper_aggr_expr = vec![];
for aggr_expr in &input_aggr.aggr_expr {
let Some(aggr_func) = get_aggr_func(aggr_expr) else {
return Err(datafusion_common::DataFusionError::NotImplemented(
"Aggregate function not found".to_string(),
));
};
let col_name = aggr_expr.qualified_name();
let input_column = Expr::Column(datafusion_common::Column::new(col_name.0, col_name.1));
let upper_func = match aggr_func.func.name() {
"sum" | "min" | "max" | "last_value" | "first_value" => {
// aggr_calc(aggr_merge(input_column))) as col_name
let mut new_aggr_func = aggr_func.clone();
new_aggr_func.args = vec![input_column.clone()];
new_aggr_func
}
"count" => {
// sum(input_column) as col_name
let mut new_aggr_func = aggr_func.clone();
new_aggr_func.func = sum_udaf();
new_aggr_func.args = vec![input_column.clone()];
new_aggr_func
}
UDDSKETCH_STATE_NAME | UDDSKETCH_MERGE_NAME => {
// udd_merge(bucket_size, error_rate input_column) as col_name
let mut new_aggr_func = aggr_func.clone();
new_aggr_func.func = Arc::new(UddSketchState::merge_udf_impl());
new_aggr_func.args[2] = input_column.clone();
new_aggr_func
}
HLL_NAME | HLL_MERGE_NAME => {
// hll_merge(input_column) as col_name
let mut new_aggr_func = aggr_func.clone();
new_aggr_func.func = Arc::new(HllState::merge_udf_impl());
new_aggr_func.args = vec![input_column.clone()];
new_aggr_func
}
_ => {
return Err(datafusion_common::DataFusionError::NotImplemented(format!(
"Aggregate function {} is not supported for Step aggregation",
aggr_func.func.name()
)))
}
};
// deal with nested alias case
let mut new_aggr_expr = aggr_expr.clone();
{
let new_aggr_func = get_aggr_func_mut(&mut new_aggr_expr).unwrap();
*new_aggr_func = upper_func;
}
let step_aggr_plan = StateMergeHelper::split_aggr_node(input_aggr.clone())?;
upper_aggr_expr.push(new_aggr_expr);
}
let mut new_aggr = input_aggr.clone();
// use lower aggregate plan as input, this will be replace by merge scan plan later
new_aggr.input = Arc::new(LogicalPlan::Aggregate(input_aggr.clone()));
new_aggr.aggr_expr = upper_aggr_expr;
// group by expr also need to be all ref by column to avoid duplicated computing
let mut new_group_expr = new_aggr.group_expr.clone();
for expr in &mut new_group_expr {
if let Expr::Column(_) = expr {
// already a column, no need to change
continue;
}
let col_name = expr.qualified_name();
let input_column = Expr::Column(datafusion_common::Column::new(col_name.0, col_name.1));
*expr = input_column;
}
new_aggr.group_expr = new_group_expr.clone();
let mut new_projection_exprs = new_group_expr;
// the upper aggr expr need to be aliased to the input aggr expr's name,
// so that the parent plan can recognize it.
for (lower_aggr_expr, upper_aggr_expr) in
input_aggr.aggr_expr.iter().zip(new_aggr.aggr_expr.iter())
{
let lower_col_name = lower_aggr_expr.qualified_name();
let (table, col_name) = upper_aggr_expr.qualified_name();
let aggr_out_column = Column::new(table, col_name);
let aliased_output_aggr_expr =
Expr::Column(aggr_out_column).alias_qualified(lower_col_name.0, lower_col_name.1);
new_projection_exprs.push(aliased_output_aggr_expr);
}
let upper_aggr_plan = LogicalPlan::Aggregate(new_aggr);
let upper_aggr_plan = upper_aggr_plan.recompute_schema()?;
// create a projection on top of the new aggregate plan
let new_projection =
Projection::try_new(new_projection_exprs, Arc::new(upper_aggr_plan.clone()))?;
let projection = LogicalPlan::Projection(new_projection);
// return the new logical plan
Ok([projection, upper_aggr_plan])
// TODO(discord9): remove duplication
let ret = StepTransformAction {
extra_parent_plans: vec![step_aggr_plan.upper_merge.clone()],
new_child_plan: Some(step_aggr_plan.lower_state.clone()),
};
Ok(ret)
}
/// Check if the given aggregate expression is steppable.
@@ -154,25 +76,15 @@ pub fn step_aggr_to_upper_aggr(
/// i.e. on datanode first call `state(input)` then
/// on frontend call `calc(merge(state))` to get the final result.
pub fn is_all_aggr_exprs_steppable(aggr_exprs: &[Expr]) -> bool {
let step_action = HashSet::from([
"sum",
"count",
"min",
"max",
"first_value",
"last_value",
UDDSKETCH_STATE_NAME,
UDDSKETCH_MERGE_NAME,
HLL_NAME,
HLL_MERGE_NAME,
]);
aggr_exprs.iter().all(|expr| {
if let Some(aggr_func) = get_aggr_func(expr) {
if aggr_func.distinct {
// Distinct aggregate functions are not steppable(yet).
return false;
}
step_action.contains(aggr_func.func.name())
// whether the corresponding state function exists in the registry
FUNCTION_REGISTRY.is_aggr_func_exist(&aggr_state_func_name(aggr_func.func.name()))
} else {
false
}
@@ -191,18 +103,6 @@ pub fn get_aggr_func(expr: &Expr) -> Option<&datafusion_expr::expr::AggregateFun
}
}
pub fn get_aggr_func_mut(expr: &mut Expr) -> Option<&mut datafusion_expr::expr::AggregateFunction> {
let mut expr_ref = expr;
while let Expr::Alias(alias) = expr_ref {
expr_ref = &mut alias.expr;
}
if let Expr::AggregateFunction(aggr_func) = expr_ref {
Some(aggr_func)
} else {
None
}
}
#[allow(dead_code)]
pub enum Commutativity {
Commutative,
@@ -247,8 +147,8 @@ impl Categorizer {
debug!("Before Step optimize: {plan}");
let ret = step_aggr_to_upper_aggr(plan);
ret.ok().map(|s| TransformerAction {
extra_parent_plans: s.to_vec(),
new_child_plan: None,
extra_parent_plans: s.extra_parent_plans,
new_child_plan: s.new_child_plan,
})
})),
};