mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-24 08:50:40 +00:00
feat: register all aggregate function to auto step aggr fn (#6596)
* feat: support generic aggr push down Signed-off-by: discord9 <discord9@163.com> * typo Signed-off-by: discord9 <discord9@163.com> * fix: type ck in merge wrapper Signed-off-by: discord9 <discord9@163.com> * test: update sqlness Signed-off-by: discord9 <discord9@163.com> * feat: support all registried aggr func Signed-off-by: discord9 <discord9@163.com> * chore: per review Signed-off-by: discord9 <discord9@163.com> * chore: per review Signed-off-by: discord9 <discord9@163.com> --------- Signed-off-by: discord9 <discord9@163.com>
This commit is contained in:
@@ -26,6 +26,8 @@ use std::sync::Arc;
|
||||
|
||||
use arrow::array::StructArray;
|
||||
use arrow_schema::Fields;
|
||||
use common_telemetry::debug;
|
||||
use datafusion::functions_aggregate::all_default_aggregate_functions;
|
||||
use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
|
||||
use datafusion::optimizer::AnalyzerRule;
|
||||
use datafusion::physical_planner::create_aggregate_expr_and_maybe_filter;
|
||||
@@ -39,6 +41,8 @@ use datafusion_expr::{
|
||||
use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
|
||||
use datatypes::arrow::datatypes::{DataType, Field};
|
||||
|
||||
use crate::function_registry::FunctionRegistry;
|
||||
|
||||
/// Returns the name of the state function for the given aggregate function name.
|
||||
/// The state function is used to compute the state of the aggregate function.
|
||||
/// The state function's name is in the format `__<aggr_name>_state
|
||||
@@ -65,9 +69,9 @@ pub struct StateMergeHelper;
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StepAggrPlan {
|
||||
/// Upper merge plan, which is the aggregate plan that merges the states of the state function.
|
||||
pub upper_merge: Arc<LogicalPlan>,
|
||||
pub upper_merge: LogicalPlan,
|
||||
/// Lower state plan, which is the aggregate plan that computes the state of the aggregate function.
|
||||
pub lower_state: Arc<LogicalPlan>,
|
||||
pub lower_state: LogicalPlan,
|
||||
}
|
||||
|
||||
pub fn get_aggr_func(expr: &Expr) -> Option<&datafusion_expr::expr::AggregateFunction> {
|
||||
@@ -83,6 +87,36 @@ pub fn get_aggr_func(expr: &Expr) -> Option<&datafusion_expr::expr::AggregateFun
|
||||
}
|
||||
|
||||
impl StateMergeHelper {
|
||||
/// Register all the `state` function of supported aggregate functions.
|
||||
/// Note that can't register `merge` function here, as it needs to be created from the original aggregate function with given input types.
|
||||
pub fn register(registry: &FunctionRegistry) {
|
||||
let all_default = all_default_aggregate_functions();
|
||||
let greptime_custom_aggr_functions = registry.aggregate_functions();
|
||||
|
||||
// if our custom aggregate function have the same name as the default aggregate function, we will override it.
|
||||
let supported = all_default
|
||||
.into_iter()
|
||||
.chain(greptime_custom_aggr_functions.into_iter().map(Arc::new))
|
||||
.collect::<Vec<_>>();
|
||||
debug!(
|
||||
"Registering state functions for supported: {:?}",
|
||||
supported.iter().map(|f| f.name()).collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
let state_func = supported.into_iter().filter_map(|f| {
|
||||
StateWrapper::new((*f).clone())
|
||||
.inspect_err(
|
||||
|e| common_telemetry::error!(e; "Failed to register state function for {:?}", f),
|
||||
)
|
||||
.ok()
|
||||
.map(AggregateUDF::new_from_impl)
|
||||
});
|
||||
|
||||
for func in state_func {
|
||||
registry.register_aggr(func);
|
||||
}
|
||||
}
|
||||
|
||||
/// Split an aggregate plan into two aggregate plans, one for the state function and one for the merge function.
|
||||
pub fn split_aggr_node(aggr_plan: Aggregate) -> datafusion_common::Result<StepAggrPlan> {
|
||||
let aggr = {
|
||||
@@ -166,18 +200,18 @@ impl StateMergeHelper {
|
||||
let lower_plan = LogicalPlan::Aggregate(lower);
|
||||
|
||||
// update aggregate's output schema
|
||||
let lower_plan = Arc::new(lower_plan.recompute_schema()?);
|
||||
let lower_plan = lower_plan.recompute_schema()?;
|
||||
|
||||
let mut upper = aggr.clone();
|
||||
let aggr_plan = LogicalPlan::Aggregate(aggr);
|
||||
upper.aggr_expr = upper_aggr_exprs;
|
||||
upper.input = lower_plan.clone();
|
||||
upper.input = Arc::new(lower_plan.clone());
|
||||
// upper schema's output schema should be the same as the original aggregate plan's output schema
|
||||
let upper_check = upper.clone();
|
||||
let upper_plan = Arc::new(LogicalPlan::Aggregate(upper_check).recompute_schema()?);
|
||||
let upper_check = upper;
|
||||
let upper_plan = LogicalPlan::Aggregate(upper_check).recompute_schema()?;
|
||||
if *upper_plan.schema() != *aggr_plan.schema() {
|
||||
return Err(datafusion_common::DataFusionError::Internal(format!(
|
||||
"Upper aggregate plan's schema is not the same as the original aggregate plan's schema: \n[transformed]:{}\n[ original]{}",
|
||||
"Upper aggregate plan's schema is not the same as the original aggregate plan's schema: \n[transformed]:{}\n[original]:{}",
|
||||
upper_plan.schema(), aggr_plan.schema()
|
||||
)));
|
||||
}
|
||||
@@ -407,15 +441,18 @@ impl AggregateUDFImpl for MergeWrapper {
|
||||
&'a self,
|
||||
acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
|
||||
) -> datafusion_common::Result<Box<dyn Accumulator>> {
|
||||
if acc_args.schema.fields().len() != 1
|
||||
|| !matches!(acc_args.schema.field(0).data_type(), DataType::Struct(_))
|
||||
if acc_args.exprs.len() != 1
|
||||
|| !matches!(
|
||||
acc_args.exprs[0].data_type(acc_args.schema)?,
|
||||
DataType::Struct(_)
|
||||
)
|
||||
{
|
||||
return Err(datafusion_common::DataFusionError::Internal(format!(
|
||||
"Expected one struct type as input, got: {:?}",
|
||||
acc_args.schema
|
||||
)));
|
||||
}
|
||||
let input_type = acc_args.schema.field(0).data_type();
|
||||
let input_type = acc_args.exprs[0].data_type(acc_args.schema)?;
|
||||
let DataType::Struct(fields) = input_type else {
|
||||
return Err(datafusion_common::DataFusionError::Internal(format!(
|
||||
"Expected a struct type for input, got: {:?}",
|
||||
@@ -424,7 +461,7 @@ impl AggregateUDFImpl for MergeWrapper {
|
||||
};
|
||||
|
||||
let inner_accum = self.original_phy_expr.create_accumulator()?;
|
||||
Ok(Box::new(MergeAccum::new(inner_accum, fields)))
|
||||
Ok(Box::new(MergeAccum::new(inner_accum, &fields)))
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
|
||||
@@ -258,7 +258,7 @@ async fn test_sum_udaf() {
|
||||
)
|
||||
.recompute_schema()
|
||||
.unwrap();
|
||||
assert_eq!(res.lower_state.as_ref(), &expected_lower_plan);
|
||||
assert_eq!(&res.lower_state, &expected_lower_plan);
|
||||
|
||||
let expected_merge_plan = LogicalPlan::Aggregate(
|
||||
Aggregate::try_new(
|
||||
@@ -297,7 +297,7 @@ async fn test_sum_udaf() {
|
||||
)
|
||||
.unwrap(),
|
||||
);
|
||||
assert_eq!(res.upper_merge.as_ref(), &expected_merge_plan);
|
||||
assert_eq!(&res.upper_merge, &expected_merge_plan);
|
||||
|
||||
let phy_aggr_state_plan = DefaultPhysicalPlanner::default()
|
||||
.create_physical_plan(&res.lower_state, &ctx.state())
|
||||
@@ -405,7 +405,7 @@ async fn test_avg_udaf() {
|
||||
let coerced_aggr_state_plan = TypeCoercion::new()
|
||||
.analyze(expected_aggr_state_plan.clone(), &Default::default())
|
||||
.unwrap();
|
||||
assert_eq!(res.lower_state.as_ref(), &coerced_aggr_state_plan);
|
||||
assert_eq!(&res.lower_state, &coerced_aggr_state_plan);
|
||||
assert_eq!(
|
||||
res.lower_state.schema().as_arrow(),
|
||||
&arrow_schema::Schema::new(vec![Field::new(
|
||||
@@ -456,7 +456,7 @@ async fn test_avg_udaf() {
|
||||
)
|
||||
.unwrap(),
|
||||
);
|
||||
assert_eq!(res.upper_merge.as_ref(), &expected_merge_plan);
|
||||
assert_eq!(&res.upper_merge, &expected_merge_plan);
|
||||
|
||||
let phy_aggr_state_plan = DefaultPhysicalPlanner::default()
|
||||
.create_physical_plan(&coerced_aggr_state_plan, &ctx.state())
|
||||
|
||||
@@ -20,6 +20,7 @@ use datafusion_expr::AggregateUDF;
|
||||
use once_cell::sync::Lazy;
|
||||
|
||||
use crate::admin::AdminFunction;
|
||||
use crate::aggrs::aggr_wrapper::StateMergeHelper;
|
||||
use crate::aggrs::approximate::ApproximateFunction;
|
||||
use crate::aggrs::count_hash::CountHash;
|
||||
use crate::aggrs::vector::VectorFunction as VectorAggrFunction;
|
||||
@@ -105,6 +106,10 @@ impl FunctionRegistry {
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn is_aggr_func_exist(&self, name: &str) -> bool {
|
||||
self.aggregate_functions.read().unwrap().contains_key(name)
|
||||
}
|
||||
}
|
||||
|
||||
pub static FUNCTION_REGISTRY: Lazy<Arc<FunctionRegistry>> = Lazy::new(|| {
|
||||
@@ -148,6 +153,9 @@ pub static FUNCTION_REGISTRY: Lazy<Arc<FunctionRegistry>> = Lazy::new(|| {
|
||||
// CountHash function
|
||||
CountHash::register(&function_registry);
|
||||
|
||||
// state function of supported aggregate functions
|
||||
StateMergeHelper::register(&function_registry);
|
||||
|
||||
Arc::new(function_registry)
|
||||
});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user