fix: count_state use stat to eval&predicate w/out region (#7116)

* fix: count_state use stat to eval

Signed-off-by: discord9 <discord9@163.com>

* cleanup

Signed-off-by: discord9 <discord9@163.com>

* fix: use predicate without region

Signed-off-by: discord9 <discord9@163.com>

* test: diverge standalone/dist impl

Signed-off-by: discord9 <discord9@163.com>

---------

Signed-off-by: discord9 <discord9@163.com>
This commit is contained in:
discord9
2025-10-27 10:14:45 +08:00
committed by GitHub
parent e386a366d0
commit 68247fc9b1
13 changed files with 1221 additions and 177 deletions

View File

@@ -29,6 +29,8 @@ use arrow::array::StructArray;
use arrow_schema::{FieldRef, Fields};
use common_telemetry::debug;
use datafusion::functions_aggregate::all_default_aggregate_functions;
use datafusion::functions_aggregate::count::Count;
use datafusion::functions_aggregate::min_max::{Max, Min};
use datafusion::optimizer::AnalyzerRule;
use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
use datafusion::physical_planner::create_aggregate_expr_and_maybe_filter;
@@ -413,6 +415,51 @@ impl AggregateUDFImpl for StateWrapper {
fn coerce_types(&self, arg_types: &[DataType]) -> datafusion_common::Result<Vec<DataType>> {
self.inner.coerce_types(arg_types)
}
fn value_from_stats(
&self,
statistics_args: &datafusion_expr::StatisticsArgs,
) -> Option<ScalarValue> {
let inner = self.inner().inner().as_any();
// only count/min/max need special handling here, for getting result from statistics
// the result of count/min/max is also the result of count_state so can return directly
let can_use_stat = inner.is::<Count>() || inner.is::<Max>() || inner.is::<Min>();
if !can_use_stat {
return None;
}
// fix return type by extract the first field's data type from the struct type
let state_type = if let DataType::Struct(fields) = &statistics_args.return_type {
if fields.is_empty() {
return None;
}
fields[0].data_type().clone()
} else {
return None;
};
let fixed_args = datafusion_expr::StatisticsArgs {
statistics: statistics_args.statistics,
return_type: &state_type,
is_distinct: statistics_args.is_distinct,
exprs: statistics_args.exprs,
};
let ret = self.inner().value_from_stats(&fixed_args)?;
// wrap the result into struct scalar value
let fields = if let DataType::Struct(fields) = &statistics_args.return_type {
fields
} else {
return None;
};
let array = ret.to_array().ok()?;
let struct_array = StructArray::new(fields.clone(), vec![array], None);
let ret = ScalarValue::Struct(Arc::new(struct_array));
Some(ret)
}
}
/// The wrapper's input is the same as the original aggregate function's input,