Files
greptimedb/src/common/function/src/aggrs/aggr_wrapper.rs
Ruihang Xia f034255fe6 perf: support group accumulators for state wrapper (#7826)
* perf: support group accumulators for state wrapper

* new tests and avoid clone

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

---------

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
2026-03-19 22:40:52 +00:00

880 lines
31 KiB
Rust

// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Wrapper for making aggregate functions out of state/merge functions of original aggregate functions.
//!
//! i.e. for a aggregate function `foo`, we will have a state function `foo_state` and a merge function `foo_merge`.
//!
//! `foo_state`'s input args is the same as `foo`'s, and its output is a state object.
//! Note that `foo_state` might have multiple output columns, so it's a struct array
//! that each output column is a struct field.
//! `foo_merge`'s input arg is the same as `foo_state`'s output, and its output is the same as `foo`'s input.
//!
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use arrow::array::{ArrayRef, BooleanArray, StructArray};
use arrow_schema::{FieldRef, Fields};
use common_telemetry::debug;
use datafusion::functions_aggregate::all_default_aggregate_functions;
use datafusion::functions_aggregate::count::Count;
use datafusion::functions_aggregate::min_max::{Max, Min};
use datafusion::optimizer::AnalyzerRule;
use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
use datafusion::physical_planner::create_aggregate_expr_and_maybe_filter;
use datafusion_common::{Column, ScalarValue};
use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
use datafusion_expr::function::StateFieldsArgs;
use datafusion_expr::{
Accumulator, Aggregate, AggregateUDF, AggregateUDFImpl, EmitTo, Expr, ExprSchemable,
GroupsAccumulator, LogicalPlan, Signature,
};
use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
use datatypes::arrow::datatypes::{DataType, Field};
use crate::aggrs::aggr_wrapper::fix_order::FixStateUdafOrderingAnalyzer;
use crate::function_registry::{FUNCTION_REGISTRY, FunctionRegistry};
pub mod fix_order;
#[cfg(test)]
mod tests;
/// Returns the name of the state function for the given aggregate function name.
/// The state function is used to compute the state of the aggregate function.
/// The state function's name is in the format `__<aggr_name>_state
pub fn aggr_state_func_name(aggr_name: &str) -> String {
format!("__{}_state", aggr_name)
}
/// Returns the name of the merge function for the given aggregate function name.
/// The merge function is used to merge the states of the state functions.
/// The merge function's name is in the format `__<aggr_name>_merge
pub fn aggr_merge_func_name(aggr_name: &str) -> String {
format!("__{}_merge", aggr_name)
}
/// Check if the given aggregate expression is steppable.
/// As in if it can be split into multiple steps:
/// i.e. on datanode first call `state(input)` then
/// on frontend call `calc(merge(state))` to get the final result.
pub fn is_all_aggr_exprs_steppable(aggr_exprs: &[Expr]) -> bool {
aggr_exprs.iter().all(|expr| {
if let Some(aggr_func) = get_aggr_func(expr) {
if aggr_func.params.distinct {
// Distinct aggregate functions are not steppable(yet).
// TODO(discord9): support distinct aggregate functions.
return false;
}
// whether the corresponding state function exists in the registry
FUNCTION_REGISTRY.is_aggr_func_exist(&aggr_state_func_name(aggr_func.func.name()))
} else {
false
}
})
}
pub fn get_aggr_func(expr: &Expr) -> Option<&datafusion_expr::expr::AggregateFunction> {
let mut expr_ref = expr;
while let Expr::Alias(alias) = expr_ref {
expr_ref = &alias.expr;
}
if let Expr::AggregateFunction(aggr_func) = expr_ref {
Some(aggr_func)
} else {
None
}
}
/// A wrapper to make an aggregate function out of the state and merge functions of the original aggregate function.
/// It contains the original aggregate function, the state functions, and the merge function.
///
/// Notice state functions may have multiple output columns, so it's return type is always a struct array, and the merge function is used to merge the states of the state functions.
#[derive(Debug, Clone)]
pub struct StateMergeHelper;
/// A struct to hold the two aggregate plans, one for the state function(lower) and one for the merge function(upper).
#[allow(unused)]
#[derive(Debug, Clone)]
pub struct StepAggrPlan {
/// Upper merge plan, which is the aggregate plan that merges the states of the state function.
pub upper_merge: LogicalPlan,
/// Lower state plan, which is the aggregate plan that computes the state of the aggregate function.
pub lower_state: LogicalPlan,
}
impl StateMergeHelper {
/// Register all the `state` function of supported aggregate functions.
/// Note that can't register `merge` function here, as it needs to be created from the original aggregate function with given input types.
pub fn register(registry: &FunctionRegistry) {
let all_default = all_default_aggregate_functions();
let greptime_custom_aggr_functions = registry.aggregate_functions();
// if our custom aggregate function have the same name as the default aggregate function, we will override it.
let supported = all_default
.into_iter()
.chain(greptime_custom_aggr_functions.into_iter().map(Arc::new))
.collect::<Vec<_>>();
debug!(
"Registering state functions for supported: {:?}",
supported.iter().map(|f| f.name()).collect::<Vec<_>>()
);
let state_func = supported.into_iter().filter_map(|f| {
StateWrapper::new((*f).clone())
.inspect_err(
|e| common_telemetry::error!(e; "Failed to register state function for {:?}", f),
)
.ok()
.map(AggregateUDF::new_from_impl)
});
for func in state_func {
registry.register_aggr(func);
}
}
/// Split an aggregate plan into two aggregate plans, one for the state function and one for the merge function.
///
pub fn split_aggr_node(aggr_plan: Aggregate) -> datafusion_common::Result<StepAggrPlan> {
let aggr = {
// certain aggr func need type coercion to work correctly, so we need to analyze the plan first.
let aggr_plan = TypeCoercion::new().analyze(
LogicalPlan::Aggregate(aggr_plan).clone(),
&Default::default(),
)?;
if let LogicalPlan::Aggregate(aggr) = aggr_plan {
aggr
} else {
return Err(datafusion_common::DataFusionError::Internal(format!(
"Failed to coerce expressions in aggregate plan, expected Aggregate, got: {:?}",
aggr_plan
)));
}
};
let mut lower_aggr_exprs = vec![];
let mut upper_aggr_exprs = vec![];
// group exprs for upper plan should refer to the output group expr as column from lower plan
// to avoid re-compute group exprs again.
let upper_group_exprs = aggr
.group_expr
.iter()
.map(|c| c.qualified_name())
.map(|(r, c)| Expr::Column(Column::new(r, c)))
.collect();
for aggr_expr in aggr.aggr_expr.iter() {
let Some(aggr_func) = get_aggr_func(aggr_expr) else {
return Err(datafusion_common::DataFusionError::NotImplemented(format!(
"Unsupported aggregate expression for step aggr optimize: {:?}",
aggr_expr
)));
};
let original_input_fields = aggr_func
.params
.args
.iter()
.map(|e| e.to_field(&aggr.input.schema()).map(|(_, field)| field))
.collect::<Result<Vec<_>, _>>()?;
// first create the state function from the original aggregate function.
let state_func = StateWrapper::new((*aggr_func.func).clone())?;
let expr = AggregateFunction {
func: Arc::new(state_func.into()),
params: aggr_func.params.clone(),
};
let expr = Expr::AggregateFunction(expr);
let lower_state_output_col_name = expr.schema_name().to_string();
lower_aggr_exprs.push(expr);
// then create the merge function using the physical expression of the original aggregate function
let (original_phy_expr, _filter, _ordering) = create_aggregate_expr_and_maybe_filter(
aggr_expr,
aggr.input.schema(),
aggr.input.schema().as_arrow(),
&Default::default(),
)?;
let merge_func = MergeWrapper::new(
(*aggr_func.func).clone(),
original_phy_expr,
original_input_fields,
)?;
let arg = Expr::Column(Column::new_unqualified(lower_state_output_col_name));
let expr = AggregateFunction {
func: Arc::new(merge_func.into()),
// notice filter/order_by is not supported in the merge function, as it's not meaningful to have them in the merge phase.
// do notice this order by is only removed in the outer logical plan, the physical plan still have order by and hence
// can create correct accumulator with order by.
params: AggregateFunctionParams {
args: vec![arg],
distinct: aggr_func.params.distinct,
filter: None,
order_by: vec![],
null_treatment: aggr_func.params.null_treatment,
},
};
// alias to the original aggregate expr's schema name, so parent plan can refer to it
// correctly.
let expr = Expr::AggregateFunction(expr).alias(aggr_expr.schema_name().to_string());
upper_aggr_exprs.push(expr);
}
let mut lower = aggr.clone();
lower.aggr_expr = lower_aggr_exprs;
let lower_plan = LogicalPlan::Aggregate(lower);
// update aggregate's output schema
let lower_plan = lower_plan.recompute_schema()?;
// should only affect two udaf `first_value/last_value`
// which only them have meaningful order by field
let fixed_lower_plan =
FixStateUdafOrderingAnalyzer.analyze(lower_plan, &Default::default())?;
let upper = Aggregate::try_new(
Arc::new(fixed_lower_plan.clone()),
upper_group_exprs,
upper_aggr_exprs.clone(),
)?;
let aggr_plan = LogicalPlan::Aggregate(aggr);
// upper schema's output schema should be the same as the original aggregate plan's output schema
let upper_check = upper;
let upper_plan = LogicalPlan::Aggregate(upper_check).recompute_schema()?;
if *upper_plan.schema() != *aggr_plan.schema() {
return Err(datafusion_common::DataFusionError::Internal(format!(
"Upper aggregate plan's schema is not the same as the original aggregate plan's schema: \n[transformed]:{}\n[original]:{}",
upper_plan.schema(),
aggr_plan.schema()
)));
}
Ok(StepAggrPlan {
lower_state: fixed_lower_plan,
upper_merge: upper_plan,
})
}
}
/// Wrapper to make an aggregate function out of a state function.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct StateWrapper {
inner: AggregateUDF,
name: String,
/// Default to empty, might get fixed by analyzer later
ordering: Vec<FieldRef>,
/// Default to false, might get fixed by analyzer later
distinct: bool,
}
impl StateWrapper {
/// `state_index`: The index of the state in the output of the state function.
pub fn new(inner: AggregateUDF) -> datafusion_common::Result<Self> {
let name = aggr_state_func_name(inner.name());
Ok(Self {
inner,
name,
ordering: vec![],
distinct: false,
})
}
pub fn inner(&self) -> &AggregateUDF {
&self.inner
}
/// Deduce the return type of the original aggregate function
/// based on the accumulator arguments.
///
pub fn deduce_aggr_return_type(
&self,
acc_args: &datafusion_expr::function::AccumulatorArgs,
) -> datafusion_common::Result<FieldRef> {
let input_fields = acc_args
.exprs
.iter()
.map(|e| e.return_field(acc_args.schema))
.collect::<Result<Vec<_>, _>>()?;
self.inner.return_field(&input_fields).inspect_err(|e| {
common_telemetry::error!(
"StateWrapper: {:#?}\nacc_args:{:?}\nerror:{:?}",
&self,
&acc_args,
e
);
})
}
fn fix_inner_acc_args<'b>(
&self,
mut acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
) -> datafusion_common::Result<datafusion_expr::function::AccumulatorArgs<'b>> {
acc_args.return_field = self.deduce_aggr_return_type(&acc_args)?;
Ok(acc_args)
}
}
impl AggregateUDFImpl for StateWrapper {
fn accumulator<'a, 'b>(
&'a self,
acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
) -> datafusion_common::Result<Box<dyn Accumulator>> {
// fix and recover proper acc args for the original aggregate function.
let state_type = acc_args.return_type().clone();
let inner = self.inner.accumulator(self.fix_inner_acc_args(acc_args)?)?;
Ok(Box::new(StateAccum::new(inner, state_type)?))
}
fn groups_accumulator_supported(
&self,
acc_args: datafusion_expr::function::AccumulatorArgs,
) -> bool {
self.fix_inner_acc_args(acc_args)
.map(|args| self.inner.inner().groups_accumulator_supported(args))
.unwrap_or(false)
}
fn create_groups_accumulator(
&self,
acc_args: datafusion_expr::function::AccumulatorArgs,
) -> datafusion_common::Result<Box<dyn GroupsAccumulator>> {
let state_type = acc_args.return_type().clone();
let inner = self
.inner
.inner()
.create_groups_accumulator(self.fix_inner_acc_args(acc_args)?)?;
Ok(Box::new(StateGroupsAccum::new(inner, state_type)?))
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn name(&self) -> &str {
self.name.as_str()
}
fn is_nullable(&self) -> bool {
self.inner.is_nullable()
}
/// Return state_fields as the output struct type.
///
fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
let input_fields = &arg_types
.iter()
.map(|x| Arc::new(Field::new("x", x.clone(), false)))
.collect::<Vec<_>>();
let state_fields_args = StateFieldsArgs {
name: self.inner().name(),
input_fields,
return_field: self.inner.return_field(input_fields)?,
// those args are also needed as they are vital to construct the state fields correctly.
ordering_fields: &self.ordering,
is_distinct: self.distinct,
};
let state_fields = self.inner.state_fields(state_fields_args)?;
let state_fields = state_fields
.into_iter()
.map(|f| {
let mut f = f.as_ref().clone();
// since state can be null when no input rows, so make all fields nullable
f.set_nullable(true);
Arc::new(f)
})
.collect::<Vec<_>>();
let struct_field = DataType::Struct(state_fields.into());
Ok(struct_field)
}
/// The state function's output fields are the same as the original aggregate function's state fields.
fn state_fields(
&self,
args: datafusion_expr::function::StateFieldsArgs,
) -> datafusion_common::Result<Vec<FieldRef>> {
let state_fields_args = StateFieldsArgs {
name: args.name,
input_fields: args.input_fields,
return_field: self.inner.return_field(args.input_fields)?,
ordering_fields: args.ordering_fields,
is_distinct: args.is_distinct,
};
self.inner.state_fields(state_fields_args)
}
/// The state function's signature is the same as the original aggregate function's signature,
fn signature(&self) -> &Signature {
self.inner.signature()
}
/// Coerce types also do nothing, as optimizer should be able to already make struct types
fn coerce_types(&self, arg_types: &[DataType]) -> datafusion_common::Result<Vec<DataType>> {
self.inner.coerce_types(arg_types)
}
fn value_from_stats(
&self,
statistics_args: &datafusion_expr::StatisticsArgs,
) -> Option<ScalarValue> {
let inner = self.inner().inner().as_any();
// only count/min/max need special handling here, for getting result from statistics
// the result of count/min/max is also the result of count_state so can return directly
let can_use_stat = inner.is::<Count>() || inner.is::<Max>() || inner.is::<Min>();
if !can_use_stat {
return None;
}
// fix return type by extract the first field's data type from the struct type
let state_type = if let DataType::Struct(fields) = &statistics_args.return_type {
if fields.is_empty() {
return None;
}
fields[0].data_type().clone()
} else {
return None;
};
let fixed_args = datafusion_expr::StatisticsArgs {
statistics: statistics_args.statistics,
return_type: &state_type,
is_distinct: statistics_args.is_distinct,
exprs: statistics_args.exprs,
};
let ret = self.inner().value_from_stats(&fixed_args)?;
// wrap the result into struct scalar value
let fields = if let DataType::Struct(fields) = &statistics_args.return_type {
fields
} else {
return None;
};
let array = ret.to_array().ok()?;
let struct_array = StructArray::new(fields.clone(), vec![array], None);
let ret = ScalarValue::Struct(Arc::new(struct_array));
Some(ret)
}
}
/// The wrapper's input is the same as the original aggregate function's input,
/// and the output is the state function's output.
#[derive(Debug)]
pub struct StateAccum {
inner: Box<dyn Accumulator>,
state_fields: Fields,
}
pub struct StateGroupsAccum {
inner: Box<dyn GroupsAccumulator>,
state_fields: Fields,
}
impl StateGroupsAccum {
fn new(
inner: Box<dyn GroupsAccumulator>,
state_type: DataType,
) -> datafusion_common::Result<Self> {
let DataType::Struct(fields) = state_type else {
return Err(datafusion_common::DataFusionError::Internal(format!(
"Expected a struct type for state, got: {:?}",
state_type
)));
};
Ok(Self {
inner,
state_fields: fields,
})
}
fn wrap_state_arrays(&self, arrays: Vec<ArrayRef>) -> datafusion_common::Result<ArrayRef> {
let array_type = arrays
.iter()
.map(|array| array.data_type().clone())
.collect::<Vec<_>>();
let expected_type = self
.state_fields
.iter()
.map(|field| field.data_type().clone())
.collect::<Vec<_>>();
if array_type != expected_type {
debug!(
"State mismatch, expected: {}, got: {} for expected fields: {:?} and given array types: {:?}",
self.state_fields.len(),
arrays.len(),
self.state_fields,
array_type,
);
let guess_schema = arrays
.iter()
.enumerate()
.map(|(index, array)| {
Field::new(
format!("col_{index}[mismatch_state]").as_str(),
array.data_type().clone(),
true,
)
})
.collect::<Fields>();
let array = StructArray::try_new(guess_schema, arrays, None)?;
return Ok(Arc::new(array));
}
Ok(Arc::new(StructArray::try_new(
self.state_fields.clone(),
arrays,
None,
)?))
}
}
impl GroupsAccumulator for StateGroupsAccum {
fn update_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> datafusion_common::Result<()> {
self.inner
.update_batch(values, group_indices, opt_filter, total_num_groups)
}
fn merge_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> datafusion_common::Result<()> {
self.inner
.merge_batch(values, group_indices, opt_filter, total_num_groups)
}
fn evaluate(&mut self, emit_to: EmitTo) -> datafusion_common::Result<ArrayRef> {
let state = self.inner.state(emit_to)?;
self.wrap_state_arrays(state)
}
fn state(&mut self, emit_to: EmitTo) -> datafusion_common::Result<Vec<ArrayRef>> {
self.inner.state(emit_to)
}
fn convert_to_state(
&self,
values: &[ArrayRef],
opt_filter: Option<&BooleanArray>,
) -> datafusion_common::Result<Vec<ArrayRef>> {
self.inner.convert_to_state(values, opt_filter)
}
fn supports_convert_to_state(&self) -> bool {
self.inner.supports_convert_to_state()
}
fn size(&self) -> usize {
self.inner.size()
}
}
impl StateAccum {
pub fn new(
inner: Box<dyn Accumulator>,
state_type: DataType,
) -> datafusion_common::Result<Self> {
let DataType::Struct(fields) = state_type else {
return Err(datafusion_common::DataFusionError::Internal(format!(
"Expected a struct type for state, got: {:?}",
state_type
)));
};
Ok(Self {
inner,
state_fields: fields,
})
}
}
impl Accumulator for StateAccum {
fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
let state = self.inner.state()?;
let array = state
.iter()
.map(|s| s.to_array())
.collect::<Result<Vec<_>, _>>()?;
let array_type = array
.iter()
.map(|a| a.data_type().clone())
.collect::<Vec<_>>();
let expected_type: Vec<_> = self
.state_fields
.iter()
.map(|f| f.data_type().clone())
.collect();
if array_type != expected_type {
debug!(
"State mismatch, expected: {}, got: {} for expected fields: {:?} and given array types: {:?}",
self.state_fields.len(),
array.len(),
self.state_fields,
array_type,
);
let guess_schema = array
.iter()
.enumerate()
.map(|(index, array)| {
Field::new(
format!("col_{index}[mismatch_state]").as_str(),
array.data_type().clone(),
true,
)
})
.collect::<Fields>();
let arr = StructArray::try_new(guess_schema, array, None)?;
return Ok(ScalarValue::Struct(Arc::new(arr)));
}
let struct_array = StructArray::try_new(self.state_fields.clone(), array, None)?;
Ok(ScalarValue::Struct(Arc::new(struct_array)))
}
fn merge_batch(
&mut self,
states: &[datatypes::arrow::array::ArrayRef],
) -> datafusion_common::Result<()> {
self.inner.merge_batch(states)
}
fn update_batch(
&mut self,
values: &[datatypes::arrow::array::ArrayRef],
) -> datafusion_common::Result<()> {
self.inner.update_batch(values)
}
fn size(&self) -> usize {
self.inner.size()
}
fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
self.inner.state()
}
}
/// TODO(discord9): mark this function as non-ser/de able
///
/// This wrapper shouldn't be register as a udaf, as it contain extra data that is not serializable.
/// and changes for different logical plans.
#[derive(Debug, Clone)]
pub struct MergeWrapper {
inner: AggregateUDF,
name: String,
merge_signature: Signature,
/// The original physical expression of the aggregate function, can't store the original aggregate function directly, as PhysicalExpr didn't implement Any
original_phy_expr: Arc<AggregateFunctionExpr>,
return_field: FieldRef,
}
impl MergeWrapper {
pub fn new(
inner: AggregateUDF,
original_phy_expr: Arc<AggregateFunctionExpr>,
original_input_fields: Vec<FieldRef>,
) -> datafusion_common::Result<Self> {
let name = aggr_merge_func_name(inner.name());
// the input type is actually struct type, which is the state fields of the original aggregate function.
let merge_signature = Signature::user_defined(datafusion_expr::Volatility::Immutable);
let return_field = inner.return_field(&original_input_fields)?.clone();
Ok(Self {
inner,
name,
merge_signature,
original_phy_expr,
return_field,
})
}
pub fn inner(&self) -> &AggregateUDF {
&self.inner
}
}
impl AggregateUDFImpl for MergeWrapper {
fn accumulator<'a, 'b>(
&'a self,
acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
) -> datafusion_common::Result<Box<dyn Accumulator>> {
if acc_args.exprs.len() != 1
|| !matches!(
acc_args.exprs[0].data_type(acc_args.schema)?,
DataType::Struct(_)
)
{
return Err(datafusion_common::DataFusionError::Internal(format!(
"Expected one struct type as input, got: {:?}",
acc_args.schema
)));
}
let input_type = acc_args.exprs[0].data_type(acc_args.schema)?;
let DataType::Struct(fields) = input_type else {
return Err(datafusion_common::DataFusionError::Internal(format!(
"Expected a struct type for input, got: {:?}",
input_type
)));
};
let inner_accum = self.original_phy_expr.create_accumulator()?;
Ok(Box::new(MergeAccum::new(inner_accum, &fields)))
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn name(&self) -> &str {
self.name.as_str()
}
fn is_nullable(&self) -> bool {
self.inner.is_nullable()
}
/// Notice here the `arg_types` is actually the `state_fields`'s data types,
/// so return fixed return type instead of using `arg_types` to determine the return type.
fn return_type(&self, _arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
// The return type is the same as the original aggregate function's return type.
Ok(self.return_field.data_type().clone())
}
/// Similar to return_type, we just return the fixed return field.
fn return_field(&self, _arg_fields: &[FieldRef]) -> datafusion_common::Result<FieldRef> {
Ok(self.return_field.clone())
}
fn signature(&self) -> &Signature {
&self.merge_signature
}
/// Coerce types also do nothing, as optimizer should be able to already make struct types
fn coerce_types(&self, arg_types: &[DataType]) -> datafusion_common::Result<Vec<DataType>> {
// just check if the arg_types are only one and is struct array
if arg_types.len() != 1 || !matches!(arg_types.first(), Some(DataType::Struct(_))) {
return Err(datafusion_common::DataFusionError::Internal(format!(
"Expected one struct type as input, got: {:?}",
arg_types
)));
}
Ok(arg_types.to_vec())
}
/// Just return the original aggregate function's state fields.
fn state_fields(
&self,
_args: datafusion_expr::function::StateFieldsArgs,
) -> datafusion_common::Result<Vec<FieldRef>> {
self.original_phy_expr.state_fields()
}
}
impl PartialEq for MergeWrapper {
fn eq(&self, other: &Self) -> bool {
self.inner == other.inner
}
}
impl Eq for MergeWrapper {}
impl Hash for MergeWrapper {
fn hash<H: Hasher>(&self, state: &mut H) {
self.inner.hash(state);
}
}
/// The merge accumulator, which modify `update_batch`'s behavior to accept one struct array which
/// include the state fields of original aggregate function, and merge said states into original accumulator
/// the output is the same as original aggregate function
#[derive(Debug)]
pub struct MergeAccum {
inner: Box<dyn Accumulator>,
state_fields: Fields,
}
impl MergeAccum {
pub fn new(inner: Box<dyn Accumulator>, state_fields: &Fields) -> Self {
Self {
inner,
state_fields: state_fields.clone(),
}
}
}
impl Accumulator for MergeAccum {
fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
self.inner.evaluate()
}
fn merge_batch(&mut self, states: &[arrow::array::ArrayRef]) -> datafusion_common::Result<()> {
self.inner.merge_batch(states)
}
fn update_batch(&mut self, values: &[arrow::array::ArrayRef]) -> datafusion_common::Result<()> {
let value = values.first().ok_or_else(|| {
datafusion_common::DataFusionError::Internal("No values provided for merge".to_string())
})?;
// The input values are states from other accumulators, so we merge them.
let struct_arr = value
.as_any()
.downcast_ref::<StructArray>()
.ok_or_else(|| {
datafusion_common::DataFusionError::Internal(format!(
"Expected StructArray, got: {:?}",
value.data_type()
))
})?;
let fields = struct_arr.fields();
if fields != &self.state_fields {
debug!(
"State fields mismatch, expected: {:?}, got: {:?}",
self.state_fields, fields
);
// state fields mismatch might be acceptable by datafusion, continue
}
// now fields should be the same, so we can merge the batch
// by pass the columns as order should be the same
let state_columns = struct_arr.columns();
self.inner.merge_batch(state_columns)
}
fn size(&self) -> usize {
self.inner.size()
}
fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
self.inner.state()
}
}