perf: support group accumulators for state wrapper (#7826)

* perf: support group accumulators for state wrapper

* new tests and avoid clone

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

---------

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
This commit is contained in:
Ruihang Xia
2026-03-20 06:40:52 +08:00
committed by GitHub
parent 16fcbb2729
commit f034255fe6
2 changed files with 270 additions and 9 deletions

View File

@@ -25,7 +25,7 @@
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use arrow::array::StructArray;
use arrow::array::{ArrayRef, BooleanArray, StructArray};
use arrow_schema::{FieldRef, Fields};
use common_telemetry::debug;
use datafusion::functions_aggregate::all_default_aggregate_functions;
@@ -38,8 +38,8 @@ use datafusion_common::{Column, ScalarValue};
use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
use datafusion_expr::function::StateFieldsArgs;
use datafusion_expr::{
Accumulator, Aggregate, AggregateUDF, AggregateUDFImpl, Expr, ExprSchemable, LogicalPlan,
Signature,
Accumulator, Aggregate, AggregateUDF, AggregateUDFImpl, EmitTo, Expr, ExprSchemable,
GroupsAccumulator, LogicalPlan, Signature,
};
use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
use datatypes::arrow::datatypes::{DataType, Field};
@@ -322,6 +322,14 @@ impl StateWrapper {
);
})
}
fn fix_inner_acc_args<'b>(
&self,
mut acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
) -> datafusion_common::Result<datafusion_expr::function::AccumulatorArgs<'b>> {
acc_args.return_field = self.deduce_aggr_return_type(&acc_args)?;
Ok(acc_args)
}
}
impl AggregateUDFImpl for StateWrapper {
@@ -331,15 +339,32 @@ impl AggregateUDFImpl for StateWrapper {
) -> datafusion_common::Result<Box<dyn Accumulator>> {
// fix and recover proper acc args for the original aggregate function.
let state_type = acc_args.return_type().clone();
let inner = {
let mut new_acc_args = acc_args.clone();
new_acc_args.return_field = self.deduce_aggr_return_type(&acc_args)?;
self.inner.accumulator(new_acc_args)?
};
let inner = self.inner.accumulator(self.fix_inner_acc_args(acc_args)?)?;
Ok(Box::new(StateAccum::new(inner, state_type)?))
}
fn groups_accumulator_supported(
&self,
acc_args: datafusion_expr::function::AccumulatorArgs,
) -> bool {
self.fix_inner_acc_args(acc_args)
.map(|args| self.inner.inner().groups_accumulator_supported(args))
.unwrap_or(false)
}
fn create_groups_accumulator(
&self,
acc_args: datafusion_expr::function::AccumulatorArgs,
) -> datafusion_common::Result<Box<dyn GroupsAccumulator>> {
let state_type = acc_args.return_type().clone();
let inner = self
.inner
.inner()
.create_groups_accumulator(self.fix_inner_acc_args(acc_args)?)?;
Ok(Box::new(StateGroupsAccum::new(inner, state_type)?))
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
@@ -462,6 +487,118 @@ pub struct StateAccum {
state_fields: Fields,
}
pub struct StateGroupsAccum {
inner: Box<dyn GroupsAccumulator>,
state_fields: Fields,
}
impl StateGroupsAccum {
fn new(
inner: Box<dyn GroupsAccumulator>,
state_type: DataType,
) -> datafusion_common::Result<Self> {
let DataType::Struct(fields) = state_type else {
return Err(datafusion_common::DataFusionError::Internal(format!(
"Expected a struct type for state, got: {:?}",
state_type
)));
};
Ok(Self {
inner,
state_fields: fields,
})
}
fn wrap_state_arrays(&self, arrays: Vec<ArrayRef>) -> datafusion_common::Result<ArrayRef> {
let array_type = arrays
.iter()
.map(|array| array.data_type().clone())
.collect::<Vec<_>>();
let expected_type = self
.state_fields
.iter()
.map(|field| field.data_type().clone())
.collect::<Vec<_>>();
if array_type != expected_type {
debug!(
"State mismatch, expected: {}, got: {} for expected fields: {:?} and given array types: {:?}",
self.state_fields.len(),
arrays.len(),
self.state_fields,
array_type,
);
let guess_schema = arrays
.iter()
.enumerate()
.map(|(index, array)| {
Field::new(
format!("col_{index}[mismatch_state]").as_str(),
array.data_type().clone(),
true,
)
})
.collect::<Fields>();
let array = StructArray::try_new(guess_schema, arrays, None)?;
return Ok(Arc::new(array));
}
Ok(Arc::new(StructArray::try_new(
self.state_fields.clone(),
arrays,
None,
)?))
}
}
impl GroupsAccumulator for StateGroupsAccum {
fn update_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> datafusion_common::Result<()> {
self.inner
.update_batch(values, group_indices, opt_filter, total_num_groups)
}
fn merge_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> datafusion_common::Result<()> {
self.inner
.merge_batch(values, group_indices, opt_filter, total_num_groups)
}
fn evaluate(&mut self, emit_to: EmitTo) -> datafusion_common::Result<ArrayRef> {
let state = self.inner.state(emit_to)?;
self.wrap_state_arrays(state)
}
fn state(&mut self, emit_to: EmitTo) -> datafusion_common::Result<Vec<ArrayRef>> {
self.inner.state(emit_to)
}
fn convert_to_state(
&self,
values: &[ArrayRef],
opt_filter: Option<&BooleanArray>,
) -> datafusion_common::Result<Vec<ArrayRef>> {
self.inner.convert_to_state(values, opt_filter)
}
fn supports_convert_to_state(&self) -> bool {
self.inner.supports_convert_to_state()
}
fn size(&self) -> usize {
self.inner.size()
}
}
impl StateAccum {
pub fn new(
inner: Box<dyn Accumulator>,

View File

@@ -40,10 +40,13 @@ use datafusion_common::arrow::array::AsArray;
use datafusion_common::arrow::datatypes::{Float64Type, UInt64Type};
use datafusion_common::{Column, TableReference};
use datafusion_expr::expr::{AggregateFunction, NullTreatment};
use datafusion_expr::function::AccumulatorArgs;
use datafusion_expr::{
Aggregate, ColumnarValue, Expr, LogicalPlan, ScalarFunctionArgs, SortExpr, TableScan, lit,
Aggregate, AggregateUDFImpl, ColumnarValue, Expr, LogicalPlan, ScalarFunctionArgs, SortExpr,
TableScan, lit,
};
use datafusion_physical_expr::aggregate::AggregateExprBuilder;
use datafusion_physical_expr::expressions::col;
use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
use datatypes::arrow_array::StringArray;
use futures::{Stream, StreamExt as _};
@@ -256,6 +259,38 @@ fn dummy_table_scan_with_ts() -> LogicalPlan {
)
}
fn create_avg_state_groups_accumulator() -> Box<dyn GroupsAccumulator> {
let state_wrapper = StateWrapper::new((*avg_udaf()).clone()).unwrap();
let schema = Arc::new(arrow_schema::Schema::new(vec![Field::new(
"number",
DataType::Float64,
true,
)]));
let expr = col("number", &schema).unwrap();
let expr_field = expr.return_field(&schema).unwrap();
let return_field = Arc::new(Field::new(
"__avg_state(number)",
state_wrapper.return_type(&[DataType::Float64]).unwrap(),
true,
));
let exprs = [expr];
let expr_fields = [expr_field];
let acc_args = AccumulatorArgs {
return_field,
schema: &schema,
ignore_nulls: false,
order_bys: &[],
is_reversed: false,
name: "__avg_state(number)",
is_distinct: false,
exprs: &exprs,
expr_fields: &expr_fields,
};
assert!(state_wrapper.groups_accumulator_supported(acc_args.clone()));
state_wrapper.create_groups_accumulator(acc_args).unwrap()
}
#[tokio::test]
async fn test_sum_udaf() {
let ctx = SessionContext::new();
@@ -796,6 +831,95 @@ async fn test_last_value_order_by_udaf() {
assert_eq!(merge_eval_res, ScalarValue::Int64(Some(4)));
}
#[test]
fn test_avg_state_groups_accumulator_evaluate() {
let mut state_accum = create_avg_state_groups_accumulator();
let values = vec![Arc::new(Float64Array::from(vec![
Some(1.0),
Some(2.0),
None,
Some(3.0),
Some(4.0),
Some(5.0),
])) as ArrayRef];
let group_indices = vec![0, 1, 0, 0, 1, 2];
state_accum
.update_batch(&values, &group_indices, None, 3)
.unwrap();
let result = state_accum.evaluate(EmitTo::All).unwrap();
let result = result.as_any().downcast_ref::<StructArray>().unwrap();
assert_eq!(
result
.column(0)
.as_any()
.downcast_ref::<UInt64Array>()
.unwrap(),
&UInt64Array::from(vec![2, 2, 1])
);
assert_eq!(
result
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap(),
&Float64Array::from(vec![4.0, 6.0, 5.0])
);
}
#[test]
fn test_avg_state_groups_accumulator_state_merge_evaluate() {
let mut source_accum = create_avg_state_groups_accumulator();
let source_values = vec![Arc::new(Float64Array::from(vec![
Some(1.0),
Some(2.0),
None,
Some(3.0),
Some(4.0),
Some(5.0),
])) as ArrayRef];
let source_group_indices = vec![0, 1, 0, 0, 1, 2];
source_accum
.update_batch(&source_values, &source_group_indices, None, 3)
.unwrap();
let source_state = source_accum.state(EmitTo::All).unwrap();
let mut merged_accum = create_avg_state_groups_accumulator();
let merged_values =
vec![Arc::new(Float64Array::from(vec![Some(10.0), Some(20.0), Some(30.0)])) as ArrayRef];
let merged_group_indices = vec![0, 1, 2];
merged_accum
.update_batch(&merged_values, &merged_group_indices, None, 3)
.unwrap();
merged_accum
.merge_batch(&source_state, &[1, 2, 0], None, 3)
.unwrap();
let result = merged_accum.evaluate(EmitTo::All).unwrap();
let result = result.as_any().downcast_ref::<StructArray>().unwrap();
assert_eq!(
result
.column(0)
.as_any()
.downcast_ref::<UInt64Array>()
.unwrap(),
&UInt64Array::from(vec![2, 3, 3])
);
assert_eq!(
result
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap(),
&Float64Array::from(vec![15.0, 24.0, 36.0])
);
}
/// For testing whether the UDAF state fields are correctly implemented.
/// esp. for our own custom UDAF's state fields.
/// By compare eval results before and after split to state/merge functions.