fix: group by expr not as column in step aggr (#7008)

* fix: group by expr not as column

Signed-off-by: discord9 <discord9@163.com>

* test: dist analyzer date_bin

Signed-off-by: discord9 <discord9@163.com>

* ???fix wip

Signed-off-by: discord9 <discord9@163.com>

* fix: deduce using correct input fields

Signed-off-by: discord9 <discord9@163.com>

* refactor: clearer wrapper

Signed-off-by: discord9 <discord9@163.com>

* chore: update sqlness

Signed-off-by: discord9 <discord9@163.com>

* chore: per review

Signed-off-by: discord9 <discord9@163.com>

* chore: per review

Signed-off-by: discord9 <discord9@163.com>

* chore: rm todo

Signed-off-by: discord9 <discord9@163.com>

---------

Signed-off-by: discord9 <discord9@163.com>
This commit is contained in:
discord9
2025-09-24 14:57:01 +08:00
committed by GitHub
parent 0c038f755f
commit 238ed003df
12 changed files with 830 additions and 154 deletions

View File

@@ -137,6 +137,15 @@ impl StateMergeHelper {
let mut lower_aggr_exprs = vec![];
let mut upper_aggr_exprs = vec![];
// group exprs for upper plan should refer to the output group expr as column from lower plan
// to avoid re-compute group exprs again.
let upper_group_exprs = aggr
.group_expr
.iter()
.map(|c| c.qualified_name())
.map(|(r, c)| Expr::Column(Column::new(r, c)))
.collect();
for aggr_expr in aggr.aggr_expr.iter() {
let Some(aggr_func) = get_aggr_func(aggr_expr) else {
return Err(datafusion_common::DataFusionError::NotImplemented(format!(
@@ -198,10 +207,13 @@ impl StateMergeHelper {
// update aggregate's output schema
let lower_plan = lower_plan.recompute_schema()?;
let mut upper = aggr.clone();
let upper = Aggregate::try_new(
Arc::new(lower_plan.clone()),
upper_group_exprs,
upper_aggr_exprs.clone(),
)?;
let aggr_plan = LogicalPlan::Aggregate(aggr);
upper.aggr_expr = upper_aggr_exprs;
upper.input = Arc::new(lower_plan.clone());
// upper schema's output schema should be the same as the original aggregate plan's output schema
let upper_check = upper;
let upper_plan = LogicalPlan::Aggregate(upper_check).recompute_schema()?;
@@ -245,7 +257,19 @@ impl StateWrapper {
&self,
acc_args: &datafusion_expr::function::AccumulatorArgs,
) -> datafusion_common::Result<FieldRef> {
self.inner.return_field(acc_args.schema.fields())
let input_fields = acc_args
.exprs
.iter()
.map(|e| e.return_field(acc_args.schema))
.collect::<Result<Vec<_>, _>>()?;
self.inner.return_field(&input_fields).inspect_err(|e| {
common_telemetry::error!(
"StateWrapper: {:#?}\nacc_args:{:?}\nerror:{:?}",
&self,
&acc_args,
e
);
})
}
}
@@ -402,7 +426,7 @@ pub struct MergeWrapper {
merge_signature: Signature,
/// The original physical expression of the aggregate function, can't store the original aggregate function directly, as PhysicalExpr didn't implement Any
original_phy_expr: Arc<AggregateFunctionExpr>,
original_input_types: Vec<DataType>,
return_type: DataType,
}
impl MergeWrapper {
pub fn new(
@@ -413,13 +437,14 @@ impl MergeWrapper {
let name = aggr_merge_func_name(inner.name());
// the input type is actually struct type, which is the state fields of the original aggregate function.
let merge_signature = Signature::user_defined(datafusion_expr::Volatility::Immutable);
let return_type = inner.return_type(&original_input_types)?;
Ok(Self {
inner,
name,
merge_signature,
original_phy_expr,
original_input_types,
return_type,
})
}
@@ -471,8 +496,7 @@ impl AggregateUDFImpl for MergeWrapper {
/// so return fixed return type instead of using `arg_types` to determine the return type.
fn return_type(&self, _arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
// The return type is the same as the original aggregate function's return type.
let ret_type = self.inner.return_type(&self.original_input_types)?;
Ok(ret_type)
Ok(self.return_type.clone())
}
fn signature(&self) -> &Signature {
&self.merge_signature

View File

@@ -23,6 +23,7 @@ use datafusion::catalog::{Session, TableProvider};
use datafusion::datasource::DefaultTableSource;
use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
use datafusion::functions_aggregate::average::avg_udaf;
use datafusion::functions_aggregate::count::count_udaf;
use datafusion::functions_aggregate::sum::sum_udaf;
use datafusion::optimizer::AnalyzerRule;
use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
@@ -555,6 +556,7 @@ async fn test_udaf_correct_eval_result() {
input_schema: SchemaRef,
input: Vec<ArrayRef>,
expected_output: Option<ScalarValue>,
// extra check function on the final array result
expected_fn: Option<ExpectedFn>,
distinct: bool,
filter: Option<Box<Expr>>,
@@ -585,6 +587,27 @@ async fn test_udaf_correct_eval_result() {
order_by: vec![],
null_treatment: None,
},
TestCase {
func: count_udaf(),
input_schema: Arc::new(arrow_schema::Schema::new(vec![Field::new(
"str_val",
DataType::Utf8,
true,
)])),
args: vec![Expr::Column(Column::new_unqualified("str_val"))],
input: vec![Arc::new(StringArray::from(vec![
Some("hello"),
Some("world"),
None,
Some("what"),
]))],
expected_output: Some(ScalarValue::Int64(Some(3))),
expected_fn: None,
distinct: false,
filter: None,
order_by: vec![],
null_treatment: None,
},
TestCase {
func: avg_udaf(),
input_schema: Arc::new(arrow_schema::Schema::new(vec![Field::new(