diff --git a/src/common/function/src/aggrs/aggr_wrapper.rs b/src/common/function/src/aggrs/aggr_wrapper.rs index 3780d39582..6242ab9454 100644 --- a/src/common/function/src/aggrs/aggr_wrapper.rs +++ b/src/common/function/src/aggrs/aggr_wrapper.rs @@ -25,7 +25,7 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; -use arrow::array::StructArray; +use arrow::array::{ArrayRef, BooleanArray, StructArray}; use arrow_schema::{FieldRef, Fields}; use common_telemetry::debug; use datafusion::functions_aggregate::all_default_aggregate_functions; @@ -38,8 +38,8 @@ use datafusion_common::{Column, ScalarValue}; use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams}; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::{ - Accumulator, Aggregate, AggregateUDF, AggregateUDFImpl, Expr, ExprSchemable, LogicalPlan, - Signature, + Accumulator, Aggregate, AggregateUDF, AggregateUDFImpl, EmitTo, Expr, ExprSchemable, + GroupsAccumulator, LogicalPlan, Signature, }; use datafusion_physical_expr::aggregate::AggregateFunctionExpr; use datatypes::arrow::datatypes::{DataType, Field}; @@ -322,6 +322,14 @@ impl StateWrapper { ); }) } + + fn fix_inner_acc_args<'b>( + &self, + mut acc_args: datafusion_expr::function::AccumulatorArgs<'b>, + ) -> datafusion_common::Result> { + acc_args.return_field = self.deduce_aggr_return_type(&acc_args)?; + Ok(acc_args) + } } impl AggregateUDFImpl for StateWrapper { @@ -331,15 +339,32 @@ impl AggregateUDFImpl for StateWrapper { ) -> datafusion_common::Result> { // fix and recover proper acc args for the original aggregate function. let state_type = acc_args.return_type().clone(); - let inner = { - let mut new_acc_args = acc_args.clone(); - new_acc_args.return_field = self.deduce_aggr_return_type(&acc_args)?; - self.inner.accumulator(new_acc_args)? - }; + let inner = self.inner.accumulator(self.fix_inner_acc_args(acc_args)?)?; Ok(Box::new(StateAccum::new(inner, state_type)?)) } + fn groups_accumulator_supported( + &self, + acc_args: datafusion_expr::function::AccumulatorArgs, + ) -> bool { + self.fix_inner_acc_args(acc_args) + .map(|args| self.inner.inner().groups_accumulator_supported(args)) + .unwrap_or(false) + } + + fn create_groups_accumulator( + &self, + acc_args: datafusion_expr::function::AccumulatorArgs, + ) -> datafusion_common::Result> { + let state_type = acc_args.return_type().clone(); + let inner = self + .inner + .inner() + .create_groups_accumulator(self.fix_inner_acc_args(acc_args)?)?; + Ok(Box::new(StateGroupsAccum::new(inner, state_type)?)) + } + fn as_any(&self) -> &dyn std::any::Any { self } @@ -462,6 +487,118 @@ pub struct StateAccum { state_fields: Fields, } +pub struct StateGroupsAccum { + inner: Box, + state_fields: Fields, +} + +impl StateGroupsAccum { + fn new( + inner: Box, + state_type: DataType, + ) -> datafusion_common::Result { + let DataType::Struct(fields) = state_type else { + return Err(datafusion_common::DataFusionError::Internal(format!( + "Expected a struct type for state, got: {:?}", + state_type + ))); + }; + Ok(Self { + inner, + state_fields: fields, + }) + } + + fn wrap_state_arrays(&self, arrays: Vec) -> datafusion_common::Result { + let array_type = arrays + .iter() + .map(|array| array.data_type().clone()) + .collect::>(); + let expected_type = self + .state_fields + .iter() + .map(|field| field.data_type().clone()) + .collect::>(); + if array_type != expected_type { + debug!( + "State mismatch, expected: {}, got: {} for expected fields: {:?} and given array types: {:?}", + self.state_fields.len(), + arrays.len(), + self.state_fields, + array_type, + ); + let guess_schema = arrays + .iter() + .enumerate() + .map(|(index, array)| { + Field::new( + format!("col_{index}[mismatch_state]").as_str(), + array.data_type().clone(), + true, + ) + }) + .collect::(); + let array = StructArray::try_new(guess_schema, arrays, None)?; + return Ok(Arc::new(array)); + } + + Ok(Arc::new(StructArray::try_new( + self.state_fields.clone(), + arrays, + None, + )?)) + } +} + +impl GroupsAccumulator for StateGroupsAccum { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> datafusion_common::Result<()> { + self.inner + .update_batch(values, group_indices, opt_filter, total_num_groups) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> datafusion_common::Result<()> { + self.inner + .merge_batch(values, group_indices, opt_filter, total_num_groups) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> datafusion_common::Result { + let state = self.inner.state(emit_to)?; + self.wrap_state_arrays(state) + } + + fn state(&mut self, emit_to: EmitTo) -> datafusion_common::Result> { + self.inner.state(emit_to) + } + + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> datafusion_common::Result> { + self.inner.convert_to_state(values, opt_filter) + } + + fn supports_convert_to_state(&self) -> bool { + self.inner.supports_convert_to_state() + } + + fn size(&self) -> usize { + self.inner.size() + } +} + impl StateAccum { pub fn new( inner: Box, diff --git a/src/common/function/src/aggrs/aggr_wrapper/tests.rs b/src/common/function/src/aggrs/aggr_wrapper/tests.rs index 8821b9fd24..de3a77df6b 100644 --- a/src/common/function/src/aggrs/aggr_wrapper/tests.rs +++ b/src/common/function/src/aggrs/aggr_wrapper/tests.rs @@ -40,10 +40,13 @@ use datafusion_common::arrow::array::AsArray; use datafusion_common::arrow::datatypes::{Float64Type, UInt64Type}; use datafusion_common::{Column, TableReference}; use datafusion_expr::expr::{AggregateFunction, NullTreatment}; +use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::{ - Aggregate, ColumnarValue, Expr, LogicalPlan, ScalarFunctionArgs, SortExpr, TableScan, lit, + Aggregate, AggregateUDFImpl, ColumnarValue, Expr, LogicalPlan, ScalarFunctionArgs, SortExpr, + TableScan, lit, }; use datafusion_physical_expr::aggregate::AggregateExprBuilder; +use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; use datatypes::arrow_array::StringArray; use futures::{Stream, StreamExt as _}; @@ -256,6 +259,38 @@ fn dummy_table_scan_with_ts() -> LogicalPlan { ) } +fn create_avg_state_groups_accumulator() -> Box { + let state_wrapper = StateWrapper::new((*avg_udaf()).clone()).unwrap(); + let schema = Arc::new(arrow_schema::Schema::new(vec![Field::new( + "number", + DataType::Float64, + true, + )])); + let expr = col("number", &schema).unwrap(); + let expr_field = expr.return_field(&schema).unwrap(); + let return_field = Arc::new(Field::new( + "__avg_state(number)", + state_wrapper.return_type(&[DataType::Float64]).unwrap(), + true, + )); + let exprs = [expr]; + let expr_fields = [expr_field]; + let acc_args = AccumulatorArgs { + return_field, + schema: &schema, + ignore_nulls: false, + order_bys: &[], + is_reversed: false, + name: "__avg_state(number)", + is_distinct: false, + exprs: &exprs, + expr_fields: &expr_fields, + }; + + assert!(state_wrapper.groups_accumulator_supported(acc_args.clone())); + state_wrapper.create_groups_accumulator(acc_args).unwrap() +} + #[tokio::test] async fn test_sum_udaf() { let ctx = SessionContext::new(); @@ -796,6 +831,95 @@ async fn test_last_value_order_by_udaf() { assert_eq!(merge_eval_res, ScalarValue::Int64(Some(4))); } +#[test] +fn test_avg_state_groups_accumulator_evaluate() { + let mut state_accum = create_avg_state_groups_accumulator(); + let values = vec![Arc::new(Float64Array::from(vec![ + Some(1.0), + Some(2.0), + None, + Some(3.0), + Some(4.0), + Some(5.0), + ])) as ArrayRef]; + let group_indices = vec![0, 1, 0, 0, 1, 2]; + + state_accum + .update_batch(&values, &group_indices, None, 3) + .unwrap(); + + let result = state_accum.evaluate(EmitTo::All).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + + assert_eq!( + result + .column(0) + .as_any() + .downcast_ref::() + .unwrap(), + &UInt64Array::from(vec![2, 2, 1]) + ); + assert_eq!( + result + .column(1) + .as_any() + .downcast_ref::() + .unwrap(), + &Float64Array::from(vec![4.0, 6.0, 5.0]) + ); +} + +#[test] +fn test_avg_state_groups_accumulator_state_merge_evaluate() { + let mut source_accum = create_avg_state_groups_accumulator(); + let source_values = vec![Arc::new(Float64Array::from(vec![ + Some(1.0), + Some(2.0), + None, + Some(3.0), + Some(4.0), + Some(5.0), + ])) as ArrayRef]; + let source_group_indices = vec![0, 1, 0, 0, 1, 2]; + + source_accum + .update_batch(&source_values, &source_group_indices, None, 3) + .unwrap(); + let source_state = source_accum.state(EmitTo::All).unwrap(); + + let mut merged_accum = create_avg_state_groups_accumulator(); + let merged_values = + vec![Arc::new(Float64Array::from(vec![Some(10.0), Some(20.0), Some(30.0)])) as ArrayRef]; + let merged_group_indices = vec![0, 1, 2]; + + merged_accum + .update_batch(&merged_values, &merged_group_indices, None, 3) + .unwrap(); + merged_accum + .merge_batch(&source_state, &[1, 2, 0], None, 3) + .unwrap(); + + let result = merged_accum.evaluate(EmitTo::All).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + + assert_eq!( + result + .column(0) + .as_any() + .downcast_ref::() + .unwrap(), + &UInt64Array::from(vec![2, 3, 3]) + ); + assert_eq!( + result + .column(1) + .as_any() + .downcast_ref::() + .unwrap(), + &Float64Array::from(vec![15.0, 24.0, 36.0]) + ); +} + /// For testing whether the UDAF state fields are correctly implemented. /// esp. for our own custom UDAF's state fields. /// By compare eval results before and after split to state/merge functions.