chore: aggr wrapper use return_field (#7582)

Signed-off-by: discord9 <discord9@163.com>
This commit is contained in:
discord9
2026-01-16 15:12:21 +08:00
committed by GitHub
parent 4e35028e28
commit 21433f09e3
2 changed files with 27 additions and 26 deletions

View File

@@ -184,11 +184,11 @@ impl StateMergeHelper {
)));
};
let original_input_types = aggr_func
let original_input_fields = aggr_func
.params
.args
.iter()
.map(|e| e.get_type(&aggr.input.schema()))
.map(|e| e.to_field(&aggr.input.schema()).map(|(_, field)| field))
.collect::<Result<Vec<_>, _>>()?;
// first create the state function from the original aggregate function.
@@ -214,7 +214,7 @@ impl StateMergeHelper {
let merge_func = MergeWrapper::new(
(*aggr_func.func).clone(),
original_phy_expr,
original_input_types,
original_input_fields,
)?;
let arg = Expr::Column(Column::new_unqualified(lower_state_output_col_name));
let expr = AggregateFunction {
@@ -332,18 +332,9 @@ impl AggregateUDFImpl for StateWrapper {
// fix and recover proper acc args for the original aggregate function.
let state_type = acc_args.return_type().clone();
let inner = {
let acc_args = datafusion_expr::function::AccumulatorArgs {
return_field: self.deduce_aggr_return_type(&acc_args)?,
schema: acc_args.schema,
ignore_nulls: acc_args.ignore_nulls,
order_bys: acc_args.order_bys,
is_reversed: acc_args.is_reversed,
name: acc_args.name,
is_distinct: acc_args.is_distinct,
exprs: acc_args.exprs,
expr_fields: acc_args.expr_fields,
};
self.inner.accumulator(acc_args)?
let mut new_acc_args = acc_args.clone();
new_acc_args.return_field = self.deduce_aggr_return_type(&acc_args)?;
self.inner.accumulator(new_acc_args)?
};
Ok(Box::new(StateAccum::new(inner, state_type)?))
@@ -568,25 +559,25 @@ pub struct MergeWrapper {
merge_signature: Signature,
/// The original physical expression of the aggregate function, can't store the original aggregate function directly, as PhysicalExpr didn't implement Any
original_phy_expr: Arc<AggregateFunctionExpr>,
return_type: DataType,
return_field: FieldRef,
}
impl MergeWrapper {
pub fn new(
inner: AggregateUDF,
original_phy_expr: Arc<AggregateFunctionExpr>,
original_input_types: Vec<DataType>,
original_input_fields: Vec<FieldRef>,
) -> datafusion_common::Result<Self> {
let name = aggr_merge_func_name(inner.name());
// the input type is actually struct type, which is the state fields of the original aggregate function.
let merge_signature = Signature::user_defined(datafusion_expr::Volatility::Immutable);
let return_type = inner.return_type(&original_input_types)?;
let return_field = inner.return_field(&original_input_fields)?.clone();
Ok(Self {
inner,
name,
merge_signature,
original_phy_expr,
return_type,
return_field,
})
}
@@ -638,8 +629,14 @@ impl AggregateUDFImpl for MergeWrapper {
/// so return fixed return type instead of using `arg_types` to determine the return type.
fn return_type(&self, _arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
// The return type is the same as the original aggregate function's return type.
Ok(self.return_type.clone())
Ok(self.return_field.data_type().clone())
}
/// Similar to return_type, we just return the fixed return field.
fn return_field(&self, _arg_fields: &[FieldRef]) -> datafusion_common::Result<FieldRef> {
Ok(self.return_field.clone())
}
fn signature(&self) -> &Signature {
&self.merge_signature
}

View File

@@ -296,6 +296,7 @@ async fn test_sum_udaf() {
.unwrap();
assert_eq!(&res.lower_state, &expected_lower_plan);
let merge_input_fields = vec![Arc::new(Field::new("number", DataType::Int64, true))];
let expected_merge_plan = LogicalPlan::Aggregate(
Aggregate::try_new(
Arc::new(expected_lower_plan),
@@ -319,7 +320,7 @@ async fn test_sum_udaf() {
.build()
.unwrap(),
),
vec![DataType::Int64],
merge_input_fields,
)
.unwrap()
.into(),
@@ -459,6 +460,7 @@ async fn test_avg_udaf() {
)])
);
let merge_input_fields = vec![Arc::new(Field::new("number", DataType::Float64, true))];
let expected_merge_fn = MergeWrapper::new(
avg.clone(),
Arc::new(
@@ -474,7 +476,7 @@ async fn test_avg_udaf() {
.unwrap(),
),
// coerced to float64
vec![DataType::Float64],
merge_input_fields,
)
.unwrap();
@@ -668,13 +670,15 @@ async fn test_last_value_order_by_udaf() {
.unwrap();
let aggr_func_expr = &aggr_exec.aggr_expr()[0];
let merge_input_fields = vec![Arc::new(Field::new(
"ts",
DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None),
false,
))];
let expected_merge_fn = MergeWrapper::new(
last_value.clone(),
aggr_func_expr.clone(),
vec![DataType::Timestamp(
arrow_schema::TimeUnit::Millisecond,
None,
)],
merge_input_fields,
)
.unwrap();