mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-19 06:20:38 +00:00
chore: aggr wrapper use return_field (#7582)
Signed-off-by: discord9 <discord9@163.com>
This commit is contained in:
@@ -184,11 +184,11 @@ impl StateMergeHelper {
|
||||
)));
|
||||
};
|
||||
|
||||
let original_input_types = aggr_func
|
||||
let original_input_fields = aggr_func
|
||||
.params
|
||||
.args
|
||||
.iter()
|
||||
.map(|e| e.get_type(&aggr.input.schema()))
|
||||
.map(|e| e.to_field(&aggr.input.schema()).map(|(_, field)| field))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
// first create the state function from the original aggregate function.
|
||||
@@ -214,7 +214,7 @@ impl StateMergeHelper {
|
||||
let merge_func = MergeWrapper::new(
|
||||
(*aggr_func.func).clone(),
|
||||
original_phy_expr,
|
||||
original_input_types,
|
||||
original_input_fields,
|
||||
)?;
|
||||
let arg = Expr::Column(Column::new_unqualified(lower_state_output_col_name));
|
||||
let expr = AggregateFunction {
|
||||
@@ -332,18 +332,9 @@ impl AggregateUDFImpl for StateWrapper {
|
||||
// fix and recover proper acc args for the original aggregate function.
|
||||
let state_type = acc_args.return_type().clone();
|
||||
let inner = {
|
||||
let acc_args = datafusion_expr::function::AccumulatorArgs {
|
||||
return_field: self.deduce_aggr_return_type(&acc_args)?,
|
||||
schema: acc_args.schema,
|
||||
ignore_nulls: acc_args.ignore_nulls,
|
||||
order_bys: acc_args.order_bys,
|
||||
is_reversed: acc_args.is_reversed,
|
||||
name: acc_args.name,
|
||||
is_distinct: acc_args.is_distinct,
|
||||
exprs: acc_args.exprs,
|
||||
expr_fields: acc_args.expr_fields,
|
||||
};
|
||||
self.inner.accumulator(acc_args)?
|
||||
let mut new_acc_args = acc_args.clone();
|
||||
new_acc_args.return_field = self.deduce_aggr_return_type(&acc_args)?;
|
||||
self.inner.accumulator(new_acc_args)?
|
||||
};
|
||||
|
||||
Ok(Box::new(StateAccum::new(inner, state_type)?))
|
||||
@@ -568,25 +559,25 @@ pub struct MergeWrapper {
|
||||
merge_signature: Signature,
|
||||
/// The original physical expression of the aggregate function, can't store the original aggregate function directly, as PhysicalExpr didn't implement Any
|
||||
original_phy_expr: Arc<AggregateFunctionExpr>,
|
||||
return_type: DataType,
|
||||
return_field: FieldRef,
|
||||
}
|
||||
impl MergeWrapper {
|
||||
pub fn new(
|
||||
inner: AggregateUDF,
|
||||
original_phy_expr: Arc<AggregateFunctionExpr>,
|
||||
original_input_types: Vec<DataType>,
|
||||
original_input_fields: Vec<FieldRef>,
|
||||
) -> datafusion_common::Result<Self> {
|
||||
let name = aggr_merge_func_name(inner.name());
|
||||
// the input type is actually struct type, which is the state fields of the original aggregate function.
|
||||
let merge_signature = Signature::user_defined(datafusion_expr::Volatility::Immutable);
|
||||
let return_type = inner.return_type(&original_input_types)?;
|
||||
let return_field = inner.return_field(&original_input_fields)?.clone();
|
||||
|
||||
Ok(Self {
|
||||
inner,
|
||||
name,
|
||||
merge_signature,
|
||||
original_phy_expr,
|
||||
return_type,
|
||||
return_field,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -638,8 +629,14 @@ impl AggregateUDFImpl for MergeWrapper {
|
||||
/// so return fixed return type instead of using `arg_types` to determine the return type.
|
||||
fn return_type(&self, _arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
|
||||
// The return type is the same as the original aggregate function's return type.
|
||||
Ok(self.return_type.clone())
|
||||
Ok(self.return_field.data_type().clone())
|
||||
}
|
||||
|
||||
/// Similar to return_type, we just return the fixed return field.
|
||||
fn return_field(&self, _arg_fields: &[FieldRef]) -> datafusion_common::Result<FieldRef> {
|
||||
Ok(self.return_field.clone())
|
||||
}
|
||||
|
||||
fn signature(&self) -> &Signature {
|
||||
&self.merge_signature
|
||||
}
|
||||
|
||||
@@ -296,6 +296,7 @@ async fn test_sum_udaf() {
|
||||
.unwrap();
|
||||
assert_eq!(&res.lower_state, &expected_lower_plan);
|
||||
|
||||
let merge_input_fields = vec![Arc::new(Field::new("number", DataType::Int64, true))];
|
||||
let expected_merge_plan = LogicalPlan::Aggregate(
|
||||
Aggregate::try_new(
|
||||
Arc::new(expected_lower_plan),
|
||||
@@ -319,7 +320,7 @@ async fn test_sum_udaf() {
|
||||
.build()
|
||||
.unwrap(),
|
||||
),
|
||||
vec![DataType::Int64],
|
||||
merge_input_fields,
|
||||
)
|
||||
.unwrap()
|
||||
.into(),
|
||||
@@ -459,6 +460,7 @@ async fn test_avg_udaf() {
|
||||
)])
|
||||
);
|
||||
|
||||
let merge_input_fields = vec![Arc::new(Field::new("number", DataType::Float64, true))];
|
||||
let expected_merge_fn = MergeWrapper::new(
|
||||
avg.clone(),
|
||||
Arc::new(
|
||||
@@ -474,7 +476,7 @@ async fn test_avg_udaf() {
|
||||
.unwrap(),
|
||||
),
|
||||
// coerced to float64
|
||||
vec![DataType::Float64],
|
||||
merge_input_fields,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -668,13 +670,15 @@ async fn test_last_value_order_by_udaf() {
|
||||
.unwrap();
|
||||
let aggr_func_expr = &aggr_exec.aggr_expr()[0];
|
||||
|
||||
let merge_input_fields = vec![Arc::new(Field::new(
|
||||
"ts",
|
||||
DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None),
|
||||
false,
|
||||
))];
|
||||
let expected_merge_fn = MergeWrapper::new(
|
||||
last_value.clone(),
|
||||
aggr_func_expr.clone(),
|
||||
vec![DataType::Timestamp(
|
||||
arrow_schema::TimeUnit::Millisecond,
|
||||
None,
|
||||
)],
|
||||
merge_input_fields,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user