mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-15 20:40:39 +00:00
* perf: support group accumulators for state wrapper * new tests and avoid clone Signed-off-by: Ruihang Xia <waynestxia@gmail.com> --------- Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
880 lines
31 KiB
Rust
880 lines
31 KiB
Rust
// Copyright 2023 Greptime Team
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
//! Wrapper for making aggregate functions out of state/merge functions of original aggregate functions.
|
|
//!
|
|
//! i.e. for a aggregate function `foo`, we will have a state function `foo_state` and a merge function `foo_merge`.
|
|
//!
|
|
//! `foo_state`'s input args is the same as `foo`'s, and its output is a state object.
|
|
//! Note that `foo_state` might have multiple output columns, so it's a struct array
|
|
//! that each output column is a struct field.
|
|
//! `foo_merge`'s input arg is the same as `foo_state`'s output, and its output is the same as `foo`'s input.
|
|
//!
|
|
|
|
use std::hash::{Hash, Hasher};
|
|
use std::sync::Arc;
|
|
|
|
use arrow::array::{ArrayRef, BooleanArray, StructArray};
|
|
use arrow_schema::{FieldRef, Fields};
|
|
use common_telemetry::debug;
|
|
use datafusion::functions_aggregate::all_default_aggregate_functions;
|
|
use datafusion::functions_aggregate::count::Count;
|
|
use datafusion::functions_aggregate::min_max::{Max, Min};
|
|
use datafusion::optimizer::AnalyzerRule;
|
|
use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
|
|
use datafusion::physical_planner::create_aggregate_expr_and_maybe_filter;
|
|
use datafusion_common::{Column, ScalarValue};
|
|
use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
|
|
use datafusion_expr::function::StateFieldsArgs;
|
|
use datafusion_expr::{
|
|
Accumulator, Aggregate, AggregateUDF, AggregateUDFImpl, EmitTo, Expr, ExprSchemable,
|
|
GroupsAccumulator, LogicalPlan, Signature,
|
|
};
|
|
use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
|
|
use datatypes::arrow::datatypes::{DataType, Field};
|
|
|
|
use crate::aggrs::aggr_wrapper::fix_order::FixStateUdafOrderingAnalyzer;
|
|
use crate::function_registry::{FUNCTION_REGISTRY, FunctionRegistry};
|
|
|
|
pub mod fix_order;
|
|
#[cfg(test)]
|
|
mod tests;
|
|
|
|
/// Returns the name of the state function for the given aggregate function name.
|
|
/// The state function is used to compute the state of the aggregate function.
|
|
/// The state function's name is in the format `__<aggr_name>_state
|
|
pub fn aggr_state_func_name(aggr_name: &str) -> String {
|
|
format!("__{}_state", aggr_name)
|
|
}
|
|
|
|
/// Returns the name of the merge function for the given aggregate function name.
|
|
/// The merge function is used to merge the states of the state functions.
|
|
/// The merge function's name is in the format `__<aggr_name>_merge
|
|
pub fn aggr_merge_func_name(aggr_name: &str) -> String {
|
|
format!("__{}_merge", aggr_name)
|
|
}
|
|
|
|
/// Check if the given aggregate expression is steppable.
|
|
/// As in if it can be split into multiple steps:
|
|
/// i.e. on datanode first call `state(input)` then
|
|
/// on frontend call `calc(merge(state))` to get the final result.
|
|
pub fn is_all_aggr_exprs_steppable(aggr_exprs: &[Expr]) -> bool {
|
|
aggr_exprs.iter().all(|expr| {
|
|
if let Some(aggr_func) = get_aggr_func(expr) {
|
|
if aggr_func.params.distinct {
|
|
// Distinct aggregate functions are not steppable(yet).
|
|
// TODO(discord9): support distinct aggregate functions.
|
|
return false;
|
|
}
|
|
|
|
// whether the corresponding state function exists in the registry
|
|
FUNCTION_REGISTRY.is_aggr_func_exist(&aggr_state_func_name(aggr_func.func.name()))
|
|
} else {
|
|
false
|
|
}
|
|
})
|
|
}
|
|
|
|
pub fn get_aggr_func(expr: &Expr) -> Option<&datafusion_expr::expr::AggregateFunction> {
|
|
let mut expr_ref = expr;
|
|
while let Expr::Alias(alias) = expr_ref {
|
|
expr_ref = &alias.expr;
|
|
}
|
|
if let Expr::AggregateFunction(aggr_func) = expr_ref {
|
|
Some(aggr_func)
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
|
|
/// A wrapper to make an aggregate function out of the state and merge functions of the original aggregate function.
|
|
/// It contains the original aggregate function, the state functions, and the merge function.
|
|
///
|
|
/// Notice state functions may have multiple output columns, so it's return type is always a struct array, and the merge function is used to merge the states of the state functions.
|
|
#[derive(Debug, Clone)]
|
|
pub struct StateMergeHelper;
|
|
|
|
/// A struct to hold the two aggregate plans, one for the state function(lower) and one for the merge function(upper).
|
|
#[allow(unused)]
|
|
#[derive(Debug, Clone)]
|
|
pub struct StepAggrPlan {
|
|
/// Upper merge plan, which is the aggregate plan that merges the states of the state function.
|
|
pub upper_merge: LogicalPlan,
|
|
/// Lower state plan, which is the aggregate plan that computes the state of the aggregate function.
|
|
pub lower_state: LogicalPlan,
|
|
}
|
|
|
|
impl StateMergeHelper {
|
|
/// Register all the `state` function of supported aggregate functions.
|
|
/// Note that can't register `merge` function here, as it needs to be created from the original aggregate function with given input types.
|
|
pub fn register(registry: &FunctionRegistry) {
|
|
let all_default = all_default_aggregate_functions();
|
|
let greptime_custom_aggr_functions = registry.aggregate_functions();
|
|
|
|
// if our custom aggregate function have the same name as the default aggregate function, we will override it.
|
|
let supported = all_default
|
|
.into_iter()
|
|
.chain(greptime_custom_aggr_functions.into_iter().map(Arc::new))
|
|
.collect::<Vec<_>>();
|
|
debug!(
|
|
"Registering state functions for supported: {:?}",
|
|
supported.iter().map(|f| f.name()).collect::<Vec<_>>()
|
|
);
|
|
|
|
let state_func = supported.into_iter().filter_map(|f| {
|
|
StateWrapper::new((*f).clone())
|
|
.inspect_err(
|
|
|e| common_telemetry::error!(e; "Failed to register state function for {:?}", f),
|
|
)
|
|
.ok()
|
|
.map(AggregateUDF::new_from_impl)
|
|
});
|
|
|
|
for func in state_func {
|
|
registry.register_aggr(func);
|
|
}
|
|
}
|
|
|
|
/// Split an aggregate plan into two aggregate plans, one for the state function and one for the merge function.
|
|
///
|
|
pub fn split_aggr_node(aggr_plan: Aggregate) -> datafusion_common::Result<StepAggrPlan> {
|
|
let aggr = {
|
|
// certain aggr func need type coercion to work correctly, so we need to analyze the plan first.
|
|
let aggr_plan = TypeCoercion::new().analyze(
|
|
LogicalPlan::Aggregate(aggr_plan).clone(),
|
|
&Default::default(),
|
|
)?;
|
|
if let LogicalPlan::Aggregate(aggr) = aggr_plan {
|
|
aggr
|
|
} else {
|
|
return Err(datafusion_common::DataFusionError::Internal(format!(
|
|
"Failed to coerce expressions in aggregate plan, expected Aggregate, got: {:?}",
|
|
aggr_plan
|
|
)));
|
|
}
|
|
};
|
|
let mut lower_aggr_exprs = vec![];
|
|
let mut upper_aggr_exprs = vec![];
|
|
|
|
// group exprs for upper plan should refer to the output group expr as column from lower plan
|
|
// to avoid re-compute group exprs again.
|
|
let upper_group_exprs = aggr
|
|
.group_expr
|
|
.iter()
|
|
.map(|c| c.qualified_name())
|
|
.map(|(r, c)| Expr::Column(Column::new(r, c)))
|
|
.collect();
|
|
|
|
for aggr_expr in aggr.aggr_expr.iter() {
|
|
let Some(aggr_func) = get_aggr_func(aggr_expr) else {
|
|
return Err(datafusion_common::DataFusionError::NotImplemented(format!(
|
|
"Unsupported aggregate expression for step aggr optimize: {:?}",
|
|
aggr_expr
|
|
)));
|
|
};
|
|
|
|
let original_input_fields = aggr_func
|
|
.params
|
|
.args
|
|
.iter()
|
|
.map(|e| e.to_field(&aggr.input.schema()).map(|(_, field)| field))
|
|
.collect::<Result<Vec<_>, _>>()?;
|
|
|
|
// first create the state function from the original aggregate function.
|
|
let state_func = StateWrapper::new((*aggr_func.func).clone())?;
|
|
|
|
let expr = AggregateFunction {
|
|
func: Arc::new(state_func.into()),
|
|
params: aggr_func.params.clone(),
|
|
};
|
|
let expr = Expr::AggregateFunction(expr);
|
|
let lower_state_output_col_name = expr.schema_name().to_string();
|
|
|
|
lower_aggr_exprs.push(expr);
|
|
|
|
// then create the merge function using the physical expression of the original aggregate function
|
|
let (original_phy_expr, _filter, _ordering) = create_aggregate_expr_and_maybe_filter(
|
|
aggr_expr,
|
|
aggr.input.schema(),
|
|
aggr.input.schema().as_arrow(),
|
|
&Default::default(),
|
|
)?;
|
|
|
|
let merge_func = MergeWrapper::new(
|
|
(*aggr_func.func).clone(),
|
|
original_phy_expr,
|
|
original_input_fields,
|
|
)?;
|
|
let arg = Expr::Column(Column::new_unqualified(lower_state_output_col_name));
|
|
let expr = AggregateFunction {
|
|
func: Arc::new(merge_func.into()),
|
|
// notice filter/order_by is not supported in the merge function, as it's not meaningful to have them in the merge phase.
|
|
// do notice this order by is only removed in the outer logical plan, the physical plan still have order by and hence
|
|
// can create correct accumulator with order by.
|
|
params: AggregateFunctionParams {
|
|
args: vec![arg],
|
|
distinct: aggr_func.params.distinct,
|
|
filter: None,
|
|
order_by: vec![],
|
|
null_treatment: aggr_func.params.null_treatment,
|
|
},
|
|
};
|
|
|
|
// alias to the original aggregate expr's schema name, so parent plan can refer to it
|
|
// correctly.
|
|
let expr = Expr::AggregateFunction(expr).alias(aggr_expr.schema_name().to_string());
|
|
upper_aggr_exprs.push(expr);
|
|
}
|
|
|
|
let mut lower = aggr.clone();
|
|
lower.aggr_expr = lower_aggr_exprs;
|
|
let lower_plan = LogicalPlan::Aggregate(lower);
|
|
|
|
// update aggregate's output schema
|
|
let lower_plan = lower_plan.recompute_schema()?;
|
|
|
|
// should only affect two udaf `first_value/last_value`
|
|
// which only them have meaningful order by field
|
|
let fixed_lower_plan =
|
|
FixStateUdafOrderingAnalyzer.analyze(lower_plan, &Default::default())?;
|
|
|
|
let upper = Aggregate::try_new(
|
|
Arc::new(fixed_lower_plan.clone()),
|
|
upper_group_exprs,
|
|
upper_aggr_exprs.clone(),
|
|
)?;
|
|
let aggr_plan = LogicalPlan::Aggregate(aggr);
|
|
|
|
// upper schema's output schema should be the same as the original aggregate plan's output schema
|
|
let upper_check = upper;
|
|
let upper_plan = LogicalPlan::Aggregate(upper_check).recompute_schema()?;
|
|
if *upper_plan.schema() != *aggr_plan.schema() {
|
|
return Err(datafusion_common::DataFusionError::Internal(format!(
|
|
"Upper aggregate plan's schema is not the same as the original aggregate plan's schema: \n[transformed]:{}\n[original]:{}",
|
|
upper_plan.schema(),
|
|
aggr_plan.schema()
|
|
)));
|
|
}
|
|
|
|
Ok(StepAggrPlan {
|
|
lower_state: fixed_lower_plan,
|
|
upper_merge: upper_plan,
|
|
})
|
|
}
|
|
}
|
|
|
|
/// Wrapper to make an aggregate function out of a state function.
|
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
|
pub struct StateWrapper {
|
|
inner: AggregateUDF,
|
|
name: String,
|
|
/// Default to empty, might get fixed by analyzer later
|
|
ordering: Vec<FieldRef>,
|
|
/// Default to false, might get fixed by analyzer later
|
|
distinct: bool,
|
|
}
|
|
|
|
impl StateWrapper {
|
|
/// `state_index`: The index of the state in the output of the state function.
|
|
pub fn new(inner: AggregateUDF) -> datafusion_common::Result<Self> {
|
|
let name = aggr_state_func_name(inner.name());
|
|
Ok(Self {
|
|
inner,
|
|
name,
|
|
ordering: vec![],
|
|
distinct: false,
|
|
})
|
|
}
|
|
|
|
pub fn inner(&self) -> &AggregateUDF {
|
|
&self.inner
|
|
}
|
|
|
|
/// Deduce the return type of the original aggregate function
|
|
/// based on the accumulator arguments.
|
|
///
|
|
pub fn deduce_aggr_return_type(
|
|
&self,
|
|
acc_args: &datafusion_expr::function::AccumulatorArgs,
|
|
) -> datafusion_common::Result<FieldRef> {
|
|
let input_fields = acc_args
|
|
.exprs
|
|
.iter()
|
|
.map(|e| e.return_field(acc_args.schema))
|
|
.collect::<Result<Vec<_>, _>>()?;
|
|
self.inner.return_field(&input_fields).inspect_err(|e| {
|
|
common_telemetry::error!(
|
|
"StateWrapper: {:#?}\nacc_args:{:?}\nerror:{:?}",
|
|
&self,
|
|
&acc_args,
|
|
e
|
|
);
|
|
})
|
|
}
|
|
|
|
fn fix_inner_acc_args<'b>(
|
|
&self,
|
|
mut acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
|
|
) -> datafusion_common::Result<datafusion_expr::function::AccumulatorArgs<'b>> {
|
|
acc_args.return_field = self.deduce_aggr_return_type(&acc_args)?;
|
|
Ok(acc_args)
|
|
}
|
|
}
|
|
|
|
impl AggregateUDFImpl for StateWrapper {
|
|
fn accumulator<'a, 'b>(
|
|
&'a self,
|
|
acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
|
|
) -> datafusion_common::Result<Box<dyn Accumulator>> {
|
|
// fix and recover proper acc args for the original aggregate function.
|
|
let state_type = acc_args.return_type().clone();
|
|
let inner = self.inner.accumulator(self.fix_inner_acc_args(acc_args)?)?;
|
|
|
|
Ok(Box::new(StateAccum::new(inner, state_type)?))
|
|
}
|
|
|
|
fn groups_accumulator_supported(
|
|
&self,
|
|
acc_args: datafusion_expr::function::AccumulatorArgs,
|
|
) -> bool {
|
|
self.fix_inner_acc_args(acc_args)
|
|
.map(|args| self.inner.inner().groups_accumulator_supported(args))
|
|
.unwrap_or(false)
|
|
}
|
|
|
|
fn create_groups_accumulator(
|
|
&self,
|
|
acc_args: datafusion_expr::function::AccumulatorArgs,
|
|
) -> datafusion_common::Result<Box<dyn GroupsAccumulator>> {
|
|
let state_type = acc_args.return_type().clone();
|
|
let inner = self
|
|
.inner
|
|
.inner()
|
|
.create_groups_accumulator(self.fix_inner_acc_args(acc_args)?)?;
|
|
Ok(Box::new(StateGroupsAccum::new(inner, state_type)?))
|
|
}
|
|
|
|
fn as_any(&self) -> &dyn std::any::Any {
|
|
self
|
|
}
|
|
fn name(&self) -> &str {
|
|
self.name.as_str()
|
|
}
|
|
|
|
fn is_nullable(&self) -> bool {
|
|
self.inner.is_nullable()
|
|
}
|
|
|
|
/// Return state_fields as the output struct type.
|
|
///
|
|
fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
|
|
let input_fields = &arg_types
|
|
.iter()
|
|
.map(|x| Arc::new(Field::new("x", x.clone(), false)))
|
|
.collect::<Vec<_>>();
|
|
|
|
let state_fields_args = StateFieldsArgs {
|
|
name: self.inner().name(),
|
|
input_fields,
|
|
return_field: self.inner.return_field(input_fields)?,
|
|
// those args are also needed as they are vital to construct the state fields correctly.
|
|
ordering_fields: &self.ordering,
|
|
is_distinct: self.distinct,
|
|
};
|
|
let state_fields = self.inner.state_fields(state_fields_args)?;
|
|
|
|
let state_fields = state_fields
|
|
.into_iter()
|
|
.map(|f| {
|
|
let mut f = f.as_ref().clone();
|
|
// since state can be null when no input rows, so make all fields nullable
|
|
f.set_nullable(true);
|
|
Arc::new(f)
|
|
})
|
|
.collect::<Vec<_>>();
|
|
|
|
let struct_field = DataType::Struct(state_fields.into());
|
|
Ok(struct_field)
|
|
}
|
|
|
|
/// The state function's output fields are the same as the original aggregate function's state fields.
|
|
fn state_fields(
|
|
&self,
|
|
args: datafusion_expr::function::StateFieldsArgs,
|
|
) -> datafusion_common::Result<Vec<FieldRef>> {
|
|
let state_fields_args = StateFieldsArgs {
|
|
name: args.name,
|
|
input_fields: args.input_fields,
|
|
return_field: self.inner.return_field(args.input_fields)?,
|
|
ordering_fields: args.ordering_fields,
|
|
is_distinct: args.is_distinct,
|
|
};
|
|
self.inner.state_fields(state_fields_args)
|
|
}
|
|
|
|
/// The state function's signature is the same as the original aggregate function's signature,
|
|
fn signature(&self) -> &Signature {
|
|
self.inner.signature()
|
|
}
|
|
|
|
/// Coerce types also do nothing, as optimizer should be able to already make struct types
|
|
fn coerce_types(&self, arg_types: &[DataType]) -> datafusion_common::Result<Vec<DataType>> {
|
|
self.inner.coerce_types(arg_types)
|
|
}
|
|
|
|
fn value_from_stats(
|
|
&self,
|
|
statistics_args: &datafusion_expr::StatisticsArgs,
|
|
) -> Option<ScalarValue> {
|
|
let inner = self.inner().inner().as_any();
|
|
// only count/min/max need special handling here, for getting result from statistics
|
|
// the result of count/min/max is also the result of count_state so can return directly
|
|
let can_use_stat = inner.is::<Count>() || inner.is::<Max>() || inner.is::<Min>();
|
|
if !can_use_stat {
|
|
return None;
|
|
}
|
|
|
|
// fix return type by extract the first field's data type from the struct type
|
|
let state_type = if let DataType::Struct(fields) = &statistics_args.return_type {
|
|
if fields.is_empty() {
|
|
return None;
|
|
}
|
|
fields[0].data_type().clone()
|
|
} else {
|
|
return None;
|
|
};
|
|
|
|
let fixed_args = datafusion_expr::StatisticsArgs {
|
|
statistics: statistics_args.statistics,
|
|
return_type: &state_type,
|
|
is_distinct: statistics_args.is_distinct,
|
|
exprs: statistics_args.exprs,
|
|
};
|
|
|
|
let ret = self.inner().value_from_stats(&fixed_args)?;
|
|
|
|
// wrap the result into struct scalar value
|
|
let fields = if let DataType::Struct(fields) = &statistics_args.return_type {
|
|
fields
|
|
} else {
|
|
return None;
|
|
};
|
|
|
|
let array = ret.to_array().ok()?;
|
|
|
|
let struct_array = StructArray::new(fields.clone(), vec![array], None);
|
|
let ret = ScalarValue::Struct(Arc::new(struct_array));
|
|
Some(ret)
|
|
}
|
|
}
|
|
|
|
/// The wrapper's input is the same as the original aggregate function's input,
|
|
/// and the output is the state function's output.
|
|
#[derive(Debug)]
|
|
pub struct StateAccum {
|
|
inner: Box<dyn Accumulator>,
|
|
state_fields: Fields,
|
|
}
|
|
|
|
pub struct StateGroupsAccum {
|
|
inner: Box<dyn GroupsAccumulator>,
|
|
state_fields: Fields,
|
|
}
|
|
|
|
impl StateGroupsAccum {
|
|
fn new(
|
|
inner: Box<dyn GroupsAccumulator>,
|
|
state_type: DataType,
|
|
) -> datafusion_common::Result<Self> {
|
|
let DataType::Struct(fields) = state_type else {
|
|
return Err(datafusion_common::DataFusionError::Internal(format!(
|
|
"Expected a struct type for state, got: {:?}",
|
|
state_type
|
|
)));
|
|
};
|
|
Ok(Self {
|
|
inner,
|
|
state_fields: fields,
|
|
})
|
|
}
|
|
|
|
fn wrap_state_arrays(&self, arrays: Vec<ArrayRef>) -> datafusion_common::Result<ArrayRef> {
|
|
let array_type = arrays
|
|
.iter()
|
|
.map(|array| array.data_type().clone())
|
|
.collect::<Vec<_>>();
|
|
let expected_type = self
|
|
.state_fields
|
|
.iter()
|
|
.map(|field| field.data_type().clone())
|
|
.collect::<Vec<_>>();
|
|
if array_type != expected_type {
|
|
debug!(
|
|
"State mismatch, expected: {}, got: {} for expected fields: {:?} and given array types: {:?}",
|
|
self.state_fields.len(),
|
|
arrays.len(),
|
|
self.state_fields,
|
|
array_type,
|
|
);
|
|
let guess_schema = arrays
|
|
.iter()
|
|
.enumerate()
|
|
.map(|(index, array)| {
|
|
Field::new(
|
|
format!("col_{index}[mismatch_state]").as_str(),
|
|
array.data_type().clone(),
|
|
true,
|
|
)
|
|
})
|
|
.collect::<Fields>();
|
|
let array = StructArray::try_new(guess_schema, arrays, None)?;
|
|
return Ok(Arc::new(array));
|
|
}
|
|
|
|
Ok(Arc::new(StructArray::try_new(
|
|
self.state_fields.clone(),
|
|
arrays,
|
|
None,
|
|
)?))
|
|
}
|
|
}
|
|
|
|
impl GroupsAccumulator for StateGroupsAccum {
|
|
fn update_batch(
|
|
&mut self,
|
|
values: &[ArrayRef],
|
|
group_indices: &[usize],
|
|
opt_filter: Option<&BooleanArray>,
|
|
total_num_groups: usize,
|
|
) -> datafusion_common::Result<()> {
|
|
self.inner
|
|
.update_batch(values, group_indices, opt_filter, total_num_groups)
|
|
}
|
|
|
|
fn merge_batch(
|
|
&mut self,
|
|
values: &[ArrayRef],
|
|
group_indices: &[usize],
|
|
opt_filter: Option<&BooleanArray>,
|
|
total_num_groups: usize,
|
|
) -> datafusion_common::Result<()> {
|
|
self.inner
|
|
.merge_batch(values, group_indices, opt_filter, total_num_groups)
|
|
}
|
|
|
|
fn evaluate(&mut self, emit_to: EmitTo) -> datafusion_common::Result<ArrayRef> {
|
|
let state = self.inner.state(emit_to)?;
|
|
self.wrap_state_arrays(state)
|
|
}
|
|
|
|
fn state(&mut self, emit_to: EmitTo) -> datafusion_common::Result<Vec<ArrayRef>> {
|
|
self.inner.state(emit_to)
|
|
}
|
|
|
|
fn convert_to_state(
|
|
&self,
|
|
values: &[ArrayRef],
|
|
opt_filter: Option<&BooleanArray>,
|
|
) -> datafusion_common::Result<Vec<ArrayRef>> {
|
|
self.inner.convert_to_state(values, opt_filter)
|
|
}
|
|
|
|
fn supports_convert_to_state(&self) -> bool {
|
|
self.inner.supports_convert_to_state()
|
|
}
|
|
|
|
fn size(&self) -> usize {
|
|
self.inner.size()
|
|
}
|
|
}
|
|
|
|
impl StateAccum {
|
|
pub fn new(
|
|
inner: Box<dyn Accumulator>,
|
|
state_type: DataType,
|
|
) -> datafusion_common::Result<Self> {
|
|
let DataType::Struct(fields) = state_type else {
|
|
return Err(datafusion_common::DataFusionError::Internal(format!(
|
|
"Expected a struct type for state, got: {:?}",
|
|
state_type
|
|
)));
|
|
};
|
|
Ok(Self {
|
|
inner,
|
|
state_fields: fields,
|
|
})
|
|
}
|
|
}
|
|
|
|
impl Accumulator for StateAccum {
|
|
fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
|
|
let state = self.inner.state()?;
|
|
|
|
let array = state
|
|
.iter()
|
|
.map(|s| s.to_array())
|
|
.collect::<Result<Vec<_>, _>>()?;
|
|
let array_type = array
|
|
.iter()
|
|
.map(|a| a.data_type().clone())
|
|
.collect::<Vec<_>>();
|
|
let expected_type: Vec<_> = self
|
|
.state_fields
|
|
.iter()
|
|
.map(|f| f.data_type().clone())
|
|
.collect();
|
|
if array_type != expected_type {
|
|
debug!(
|
|
"State mismatch, expected: {}, got: {} for expected fields: {:?} and given array types: {:?}",
|
|
self.state_fields.len(),
|
|
array.len(),
|
|
self.state_fields,
|
|
array_type,
|
|
);
|
|
let guess_schema = array
|
|
.iter()
|
|
.enumerate()
|
|
.map(|(index, array)| {
|
|
Field::new(
|
|
format!("col_{index}[mismatch_state]").as_str(),
|
|
array.data_type().clone(),
|
|
true,
|
|
)
|
|
})
|
|
.collect::<Fields>();
|
|
let arr = StructArray::try_new(guess_schema, array, None)?;
|
|
|
|
return Ok(ScalarValue::Struct(Arc::new(arr)));
|
|
}
|
|
|
|
let struct_array = StructArray::try_new(self.state_fields.clone(), array, None)?;
|
|
Ok(ScalarValue::Struct(Arc::new(struct_array)))
|
|
}
|
|
|
|
fn merge_batch(
|
|
&mut self,
|
|
states: &[datatypes::arrow::array::ArrayRef],
|
|
) -> datafusion_common::Result<()> {
|
|
self.inner.merge_batch(states)
|
|
}
|
|
|
|
fn update_batch(
|
|
&mut self,
|
|
values: &[datatypes::arrow::array::ArrayRef],
|
|
) -> datafusion_common::Result<()> {
|
|
self.inner.update_batch(values)
|
|
}
|
|
|
|
fn size(&self) -> usize {
|
|
self.inner.size()
|
|
}
|
|
|
|
fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
|
|
self.inner.state()
|
|
}
|
|
}
|
|
|
|
/// TODO(discord9): mark this function as non-ser/de able
|
|
///
|
|
/// This wrapper shouldn't be register as a udaf, as it contain extra data that is not serializable.
|
|
/// and changes for different logical plans.
|
|
#[derive(Debug, Clone)]
|
|
pub struct MergeWrapper {
|
|
inner: AggregateUDF,
|
|
name: String,
|
|
merge_signature: Signature,
|
|
/// The original physical expression of the aggregate function, can't store the original aggregate function directly, as PhysicalExpr didn't implement Any
|
|
original_phy_expr: Arc<AggregateFunctionExpr>,
|
|
return_field: FieldRef,
|
|
}
|
|
impl MergeWrapper {
|
|
pub fn new(
|
|
inner: AggregateUDF,
|
|
original_phy_expr: Arc<AggregateFunctionExpr>,
|
|
original_input_fields: Vec<FieldRef>,
|
|
) -> datafusion_common::Result<Self> {
|
|
let name = aggr_merge_func_name(inner.name());
|
|
// the input type is actually struct type, which is the state fields of the original aggregate function.
|
|
let merge_signature = Signature::user_defined(datafusion_expr::Volatility::Immutable);
|
|
let return_field = inner.return_field(&original_input_fields)?.clone();
|
|
|
|
Ok(Self {
|
|
inner,
|
|
name,
|
|
merge_signature,
|
|
original_phy_expr,
|
|
return_field,
|
|
})
|
|
}
|
|
|
|
pub fn inner(&self) -> &AggregateUDF {
|
|
&self.inner
|
|
}
|
|
}
|
|
|
|
impl AggregateUDFImpl for MergeWrapper {
|
|
fn accumulator<'a, 'b>(
|
|
&'a self,
|
|
acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
|
|
) -> datafusion_common::Result<Box<dyn Accumulator>> {
|
|
if acc_args.exprs.len() != 1
|
|
|| !matches!(
|
|
acc_args.exprs[0].data_type(acc_args.schema)?,
|
|
DataType::Struct(_)
|
|
)
|
|
{
|
|
return Err(datafusion_common::DataFusionError::Internal(format!(
|
|
"Expected one struct type as input, got: {:?}",
|
|
acc_args.schema
|
|
)));
|
|
}
|
|
let input_type = acc_args.exprs[0].data_type(acc_args.schema)?;
|
|
let DataType::Struct(fields) = input_type else {
|
|
return Err(datafusion_common::DataFusionError::Internal(format!(
|
|
"Expected a struct type for input, got: {:?}",
|
|
input_type
|
|
)));
|
|
};
|
|
|
|
let inner_accum = self.original_phy_expr.create_accumulator()?;
|
|
Ok(Box::new(MergeAccum::new(inner_accum, &fields)))
|
|
}
|
|
|
|
fn as_any(&self) -> &dyn std::any::Any {
|
|
self
|
|
}
|
|
fn name(&self) -> &str {
|
|
self.name.as_str()
|
|
}
|
|
|
|
fn is_nullable(&self) -> bool {
|
|
self.inner.is_nullable()
|
|
}
|
|
|
|
/// Notice here the `arg_types` is actually the `state_fields`'s data types,
|
|
/// so return fixed return type instead of using `arg_types` to determine the return type.
|
|
fn return_type(&self, _arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
|
|
// The return type is the same as the original aggregate function's return type.
|
|
Ok(self.return_field.data_type().clone())
|
|
}
|
|
|
|
/// Similar to return_type, we just return the fixed return field.
|
|
fn return_field(&self, _arg_fields: &[FieldRef]) -> datafusion_common::Result<FieldRef> {
|
|
Ok(self.return_field.clone())
|
|
}
|
|
|
|
fn signature(&self) -> &Signature {
|
|
&self.merge_signature
|
|
}
|
|
|
|
/// Coerce types also do nothing, as optimizer should be able to already make struct types
|
|
fn coerce_types(&self, arg_types: &[DataType]) -> datafusion_common::Result<Vec<DataType>> {
|
|
// just check if the arg_types are only one and is struct array
|
|
if arg_types.len() != 1 || !matches!(arg_types.first(), Some(DataType::Struct(_))) {
|
|
return Err(datafusion_common::DataFusionError::Internal(format!(
|
|
"Expected one struct type as input, got: {:?}",
|
|
arg_types
|
|
)));
|
|
}
|
|
Ok(arg_types.to_vec())
|
|
}
|
|
|
|
/// Just return the original aggregate function's state fields.
|
|
fn state_fields(
|
|
&self,
|
|
_args: datafusion_expr::function::StateFieldsArgs,
|
|
) -> datafusion_common::Result<Vec<FieldRef>> {
|
|
self.original_phy_expr.state_fields()
|
|
}
|
|
}
|
|
|
|
impl PartialEq for MergeWrapper {
|
|
fn eq(&self, other: &Self) -> bool {
|
|
self.inner == other.inner
|
|
}
|
|
}
|
|
|
|
impl Eq for MergeWrapper {}
|
|
|
|
impl Hash for MergeWrapper {
|
|
fn hash<H: Hasher>(&self, state: &mut H) {
|
|
self.inner.hash(state);
|
|
}
|
|
}
|
|
|
|
/// The merge accumulator, which modify `update_batch`'s behavior to accept one struct array which
|
|
/// include the state fields of original aggregate function, and merge said states into original accumulator
|
|
/// the output is the same as original aggregate function
|
|
#[derive(Debug)]
|
|
pub struct MergeAccum {
|
|
inner: Box<dyn Accumulator>,
|
|
state_fields: Fields,
|
|
}
|
|
|
|
impl MergeAccum {
|
|
pub fn new(inner: Box<dyn Accumulator>, state_fields: &Fields) -> Self {
|
|
Self {
|
|
inner,
|
|
state_fields: state_fields.clone(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Accumulator for MergeAccum {
|
|
fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
|
|
self.inner.evaluate()
|
|
}
|
|
|
|
fn merge_batch(&mut self, states: &[arrow::array::ArrayRef]) -> datafusion_common::Result<()> {
|
|
self.inner.merge_batch(states)
|
|
}
|
|
|
|
fn update_batch(&mut self, values: &[arrow::array::ArrayRef]) -> datafusion_common::Result<()> {
|
|
let value = values.first().ok_or_else(|| {
|
|
datafusion_common::DataFusionError::Internal("No values provided for merge".to_string())
|
|
})?;
|
|
// The input values are states from other accumulators, so we merge them.
|
|
let struct_arr = value
|
|
.as_any()
|
|
.downcast_ref::<StructArray>()
|
|
.ok_or_else(|| {
|
|
datafusion_common::DataFusionError::Internal(format!(
|
|
"Expected StructArray, got: {:?}",
|
|
value.data_type()
|
|
))
|
|
})?;
|
|
let fields = struct_arr.fields();
|
|
if fields != &self.state_fields {
|
|
debug!(
|
|
"State fields mismatch, expected: {:?}, got: {:?}",
|
|
self.state_fields, fields
|
|
);
|
|
// state fields mismatch might be acceptable by datafusion, continue
|
|
}
|
|
|
|
// now fields should be the same, so we can merge the batch
|
|
// by pass the columns as order should be the same
|
|
let state_columns = struct_arr.columns();
|
|
self.inner.merge_batch(state_columns)
|
|
}
|
|
|
|
fn size(&self) -> usize {
|
|
self.inner.size()
|
|
}
|
|
|
|
fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
|
|
self.inner.state()
|
|
}
|
|
}
|