feat: register all aggregate function to auto step aggr fn (#6596)

* feat: support generic aggr push down

Signed-off-by: discord9 <discord9@163.com>

* typo

Signed-off-by: discord9 <discord9@163.com>

* fix: type ck in merge wrapper

Signed-off-by: discord9 <discord9@163.com>

* test: update sqlness

Signed-off-by: discord9 <discord9@163.com>

* feat: support all registried aggr func

Signed-off-by: discord9 <discord9@163.com>

* chore: per review

Signed-off-by: discord9 <discord9@163.com>

* chore: per review

Signed-off-by: discord9 <discord9@163.com>

---------

Signed-off-by: discord9 <discord9@163.com>
This commit is contained in:
discord9
2025-08-05 19:37:45 +08:00
committed by GitHub
parent 9871c22740
commit 875207d26c
12 changed files with 510 additions and 442 deletions

View File

@@ -26,6 +26,8 @@ use std::sync::Arc;
use arrow::array::StructArray;
use arrow_schema::Fields;
use common_telemetry::debug;
use datafusion::functions_aggregate::all_default_aggregate_functions;
use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
use datafusion::optimizer::AnalyzerRule;
use datafusion::physical_planner::create_aggregate_expr_and_maybe_filter;
@@ -39,6 +41,8 @@ use datafusion_expr::{
use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
use datatypes::arrow::datatypes::{DataType, Field};
use crate::function_registry::FunctionRegistry;
/// Returns the name of the state function for the given aggregate function name.
/// The state function is used to compute the state of the aggregate function.
/// The state function's name is in the format `__<aggr_name>_state
@@ -65,9 +69,9 @@ pub struct StateMergeHelper;
#[derive(Debug, Clone)]
pub struct StepAggrPlan {
/// Upper merge plan, which is the aggregate plan that merges the states of the state function.
pub upper_merge: Arc<LogicalPlan>,
pub upper_merge: LogicalPlan,
/// Lower state plan, which is the aggregate plan that computes the state of the aggregate function.
pub lower_state: Arc<LogicalPlan>,
pub lower_state: LogicalPlan,
}
pub fn get_aggr_func(expr: &Expr) -> Option<&datafusion_expr::expr::AggregateFunction> {
@@ -83,6 +87,36 @@ pub fn get_aggr_func(expr: &Expr) -> Option<&datafusion_expr::expr::AggregateFun
}
impl StateMergeHelper {
/// Register all the `state` function of supported aggregate functions.
/// Note that can't register `merge` function here, as it needs to be created from the original aggregate function with given input types.
pub fn register(registry: &FunctionRegistry) {
let all_default = all_default_aggregate_functions();
let greptime_custom_aggr_functions = registry.aggregate_functions();
// if our custom aggregate function have the same name as the default aggregate function, we will override it.
let supported = all_default
.into_iter()
.chain(greptime_custom_aggr_functions.into_iter().map(Arc::new))
.collect::<Vec<_>>();
debug!(
"Registering state functions for supported: {:?}",
supported.iter().map(|f| f.name()).collect::<Vec<_>>()
);
let state_func = supported.into_iter().filter_map(|f| {
StateWrapper::new((*f).clone())
.inspect_err(
|e| common_telemetry::error!(e; "Failed to register state function for {:?}", f),
)
.ok()
.map(AggregateUDF::new_from_impl)
});
for func in state_func {
registry.register_aggr(func);
}
}
/// Split an aggregate plan into two aggregate plans, one for the state function and one for the merge function.
pub fn split_aggr_node(aggr_plan: Aggregate) -> datafusion_common::Result<StepAggrPlan> {
let aggr = {
@@ -166,18 +200,18 @@ impl StateMergeHelper {
let lower_plan = LogicalPlan::Aggregate(lower);
// update aggregate's output schema
let lower_plan = Arc::new(lower_plan.recompute_schema()?);
let lower_plan = lower_plan.recompute_schema()?;
let mut upper = aggr.clone();
let aggr_plan = LogicalPlan::Aggregate(aggr);
upper.aggr_expr = upper_aggr_exprs;
upper.input = lower_plan.clone();
upper.input = Arc::new(lower_plan.clone());
// upper schema's output schema should be the same as the original aggregate plan's output schema
let upper_check = upper.clone();
let upper_plan = Arc::new(LogicalPlan::Aggregate(upper_check).recompute_schema()?);
let upper_check = upper;
let upper_plan = LogicalPlan::Aggregate(upper_check).recompute_schema()?;
if *upper_plan.schema() != *aggr_plan.schema() {
return Err(datafusion_common::DataFusionError::Internal(format!(
"Upper aggregate plan's schema is not the same as the original aggregate plan's schema: \n[transformed]:{}\n[ original]{}",
"Upper aggregate plan's schema is not the same as the original aggregate plan's schema: \n[transformed]:{}\n[original]:{}",
upper_plan.schema(), aggr_plan.schema()
)));
}
@@ -407,15 +441,18 @@ impl AggregateUDFImpl for MergeWrapper {
&'a self,
acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
) -> datafusion_common::Result<Box<dyn Accumulator>> {
if acc_args.schema.fields().len() != 1
|| !matches!(acc_args.schema.field(0).data_type(), DataType::Struct(_))
if acc_args.exprs.len() != 1
|| !matches!(
acc_args.exprs[0].data_type(acc_args.schema)?,
DataType::Struct(_)
)
{
return Err(datafusion_common::DataFusionError::Internal(format!(
"Expected one struct type as input, got: {:?}",
acc_args.schema
)));
}
let input_type = acc_args.schema.field(0).data_type();
let input_type = acc_args.exprs[0].data_type(acc_args.schema)?;
let DataType::Struct(fields) = input_type else {
return Err(datafusion_common::DataFusionError::Internal(format!(
"Expected a struct type for input, got: {:?}",
@@ -424,7 +461,7 @@ impl AggregateUDFImpl for MergeWrapper {
};
let inner_accum = self.original_phy_expr.create_accumulator()?;
Ok(Box::new(MergeAccum::new(inner_accum, fields)))
Ok(Box::new(MergeAccum::new(inner_accum, &fields)))
}
fn as_any(&self) -> &dyn std::any::Any {

View File

@@ -258,7 +258,7 @@ async fn test_sum_udaf() {
)
.recompute_schema()
.unwrap();
assert_eq!(res.lower_state.as_ref(), &expected_lower_plan);
assert_eq!(&res.lower_state, &expected_lower_plan);
let expected_merge_plan = LogicalPlan::Aggregate(
Aggregate::try_new(
@@ -297,7 +297,7 @@ async fn test_sum_udaf() {
)
.unwrap(),
);
assert_eq!(res.upper_merge.as_ref(), &expected_merge_plan);
assert_eq!(&res.upper_merge, &expected_merge_plan);
let phy_aggr_state_plan = DefaultPhysicalPlanner::default()
.create_physical_plan(&res.lower_state, &ctx.state())
@@ -405,7 +405,7 @@ async fn test_avg_udaf() {
let coerced_aggr_state_plan = TypeCoercion::new()
.analyze(expected_aggr_state_plan.clone(), &Default::default())
.unwrap();
assert_eq!(res.lower_state.as_ref(), &coerced_aggr_state_plan);
assert_eq!(&res.lower_state, &coerced_aggr_state_plan);
assert_eq!(
res.lower_state.schema().as_arrow(),
&arrow_schema::Schema::new(vec![Field::new(
@@ -456,7 +456,7 @@ async fn test_avg_udaf() {
)
.unwrap(),
);
assert_eq!(res.upper_merge.as_ref(), &expected_merge_plan);
assert_eq!(&res.upper_merge, &expected_merge_plan);
let phy_aggr_state_plan = DefaultPhysicalPlanner::default()
.create_physical_plan(&coerced_aggr_state_plan, &ctx.state())

View File

@@ -20,6 +20,7 @@ use datafusion_expr::AggregateUDF;
use once_cell::sync::Lazy;
use crate::admin::AdminFunction;
use crate::aggrs::aggr_wrapper::StateMergeHelper;
use crate::aggrs::approximate::ApproximateFunction;
use crate::aggrs::count_hash::CountHash;
use crate::aggrs::vector::VectorFunction as VectorAggrFunction;
@@ -105,6 +106,10 @@ impl FunctionRegistry {
.cloned()
.collect()
}
pub fn is_aggr_func_exist(&self, name: &str) -> bool {
self.aggregate_functions.read().unwrap().contains_key(name)
}
}
pub static FUNCTION_REGISTRY: Lazy<Arc<FunctionRegistry>> = Lazy::new(|| {
@@ -148,6 +153,9 @@ pub static FUNCTION_REGISTRY: Lazy<Arc<FunctionRegistry>> = Lazy::new(|| {
// CountHash function
CountHash::register(&function_registry);
// state function of supported aggregate functions
StateMergeHelper::register(&function_registry);
Arc::new(function_registry)
});