diff --git a/src/common/function/src/aggrs/aggr_wrapper.rs b/src/common/function/src/aggrs/aggr_wrapper.rs index ab0e13e1c4..3780d39582 100644 --- a/src/common/function/src/aggrs/aggr_wrapper.rs +++ b/src/common/function/src/aggrs/aggr_wrapper.rs @@ -184,11 +184,11 @@ impl StateMergeHelper { ))); }; - let original_input_types = aggr_func + let original_input_fields = aggr_func .params .args .iter() - .map(|e| e.get_type(&aggr.input.schema())) + .map(|e| e.to_field(&aggr.input.schema()).map(|(_, field)| field)) .collect::, _>>()?; // first create the state function from the original aggregate function. @@ -214,7 +214,7 @@ impl StateMergeHelper { let merge_func = MergeWrapper::new( (*aggr_func.func).clone(), original_phy_expr, - original_input_types, + original_input_fields, )?; let arg = Expr::Column(Column::new_unqualified(lower_state_output_col_name)); let expr = AggregateFunction { @@ -332,18 +332,9 @@ impl AggregateUDFImpl for StateWrapper { // fix and recover proper acc args for the original aggregate function. let state_type = acc_args.return_type().clone(); let inner = { - let acc_args = datafusion_expr::function::AccumulatorArgs { - return_field: self.deduce_aggr_return_type(&acc_args)?, - schema: acc_args.schema, - ignore_nulls: acc_args.ignore_nulls, - order_bys: acc_args.order_bys, - is_reversed: acc_args.is_reversed, - name: acc_args.name, - is_distinct: acc_args.is_distinct, - exprs: acc_args.exprs, - expr_fields: acc_args.expr_fields, - }; - self.inner.accumulator(acc_args)? + let mut new_acc_args = acc_args.clone(); + new_acc_args.return_field = self.deduce_aggr_return_type(&acc_args)?; + self.inner.accumulator(new_acc_args)? }; Ok(Box::new(StateAccum::new(inner, state_type)?)) @@ -568,25 +559,25 @@ pub struct MergeWrapper { merge_signature: Signature, /// The original physical expression of the aggregate function, can't store the original aggregate function directly, as PhysicalExpr didn't implement Any original_phy_expr: Arc, - return_type: DataType, + return_field: FieldRef, } impl MergeWrapper { pub fn new( inner: AggregateUDF, original_phy_expr: Arc, - original_input_types: Vec, + original_input_fields: Vec, ) -> datafusion_common::Result { let name = aggr_merge_func_name(inner.name()); // the input type is actually struct type, which is the state fields of the original aggregate function. let merge_signature = Signature::user_defined(datafusion_expr::Volatility::Immutable); - let return_type = inner.return_type(&original_input_types)?; + let return_field = inner.return_field(&original_input_fields)?.clone(); Ok(Self { inner, name, merge_signature, original_phy_expr, - return_type, + return_field, }) } @@ -638,8 +629,14 @@ impl AggregateUDFImpl for MergeWrapper { /// so return fixed return type instead of using `arg_types` to determine the return type. fn return_type(&self, _arg_types: &[DataType]) -> datafusion_common::Result { // The return type is the same as the original aggregate function's return type. - Ok(self.return_type.clone()) + Ok(self.return_field.data_type().clone()) } + + /// Similar to return_type, we just return the fixed return field. + fn return_field(&self, _arg_fields: &[FieldRef]) -> datafusion_common::Result { + Ok(self.return_field.clone()) + } + fn signature(&self) -> &Signature { &self.merge_signature } diff --git a/src/common/function/src/aggrs/aggr_wrapper/tests.rs b/src/common/function/src/aggrs/aggr_wrapper/tests.rs index f23277f818..32d9913539 100644 --- a/src/common/function/src/aggrs/aggr_wrapper/tests.rs +++ b/src/common/function/src/aggrs/aggr_wrapper/tests.rs @@ -296,6 +296,7 @@ async fn test_sum_udaf() { .unwrap(); assert_eq!(&res.lower_state, &expected_lower_plan); + let merge_input_fields = vec![Arc::new(Field::new("number", DataType::Int64, true))]; let expected_merge_plan = LogicalPlan::Aggregate( Aggregate::try_new( Arc::new(expected_lower_plan), @@ -319,7 +320,7 @@ async fn test_sum_udaf() { .build() .unwrap(), ), - vec![DataType::Int64], + merge_input_fields, ) .unwrap() .into(), @@ -459,6 +460,7 @@ async fn test_avg_udaf() { )]) ); + let merge_input_fields = vec![Arc::new(Field::new("number", DataType::Float64, true))]; let expected_merge_fn = MergeWrapper::new( avg.clone(), Arc::new( @@ -474,7 +476,7 @@ async fn test_avg_udaf() { .unwrap(), ), // coerced to float64 - vec![DataType::Float64], + merge_input_fields, ) .unwrap(); @@ -668,13 +670,15 @@ async fn test_last_value_order_by_udaf() { .unwrap(); let aggr_func_expr = &aggr_exec.aggr_expr()[0]; + let merge_input_fields = vec![Arc::new(Field::new( + "ts", + DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None), + false, + ))]; let expected_merge_fn = MergeWrapper::new( last_value.clone(), aggr_func_expr.clone(), - vec![DataType::Timestamp( - arrow_schema::TimeUnit::Millisecond, - None, - )], + merge_input_fields, ) .unwrap();