mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-24 00:40:40 +00:00
feat: register all aggregate function to auto step aggr fn (#6596)
* feat: support generic aggr push down Signed-off-by: discord9 <discord9@163.com> * typo Signed-off-by: discord9 <discord9@163.com> * fix: type ck in merge wrapper Signed-off-by: discord9 <discord9@163.com> * test: update sqlness Signed-off-by: discord9 <discord9@163.com> * feat: support all registried aggr func Signed-off-by: discord9 <discord9@163.com> * chore: per review Signed-off-by: discord9 <discord9@163.com> * chore: per review Signed-off-by: discord9 <discord9@163.com> --------- Signed-off-by: discord9 <discord9@163.com>
This commit is contained in:
@@ -26,6 +26,8 @@ use std::sync::Arc;
|
||||
|
||||
use arrow::array::StructArray;
|
||||
use arrow_schema::Fields;
|
||||
use common_telemetry::debug;
|
||||
use datafusion::functions_aggregate::all_default_aggregate_functions;
|
||||
use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
|
||||
use datafusion::optimizer::AnalyzerRule;
|
||||
use datafusion::physical_planner::create_aggregate_expr_and_maybe_filter;
|
||||
@@ -39,6 +41,8 @@ use datafusion_expr::{
|
||||
use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
|
||||
use datatypes::arrow::datatypes::{DataType, Field};
|
||||
|
||||
use crate::function_registry::FunctionRegistry;
|
||||
|
||||
/// Returns the name of the state function for the given aggregate function name.
|
||||
/// The state function is used to compute the state of the aggregate function.
|
||||
/// The state function's name is in the format `__<aggr_name>_state
|
||||
@@ -65,9 +69,9 @@ pub struct StateMergeHelper;
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StepAggrPlan {
|
||||
/// Upper merge plan, which is the aggregate plan that merges the states of the state function.
|
||||
pub upper_merge: Arc<LogicalPlan>,
|
||||
pub upper_merge: LogicalPlan,
|
||||
/// Lower state plan, which is the aggregate plan that computes the state of the aggregate function.
|
||||
pub lower_state: Arc<LogicalPlan>,
|
||||
pub lower_state: LogicalPlan,
|
||||
}
|
||||
|
||||
pub fn get_aggr_func(expr: &Expr) -> Option<&datafusion_expr::expr::AggregateFunction> {
|
||||
@@ -83,6 +87,36 @@ pub fn get_aggr_func(expr: &Expr) -> Option<&datafusion_expr::expr::AggregateFun
|
||||
}
|
||||
|
||||
impl StateMergeHelper {
|
||||
/// Register all the `state` function of supported aggregate functions.
|
||||
/// Note that can't register `merge` function here, as it needs to be created from the original aggregate function with given input types.
|
||||
pub fn register(registry: &FunctionRegistry) {
|
||||
let all_default = all_default_aggregate_functions();
|
||||
let greptime_custom_aggr_functions = registry.aggregate_functions();
|
||||
|
||||
// if our custom aggregate function have the same name as the default aggregate function, we will override it.
|
||||
let supported = all_default
|
||||
.into_iter()
|
||||
.chain(greptime_custom_aggr_functions.into_iter().map(Arc::new))
|
||||
.collect::<Vec<_>>();
|
||||
debug!(
|
||||
"Registering state functions for supported: {:?}",
|
||||
supported.iter().map(|f| f.name()).collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
let state_func = supported.into_iter().filter_map(|f| {
|
||||
StateWrapper::new((*f).clone())
|
||||
.inspect_err(
|
||||
|e| common_telemetry::error!(e; "Failed to register state function for {:?}", f),
|
||||
)
|
||||
.ok()
|
||||
.map(AggregateUDF::new_from_impl)
|
||||
});
|
||||
|
||||
for func in state_func {
|
||||
registry.register_aggr(func);
|
||||
}
|
||||
}
|
||||
|
||||
/// Split an aggregate plan into two aggregate plans, one for the state function and one for the merge function.
|
||||
pub fn split_aggr_node(aggr_plan: Aggregate) -> datafusion_common::Result<StepAggrPlan> {
|
||||
let aggr = {
|
||||
@@ -166,18 +200,18 @@ impl StateMergeHelper {
|
||||
let lower_plan = LogicalPlan::Aggregate(lower);
|
||||
|
||||
// update aggregate's output schema
|
||||
let lower_plan = Arc::new(lower_plan.recompute_schema()?);
|
||||
let lower_plan = lower_plan.recompute_schema()?;
|
||||
|
||||
let mut upper = aggr.clone();
|
||||
let aggr_plan = LogicalPlan::Aggregate(aggr);
|
||||
upper.aggr_expr = upper_aggr_exprs;
|
||||
upper.input = lower_plan.clone();
|
||||
upper.input = Arc::new(lower_plan.clone());
|
||||
// upper schema's output schema should be the same as the original aggregate plan's output schema
|
||||
let upper_check = upper.clone();
|
||||
let upper_plan = Arc::new(LogicalPlan::Aggregate(upper_check).recompute_schema()?);
|
||||
let upper_check = upper;
|
||||
let upper_plan = LogicalPlan::Aggregate(upper_check).recompute_schema()?;
|
||||
if *upper_plan.schema() != *aggr_plan.schema() {
|
||||
return Err(datafusion_common::DataFusionError::Internal(format!(
|
||||
"Upper aggregate plan's schema is not the same as the original aggregate plan's schema: \n[transformed]:{}\n[ original]{}",
|
||||
"Upper aggregate plan's schema is not the same as the original aggregate plan's schema: \n[transformed]:{}\n[original]:{}",
|
||||
upper_plan.schema(), aggr_plan.schema()
|
||||
)));
|
||||
}
|
||||
@@ -407,15 +441,18 @@ impl AggregateUDFImpl for MergeWrapper {
|
||||
&'a self,
|
||||
acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
|
||||
) -> datafusion_common::Result<Box<dyn Accumulator>> {
|
||||
if acc_args.schema.fields().len() != 1
|
||||
|| !matches!(acc_args.schema.field(0).data_type(), DataType::Struct(_))
|
||||
if acc_args.exprs.len() != 1
|
||||
|| !matches!(
|
||||
acc_args.exprs[0].data_type(acc_args.schema)?,
|
||||
DataType::Struct(_)
|
||||
)
|
||||
{
|
||||
return Err(datafusion_common::DataFusionError::Internal(format!(
|
||||
"Expected one struct type as input, got: {:?}",
|
||||
acc_args.schema
|
||||
)));
|
||||
}
|
||||
let input_type = acc_args.schema.field(0).data_type();
|
||||
let input_type = acc_args.exprs[0].data_type(acc_args.schema)?;
|
||||
let DataType::Struct(fields) = input_type else {
|
||||
return Err(datafusion_common::DataFusionError::Internal(format!(
|
||||
"Expected a struct type for input, got: {:?}",
|
||||
@@ -424,7 +461,7 @@ impl AggregateUDFImpl for MergeWrapper {
|
||||
};
|
||||
|
||||
let inner_accum = self.original_phy_expr.create_accumulator()?;
|
||||
Ok(Box::new(MergeAccum::new(inner_accum, fields)))
|
||||
Ok(Box::new(MergeAccum::new(inner_accum, &fields)))
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
|
||||
@@ -258,7 +258,7 @@ async fn test_sum_udaf() {
|
||||
)
|
||||
.recompute_schema()
|
||||
.unwrap();
|
||||
assert_eq!(res.lower_state.as_ref(), &expected_lower_plan);
|
||||
assert_eq!(&res.lower_state, &expected_lower_plan);
|
||||
|
||||
let expected_merge_plan = LogicalPlan::Aggregate(
|
||||
Aggregate::try_new(
|
||||
@@ -297,7 +297,7 @@ async fn test_sum_udaf() {
|
||||
)
|
||||
.unwrap(),
|
||||
);
|
||||
assert_eq!(res.upper_merge.as_ref(), &expected_merge_plan);
|
||||
assert_eq!(&res.upper_merge, &expected_merge_plan);
|
||||
|
||||
let phy_aggr_state_plan = DefaultPhysicalPlanner::default()
|
||||
.create_physical_plan(&res.lower_state, &ctx.state())
|
||||
@@ -405,7 +405,7 @@ async fn test_avg_udaf() {
|
||||
let coerced_aggr_state_plan = TypeCoercion::new()
|
||||
.analyze(expected_aggr_state_plan.clone(), &Default::default())
|
||||
.unwrap();
|
||||
assert_eq!(res.lower_state.as_ref(), &coerced_aggr_state_plan);
|
||||
assert_eq!(&res.lower_state, &coerced_aggr_state_plan);
|
||||
assert_eq!(
|
||||
res.lower_state.schema().as_arrow(),
|
||||
&arrow_schema::Schema::new(vec![Field::new(
|
||||
@@ -456,7 +456,7 @@ async fn test_avg_udaf() {
|
||||
)
|
||||
.unwrap(),
|
||||
);
|
||||
assert_eq!(res.upper_merge.as_ref(), &expected_merge_plan);
|
||||
assert_eq!(&res.upper_merge, &expected_merge_plan);
|
||||
|
||||
let phy_aggr_state_plan = DefaultPhysicalPlanner::default()
|
||||
.create_physical_plan(&coerced_aggr_state_plan, &ctx.state())
|
||||
|
||||
@@ -20,6 +20,7 @@ use datafusion_expr::AggregateUDF;
|
||||
use once_cell::sync::Lazy;
|
||||
|
||||
use crate::admin::AdminFunction;
|
||||
use crate::aggrs::aggr_wrapper::StateMergeHelper;
|
||||
use crate::aggrs::approximate::ApproximateFunction;
|
||||
use crate::aggrs::count_hash::CountHash;
|
||||
use crate::aggrs::vector::VectorFunction as VectorAggrFunction;
|
||||
@@ -105,6 +106,10 @@ impl FunctionRegistry {
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn is_aggr_func_exist(&self, name: &str) -> bool {
|
||||
self.aggregate_functions.read().unwrap().contains_key(name)
|
||||
}
|
||||
}
|
||||
|
||||
pub static FUNCTION_REGISTRY: Lazy<Arc<FunctionRegistry>> = Lazy::new(|| {
|
||||
@@ -148,6 +153,9 @@ pub static FUNCTION_REGISTRY: Lazy<Arc<FunctionRegistry>> = Lazy::new(|| {
|
||||
// CountHash function
|
||||
CountHash::register(&function_registry);
|
||||
|
||||
// state function of supported aggregate functions
|
||||
StateMergeHelper::register(&function_registry);
|
||||
|
||||
Arc::new(function_registry)
|
||||
});
|
||||
|
||||
|
||||
@@ -369,6 +369,9 @@ impl PlanRewriter {
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
);
|
||||
if let Some(new_child_plan) = &transformer_actions.new_child_plan {
|
||||
debug!("PlanRewriter: new child plan: {}", new_child_plan);
|
||||
}
|
||||
if let Some(last_stage) = transformer_actions.extra_parent_plans.last() {
|
||||
// update the column requirements from the last stage
|
||||
// notice current plan's parent plan is where we need to apply the column requirements
|
||||
@@ -501,12 +504,12 @@ impl PlanRewriter {
|
||||
}
|
||||
|
||||
fn expand(&mut self, mut on_node: LogicalPlan) -> DfResult<LogicalPlan> {
|
||||
// store schema before expand, new child plan might have a different schema, so not using it
|
||||
let schema = on_node.schema().clone();
|
||||
if let Some(new_child_plan) = self.new_child_plan.take() {
|
||||
// if there is a new child plan, use it as the new root
|
||||
on_node = new_child_plan;
|
||||
}
|
||||
// store schema before expand
|
||||
let schema = on_node.schema().clone();
|
||||
let mut rewriter = EnforceDistRequirementRewriter::new(
|
||||
std::mem::take(&mut self.column_requirements),
|
||||
self.level,
|
||||
@@ -514,7 +517,7 @@ impl PlanRewriter {
|
||||
debug!("PlanRewriter: enforce column requirements for node: {on_node} with rewriter: {rewriter:?}");
|
||||
on_node = on_node.rewrite(&mut rewriter)?.data;
|
||||
debug!(
|
||||
"PlanRewriter: after enforced column requirements for node: {on_node} with rewriter: {rewriter:?}"
|
||||
"PlanRewriter: after enforced column requirements with rewriter: {rewriter:?} for node:\n{on_node}"
|
||||
);
|
||||
|
||||
// add merge scan as the new root
|
||||
@@ -702,7 +705,10 @@ impl TreeNodeRewriter for PlanRewriter {
|
||||
// TODO(ruihang): avoid this clone
|
||||
if self.should_expand(&parent) {
|
||||
// TODO(ruihang): does this work for nodes with multiple children?;
|
||||
debug!("PlanRewriter: should expand child:\n {node}\n Of Parent: {parent}");
|
||||
debug!(
|
||||
"PlanRewriter: should expand child:\n {node}\n Of Parent: {}",
|
||||
parent.display()
|
||||
);
|
||||
let node = self.expand(node);
|
||||
debug!(
|
||||
"PlanRewriter: expanded plan: {}",
|
||||
|
||||
@@ -456,10 +456,9 @@ fn expand_proj_step_aggr() {
|
||||
|
||||
let expected = [
|
||||
"Projection: min(t.number)",
|
||||
" Projection: min(min(t.number)) AS min(t.number)",
|
||||
" Aggregate: groupBy=[[]], aggr=[[min(min(t.number))]]",
|
||||
" MergeScan [is_placeholder=false, remote_input=[",
|
||||
"Aggregate: groupBy=[[]], aggr=[[min(t.number)]]",
|
||||
" Aggregate: groupBy=[[]], aggr=[[__min_merge(__min_state(t.number)) AS min(t.number)]]",
|
||||
" MergeScan [is_placeholder=false, remote_input=[",
|
||||
"Aggregate: groupBy=[[]], aggr=[[__min_state(t.number)]]",
|
||||
" Projection: t.number", // This Projection shouldn't add new column requirements
|
||||
" TableScan: t",
|
||||
"]]",
|
||||
@@ -502,10 +501,9 @@ fn expand_proj_alias_fake_part_col_aggr() {
|
||||
|
||||
let expected = [
|
||||
"Projection: pk1, pk2, min(t.number)",
|
||||
" Projection: pk1, pk2, min(min(t.number)) AS min(t.number)",
|
||||
" Aggregate: groupBy=[[pk1, pk2]], aggr=[[min(min(t.number))]]",
|
||||
" MergeScan [is_placeholder=false, remote_input=[",
|
||||
"Aggregate: groupBy=[[pk1, pk2]], aggr=[[min(t.number)]]",
|
||||
" Aggregate: groupBy=[[pk1, pk2]], aggr=[[__min_merge(__min_state(t.number)) AS min(t.number)]]",
|
||||
" MergeScan [is_placeholder=false, remote_input=[",
|
||||
"Aggregate: groupBy=[[pk1, pk2]], aggr=[[__min_state(t.number)]]",
|
||||
" Projection: t.number, pk1 AS pk2, pk3 AS pk1",
|
||||
" Projection: t.number, t.pk3 AS pk1, t.pk2 AS pk3",
|
||||
" TableScan: t",
|
||||
@@ -583,10 +581,9 @@ fn expand_part_col_aggr_step_aggr() {
|
||||
|
||||
let expected = [
|
||||
"Projection: min(max(t.number))",
|
||||
" Projection: min(min(max(t.number))) AS min(max(t.number))",
|
||||
" Aggregate: groupBy=[[]], aggr=[[min(min(max(t.number)))]]",
|
||||
" MergeScan [is_placeholder=false, remote_input=[",
|
||||
"Aggregate: groupBy=[[]], aggr=[[min(max(t.number))]]",
|
||||
" Aggregate: groupBy=[[]], aggr=[[__min_merge(__min_state(max(t.number))) AS min(max(t.number))]]",
|
||||
" MergeScan [is_placeholder=false, remote_input=[",
|
||||
"Aggregate: groupBy=[[]], aggr=[[__min_state(max(t.number))]]",
|
||||
" Aggregate: groupBy=[[t.pk1, t.pk2]], aggr=[[max(t.number)]]",
|
||||
" TableScan: t",
|
||||
"]]",
|
||||
@@ -618,10 +615,9 @@ fn expand_step_aggr_step_aggr() {
|
||||
let expected = [
|
||||
"Aggregate: groupBy=[[]], aggr=[[min(max(t.number))]]",
|
||||
" Projection: max(t.number)",
|
||||
" Projection: max(max(t.number)) AS max(t.number)",
|
||||
" Aggregate: groupBy=[[]], aggr=[[max(max(t.number))]]",
|
||||
" MergeScan [is_placeholder=false, remote_input=[",
|
||||
"Aggregate: groupBy=[[]], aggr=[[max(t.number)]]",
|
||||
" Aggregate: groupBy=[[]], aggr=[[__max_merge(__max_state(t.number)) AS max(t.number)]]",
|
||||
" MergeScan [is_placeholder=false, remote_input=[",
|
||||
"Aggregate: groupBy=[[]], aggr=[[__max_state(t.number)]]",
|
||||
" TableScan: t",
|
||||
"]]",
|
||||
]
|
||||
@@ -695,10 +691,9 @@ fn expand_step_aggr_proj() {
|
||||
let expected = [
|
||||
"Projection: min(t.number)",
|
||||
" Projection: t.pk1, min(t.number)",
|
||||
" Projection: t.pk1, min(min(t.number)) AS min(t.number)",
|
||||
" Aggregate: groupBy=[[t.pk1]], aggr=[[min(min(t.number))]]",
|
||||
" MergeScan [is_placeholder=false, remote_input=[",
|
||||
"Aggregate: groupBy=[[t.pk1]], aggr=[[min(t.number)]]",
|
||||
" Aggregate: groupBy=[[t.pk1]], aggr=[[__min_merge(__min_state(t.number)) AS min(t.number)]]",
|
||||
" MergeScan [is_placeholder=false, remote_input=[",
|
||||
"Aggregate: groupBy=[[t.pk1]], aggr=[[__min_state(t.number)]]",
|
||||
" TableScan: t",
|
||||
"]]",
|
||||
]
|
||||
@@ -1109,10 +1104,42 @@ fn expand_step_aggr_limit() {
|
||||
let expected = [
|
||||
"Limit: skip=0, fetch=10",
|
||||
" Projection: t.pk1, min(t.number)",
|
||||
" Projection: t.pk1, min(min(t.number)) AS min(t.number)",
|
||||
" Aggregate: groupBy=[[t.pk1]], aggr=[[min(min(t.number))]]",
|
||||
" MergeScan [is_placeholder=false, remote_input=[",
|
||||
"Aggregate: groupBy=[[t.pk1]], aggr=[[min(t.number)]]",
|
||||
" Aggregate: groupBy=[[t.pk1]], aggr=[[__min_merge(__min_state(t.number)) AS min(t.number)]]",
|
||||
" MergeScan [is_placeholder=false, remote_input=[",
|
||||
"Aggregate: groupBy=[[t.pk1]], aggr=[[__min_state(t.number)]]",
|
||||
" TableScan: t",
|
||||
"]]",
|
||||
]
|
||||
.join("\n");
|
||||
assert_eq!(expected, result.to_string());
|
||||
}
|
||||
|
||||
/// Test how avg get expanded
|
||||
#[test]
|
||||
fn expand_step_aggr_avg_limit() {
|
||||
// use logging for better debugging
|
||||
init_default_ut_logging();
|
||||
let test_table = TestTable::table_with_name(0, "numbers".to_string());
|
||||
let table_source = Arc::new(DefaultTableSource::new(Arc::new(
|
||||
DfTableProviderAdapter::new(test_table),
|
||||
)));
|
||||
let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
|
||||
.unwrap()
|
||||
.aggregate(vec![col("pk1")], vec![avg(col("number"))])
|
||||
.unwrap()
|
||||
.limit(0, Some(10))
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let config = ConfigOptions::default();
|
||||
let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
|
||||
let expected = [
|
||||
"Limit: skip=0, fetch=10",
|
||||
" Projection: t.pk1, avg(t.number)",
|
||||
" Aggregate: groupBy=[[t.pk1]], aggr=[[__avg_merge(__avg_state(t.number)) AS avg(t.number)]]",
|
||||
" MergeScan [is_placeholder=false, remote_input=[",
|
||||
"Aggregate: groupBy=[[t.pk1]], aggr=[[__avg_state(CAST(t.number AS Float64))]]",
|
||||
" TableScan: t",
|
||||
"]]",
|
||||
]
|
||||
|
||||
@@ -15,14 +15,10 @@
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_function::aggrs::approximate::hll::{HllState, HLL_MERGE_NAME, HLL_NAME};
|
||||
use common_function::aggrs::approximate::uddsketch::{
|
||||
UddSketchState, UDDSKETCH_MERGE_NAME, UDDSKETCH_STATE_NAME,
|
||||
};
|
||||
use common_function::aggrs::aggr_wrapper::{aggr_state_func_name, StateMergeHelper};
|
||||
use common_function::function_registry::FUNCTION_REGISTRY;
|
||||
use common_telemetry::debug;
|
||||
use datafusion::functions_aggregate::sum::sum_udaf;
|
||||
use datafusion_common::Column;
|
||||
use datafusion_expr::{Expr, LogicalPlan, Projection, UserDefinedLogicalNode};
|
||||
use datafusion_expr::{Expr, LogicalPlan, UserDefinedLogicalNode};
|
||||
use promql::extension_plan::{
|
||||
EmptyMetric, InstantManipulate, RangeManipulate, SeriesDivide, SeriesNormalize,
|
||||
};
|
||||
@@ -31,6 +27,11 @@ use crate::dist_plan::analyzer::AliasMapping;
|
||||
use crate::dist_plan::merge_sort::{merge_sort_transformer, MergeSortLogicalPlan};
|
||||
use crate::dist_plan::MergeScanLogicalPlan;
|
||||
|
||||
pub struct StepTransformAction {
|
||||
extra_parent_plans: Vec<LogicalPlan>,
|
||||
new_child_plan: Option<LogicalPlan>,
|
||||
}
|
||||
|
||||
/// generate the upper aggregation plan that will execute on the frontend.
|
||||
/// Basically a logical plan resembling the following:
|
||||
/// Projection:
|
||||
@@ -42,111 +43,32 @@ use crate::dist_plan::MergeScanLogicalPlan;
|
||||
/// of the upper aggregation plan.
|
||||
pub fn step_aggr_to_upper_aggr(
|
||||
aggr_plan: &LogicalPlan,
|
||||
) -> datafusion_common::Result<[LogicalPlan; 2]> {
|
||||
) -> datafusion_common::Result<StepTransformAction> {
|
||||
let LogicalPlan::Aggregate(input_aggr) = aggr_plan else {
|
||||
return Err(datafusion_common::DataFusionError::Plan(
|
||||
"step_aggr_to_upper_aggr only accepts Aggregate plan".to_string(),
|
||||
));
|
||||
};
|
||||
if !is_all_aggr_exprs_steppable(&input_aggr.aggr_expr) {
|
||||
return Err(datafusion_common::DataFusionError::NotImplemented(
|
||||
"Some aggregate expressions are not steppable".to_string(),
|
||||
));
|
||||
return Err(datafusion_common::DataFusionError::NotImplemented(format!(
|
||||
"Some aggregate expressions are not steppable in [{}]",
|
||||
input_aggr
|
||||
.aggr_expr
|
||||
.iter()
|
||||
.map(|e| e.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
)));
|
||||
}
|
||||
let mut upper_aggr_expr = vec![];
|
||||
for aggr_expr in &input_aggr.aggr_expr {
|
||||
let Some(aggr_func) = get_aggr_func(aggr_expr) else {
|
||||
return Err(datafusion_common::DataFusionError::NotImplemented(
|
||||
"Aggregate function not found".to_string(),
|
||||
));
|
||||
};
|
||||
let col_name = aggr_expr.qualified_name();
|
||||
let input_column = Expr::Column(datafusion_common::Column::new(col_name.0, col_name.1));
|
||||
let upper_func = match aggr_func.func.name() {
|
||||
"sum" | "min" | "max" | "last_value" | "first_value" => {
|
||||
// aggr_calc(aggr_merge(input_column))) as col_name
|
||||
let mut new_aggr_func = aggr_func.clone();
|
||||
new_aggr_func.args = vec![input_column.clone()];
|
||||
new_aggr_func
|
||||
}
|
||||
"count" => {
|
||||
// sum(input_column) as col_name
|
||||
let mut new_aggr_func = aggr_func.clone();
|
||||
new_aggr_func.func = sum_udaf();
|
||||
new_aggr_func.args = vec![input_column.clone()];
|
||||
new_aggr_func
|
||||
}
|
||||
UDDSKETCH_STATE_NAME | UDDSKETCH_MERGE_NAME => {
|
||||
// udd_merge(bucket_size, error_rate input_column) as col_name
|
||||
let mut new_aggr_func = aggr_func.clone();
|
||||
new_aggr_func.func = Arc::new(UddSketchState::merge_udf_impl());
|
||||
new_aggr_func.args[2] = input_column.clone();
|
||||
new_aggr_func
|
||||
}
|
||||
HLL_NAME | HLL_MERGE_NAME => {
|
||||
// hll_merge(input_column) as col_name
|
||||
let mut new_aggr_func = aggr_func.clone();
|
||||
new_aggr_func.func = Arc::new(HllState::merge_udf_impl());
|
||||
new_aggr_func.args = vec![input_column.clone()];
|
||||
new_aggr_func
|
||||
}
|
||||
_ => {
|
||||
return Err(datafusion_common::DataFusionError::NotImplemented(format!(
|
||||
"Aggregate function {} is not supported for Step aggregation",
|
||||
aggr_func.func.name()
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
// deal with nested alias case
|
||||
let mut new_aggr_expr = aggr_expr.clone();
|
||||
{
|
||||
let new_aggr_func = get_aggr_func_mut(&mut new_aggr_expr).unwrap();
|
||||
*new_aggr_func = upper_func;
|
||||
}
|
||||
let step_aggr_plan = StateMergeHelper::split_aggr_node(input_aggr.clone())?;
|
||||
|
||||
upper_aggr_expr.push(new_aggr_expr);
|
||||
}
|
||||
let mut new_aggr = input_aggr.clone();
|
||||
// use lower aggregate plan as input, this will be replace by merge scan plan later
|
||||
new_aggr.input = Arc::new(LogicalPlan::Aggregate(input_aggr.clone()));
|
||||
|
||||
new_aggr.aggr_expr = upper_aggr_expr;
|
||||
|
||||
// group by expr also need to be all ref by column to avoid duplicated computing
|
||||
let mut new_group_expr = new_aggr.group_expr.clone();
|
||||
for expr in &mut new_group_expr {
|
||||
if let Expr::Column(_) = expr {
|
||||
// already a column, no need to change
|
||||
continue;
|
||||
}
|
||||
let col_name = expr.qualified_name();
|
||||
let input_column = Expr::Column(datafusion_common::Column::new(col_name.0, col_name.1));
|
||||
*expr = input_column;
|
||||
}
|
||||
new_aggr.group_expr = new_group_expr.clone();
|
||||
|
||||
let mut new_projection_exprs = new_group_expr;
|
||||
// the upper aggr expr need to be aliased to the input aggr expr's name,
|
||||
// so that the parent plan can recognize it.
|
||||
for (lower_aggr_expr, upper_aggr_expr) in
|
||||
input_aggr.aggr_expr.iter().zip(new_aggr.aggr_expr.iter())
|
||||
{
|
||||
let lower_col_name = lower_aggr_expr.qualified_name();
|
||||
let (table, col_name) = upper_aggr_expr.qualified_name();
|
||||
let aggr_out_column = Column::new(table, col_name);
|
||||
let aliased_output_aggr_expr =
|
||||
Expr::Column(aggr_out_column).alias_qualified(lower_col_name.0, lower_col_name.1);
|
||||
new_projection_exprs.push(aliased_output_aggr_expr);
|
||||
}
|
||||
let upper_aggr_plan = LogicalPlan::Aggregate(new_aggr);
|
||||
let upper_aggr_plan = upper_aggr_plan.recompute_schema()?;
|
||||
// create a projection on top of the new aggregate plan
|
||||
let new_projection =
|
||||
Projection::try_new(new_projection_exprs, Arc::new(upper_aggr_plan.clone()))?;
|
||||
let projection = LogicalPlan::Projection(new_projection);
|
||||
// return the new logical plan
|
||||
Ok([projection, upper_aggr_plan])
|
||||
// TODO(discord9): remove duplication
|
||||
let ret = StepTransformAction {
|
||||
extra_parent_plans: vec![step_aggr_plan.upper_merge.clone()],
|
||||
new_child_plan: Some(step_aggr_plan.lower_state.clone()),
|
||||
};
|
||||
Ok(ret)
|
||||
}
|
||||
|
||||
/// Check if the given aggregate expression is steppable.
|
||||
@@ -154,25 +76,15 @@ pub fn step_aggr_to_upper_aggr(
|
||||
/// i.e. on datanode first call `state(input)` then
|
||||
/// on frontend call `calc(merge(state))` to get the final result.
|
||||
pub fn is_all_aggr_exprs_steppable(aggr_exprs: &[Expr]) -> bool {
|
||||
let step_action = HashSet::from([
|
||||
"sum",
|
||||
"count",
|
||||
"min",
|
||||
"max",
|
||||
"first_value",
|
||||
"last_value",
|
||||
UDDSKETCH_STATE_NAME,
|
||||
UDDSKETCH_MERGE_NAME,
|
||||
HLL_NAME,
|
||||
HLL_MERGE_NAME,
|
||||
]);
|
||||
aggr_exprs.iter().all(|expr| {
|
||||
if let Some(aggr_func) = get_aggr_func(expr) {
|
||||
if aggr_func.distinct {
|
||||
// Distinct aggregate functions are not steppable(yet).
|
||||
return false;
|
||||
}
|
||||
step_action.contains(aggr_func.func.name())
|
||||
|
||||
// whether the corresponding state function exists in the registry
|
||||
FUNCTION_REGISTRY.is_aggr_func_exist(&aggr_state_func_name(aggr_func.func.name()))
|
||||
} else {
|
||||
false
|
||||
}
|
||||
@@ -191,18 +103,6 @@ pub fn get_aggr_func(expr: &Expr) -> Option<&datafusion_expr::expr::AggregateFun
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_aggr_func_mut(expr: &mut Expr) -> Option<&mut datafusion_expr::expr::AggregateFunction> {
|
||||
let mut expr_ref = expr;
|
||||
while let Expr::Alias(alias) = expr_ref {
|
||||
expr_ref = &mut alias.expr;
|
||||
}
|
||||
if let Expr::AggregateFunction(aggr_func) = expr_ref {
|
||||
Some(aggr_func)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub enum Commutativity {
|
||||
Commutative,
|
||||
@@ -247,8 +147,8 @@ impl Categorizer {
|
||||
debug!("Before Step optimize: {plan}");
|
||||
let ret = step_aggr_to_upper_aggr(plan);
|
||||
ret.ok().map(|s| TransformerAction {
|
||||
extra_parent_plans: s.to_vec(),
|
||||
new_child_plan: None,
|
||||
extra_parent_plans: s.extra_parent_plans,
|
||||
new_child_plan: s.new_child_plan,
|
||||
})
|
||||
})),
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user