feat: state/merge wrapper for aggr func (#6377)

* refactor: move to query crate

Signed-off-by: discord9 <discord9@163.com>

* refactor: split to multiple columns

Signed-off-by: discord9 <discord9@163.com>

* feat: aggr merge accum wrapper

Signed-off-by: discord9 <discord9@163.com>

* rename shorter

Signed-off-by: discord9 <discord9@163.com>

* feat: add all in one helper

Signed-off-by: discord9 <discord9@163.com>

* tests: sum&avg

Signed-off-by: discord9 <discord9@163.com>

* chore: allow unused

Signed-off-by: discord9 <discord9@163.com>

* chore: typos

Signed-off-by: discord9 <discord9@163.com>

* refactor: per ds

Signed-off-by: discord9 <discord9@163.com>

* chore: fix tests

Signed-off-by: discord9 <discord9@163.com>

* refactor: move to common-function

Signed-off-by: discord9 <discord9@163.com>

* WIP massive refactor

Signed-off-by: discord9 <discord9@163.com>

* typo

Signed-off-by: discord9 <discord9@163.com>

* todo: stuff

Signed-off-by: discord9 <discord9@163.com>

* refactor: state2input type

Signed-off-by: discord9 <discord9@163.com>

* chore: rm unused

Signed-off-by: discord9 <discord9@163.com>

* refactor: per bot review

Signed-off-by: discord9 <discord9@163.com>

* chore: per bot

Signed-off-by: discord9 <discord9@163.com>

* refactor: rm duplicate infer type

Signed-off-by: discord9 <discord9@163.com>

* chore: better test

Signed-off-by: discord9 <discord9@163.com>

* fix: test sum refactor&fix wrong state types

Signed-off-by: discord9 <discord9@163.com>

* test: refactor avg udaf test

Signed-off-by: discord9 <discord9@163.com>

* refactor: split files

Signed-off-by: discord9 <discord9@163.com>

* refactor: docs&dedup

Signed-off-by: discord9 <discord9@163.com>

* refactor: allow merge to carry extra info

Signed-off-by: discord9 <discord9@163.com>

* chore: rm unused

Signed-off-by: discord9 <discord9@163.com>

* chore: clippy

Signed-off-by: discord9 <discord9@163.com>

* chore: docs&unused

Signed-off-by: discord9 <discord9@163.com>

* refactor: check fields equal

Signed-off-by: discord9 <discord9@163.com>

* test: test count_hash

Signed-off-by: discord9 <discord9@163.com>

* test: more custom udaf

Signed-off-by: discord9 <discord9@163.com>

* chore: clippy

Signed-off-by: discord9 <discord9@163.com>

* refactor: per review

Signed-off-by: discord9 <discord9@163.com>

---------

Signed-off-by: discord9 <discord9@163.com>
This commit is contained in:
discord9
2025-07-18 01:37:40 +08:00
committed by GitHub
parent 0cf25f7b05
commit 77b540ff68
5 changed files with 1354 additions and 0 deletions

5
Cargo.lock generated
View File

@@ -2347,6 +2347,8 @@ dependencies = [
"api",
"approx 0.5.1",
"arc-swap",
"arrow 54.2.1",
"arrow-schema 54.3.1",
"async-trait",
"bincode",
"catalog",
@@ -2365,8 +2367,10 @@ dependencies = [
"datafusion-common",
"datafusion-expr",
"datafusion-functions-aggregate-common",
"datafusion-physical-expr",
"datatypes",
"derive_more",
"futures",
"geo",
"geo-types",
"geohash",
@@ -2379,6 +2383,7 @@ dependencies = [
"num-traits",
"once_cell",
"paste",
"pretty_assertions",
"s2",
"serde",
"serde_json",

View File

@@ -16,6 +16,8 @@ geo = ["geohash", "h3o", "s2", "wkt", "geo-types", "dep:geo"]
ahash.workspace = true
api.workspace = true
arc-swap = "1.0"
arrow.workspace = true
arrow-schema.workspace = true
async-trait.workspace = true
bincode = "1.3"
catalog.workspace = true
@@ -34,6 +36,7 @@ datafusion.workspace = true
datafusion-common.workspace = true
datafusion-expr.workspace = true
datafusion-functions-aggregate-common.workspace = true
datafusion-physical-expr.workspace = true
datatypes.workspace = true
derive_more = { version = "1", default-features = false, features = ["display"] }
geo = { version = "0.29", optional = true }
@@ -62,5 +65,7 @@ wkt = { version = "0.11", optional = true }
[dev-dependencies]
approx = "0.5"
futures.workspace = true
pretty_assertions = "1.4.0"
serde = { version = "1.0", features = ["derive"] }
tokio.workspace = true

View File

@@ -17,3 +17,5 @@ pub mod count_hash;
#[cfg(feature = "geo")]
pub mod geo;
pub mod vector;
pub mod aggr_wrapper;

View File

@@ -0,0 +1,538 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Wrapper for making aggregate functions out of state/merge functions of original aggregate functions.
//!
//! i.e. for a aggregate function `foo`, we will have a state function `foo_state` and a merge function `foo_merge`.
//!
//! `foo_state`'s input args is the same as `foo`'s, and its output is a state object.
//! Note that `foo_state` might have multiple output columns, so it's a struct array
//! that each output column is a struct field.
//! `foo_merge`'s input arg is the same as `foo_state`'s output, and its output is the same as `foo`'s input.
//!
use std::sync::Arc;
use arrow::array::StructArray;
use arrow_schema::Fields;
use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
use datafusion::optimizer::AnalyzerRule;
use datafusion::physical_planner::create_aggregate_expr_and_maybe_filter;
use datafusion_common::{Column, ScalarValue};
use datafusion_expr::expr::AggregateFunction;
use datafusion_expr::function::StateFieldsArgs;
use datafusion_expr::{
Accumulator, Aggregate, AggregateUDF, AggregateUDFImpl, Expr, ExprSchemable, LogicalPlan,
Signature,
};
use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
use datatypes::arrow::datatypes::{DataType, Field};
/// Returns the name of the state function for the given aggregate function name.
/// The state function is used to compute the state of the aggregate function.
/// The state function's name is in the format `__<aggr_name>_state
pub fn aggr_state_func_name(aggr_name: &str) -> String {
format!("__{}_state", aggr_name)
}
/// Returns the name of the merge function for the given aggregate function name.
/// The merge function is used to merge the states of the state functions.
/// The merge function's name is in the format `__<aggr_name>_merge
pub fn aggr_merge_func_name(aggr_name: &str) -> String {
format!("__{}_merge", aggr_name)
}
/// A wrapper to make an aggregate function out of the state and merge functions of the original aggregate function.
/// It contains the original aggregate function, the state functions, and the merge function.
///
/// Notice state functions may have multiple output columns, so it's return type is always a struct array, and the merge function is used to merge the states of the state functions.
#[derive(Debug, Clone)]
pub struct StateMergeHelper;
/// A struct to hold the two aggregate plans, one for the state function(lower) and one for the merge function(upper).
#[allow(unused)]
#[derive(Debug, Clone)]
pub struct StepAggrPlan {
/// Upper merge plan, which is the aggregate plan that merges the states of the state function.
pub upper_merge: Arc<LogicalPlan>,
/// Lower state plan, which is the aggregate plan that computes the state of the aggregate function.
pub lower_state: Arc<LogicalPlan>,
}
pub fn get_aggr_func(expr: &Expr) -> Option<&datafusion_expr::expr::AggregateFunction> {
let mut expr_ref = expr;
while let Expr::Alias(alias) = expr_ref {
expr_ref = &alias.expr;
}
if let Expr::AggregateFunction(aggr_func) = expr_ref {
Some(aggr_func)
} else {
None
}
}
impl StateMergeHelper {
/// Split an aggregate plan into two aggregate plans, one for the state function and one for the merge function.
pub fn split_aggr_node(aggr_plan: Aggregate) -> datafusion_common::Result<StepAggrPlan> {
let aggr = {
// certain aggr func need type coercion to work correctly, so we need to analyze the plan first.
let aggr_plan = TypeCoercion::new().analyze(
LogicalPlan::Aggregate(aggr_plan).clone(),
&Default::default(),
)?;
if let LogicalPlan::Aggregate(aggr) = aggr_plan {
aggr
} else {
return Err(datafusion_common::DataFusionError::Internal(format!(
"Failed to coerce expressions in aggregate plan, expected Aggregate, got: {:?}",
aggr_plan
)));
}
};
let mut lower_aggr_exprs = vec![];
let mut upper_aggr_exprs = vec![];
for aggr_expr in aggr.aggr_expr.iter() {
let Some(aggr_func) = get_aggr_func(aggr_expr) else {
return Err(datafusion_common::DataFusionError::NotImplemented(format!(
"Unsupported aggregate expression for step aggr optimize: {:?}",
aggr_expr
)));
};
let original_input_types = aggr_func
.args
.iter()
.map(|e| e.get_type(&aggr.input.schema()))
.collect::<Result<Vec<_>, _>>()?;
// first create the state function from the original aggregate function.
let state_func = StateWrapper::new((*aggr_func.func).clone())?;
let expr = AggregateFunction {
func: Arc::new(state_func.into()),
args: aggr_func.args.clone(),
distinct: aggr_func.distinct,
filter: aggr_func.filter.clone(),
order_by: aggr_func.order_by.clone(),
null_treatment: aggr_func.null_treatment,
};
let expr = Expr::AggregateFunction(expr);
let lower_state_output_col_name = expr.schema_name().to_string();
lower_aggr_exprs.push(expr);
let (original_phy_expr, _filter, _ordering) = create_aggregate_expr_and_maybe_filter(
aggr_expr,
aggr.input.schema(),
aggr.input.schema().as_arrow(),
&Default::default(),
)?;
let merge_func = MergeWrapper::new(
(*aggr_func.func).clone(),
original_phy_expr,
original_input_types,
)?;
let arg = Expr::Column(Column::new_unqualified(lower_state_output_col_name));
let expr = AggregateFunction {
func: Arc::new(merge_func.into()),
args: vec![arg],
distinct: aggr_func.distinct,
filter: aggr_func.filter.clone(),
order_by: aggr_func.order_by.clone(),
null_treatment: aggr_func.null_treatment,
};
// alias to the original aggregate expr's schema name, so parent plan can refer to it
// correctly.
let expr = Expr::AggregateFunction(expr).alias(aggr_expr.schema_name().to_string());
upper_aggr_exprs.push(expr);
}
let mut lower = aggr.clone();
lower.aggr_expr = lower_aggr_exprs;
let lower_plan = LogicalPlan::Aggregate(lower);
// update aggregate's output schema
let lower_plan = Arc::new(lower_plan.recompute_schema()?);
let mut upper = aggr.clone();
let aggr_plan = LogicalPlan::Aggregate(aggr);
upper.aggr_expr = upper_aggr_exprs;
upper.input = lower_plan.clone();
// upper schema's output schema should be the same as the original aggregate plan's output schema
let upper_check = upper.clone();
let upper_plan = Arc::new(LogicalPlan::Aggregate(upper_check).recompute_schema()?);
if *upper_plan.schema() != *aggr_plan.schema() {
return Err(datafusion_common::DataFusionError::Internal(format!(
"Upper aggregate plan's schema is not the same as the original aggregate plan's schema: \n[transformed]:{}\n[ original]{}",
upper_plan.schema(), aggr_plan.schema()
)));
}
Ok(StepAggrPlan {
lower_state: lower_plan,
upper_merge: upper_plan,
})
}
}
/// Wrapper to make an aggregate function out of a state function.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct StateWrapper {
inner: AggregateUDF,
name: String,
}
impl StateWrapper {
/// `state_index`: The index of the state in the output of the state function.
pub fn new(inner: AggregateUDF) -> datafusion_common::Result<Self> {
let name = aggr_state_func_name(inner.name());
Ok(Self { inner, name })
}
pub fn inner(&self) -> &AggregateUDF {
&self.inner
}
/// Deduce the return type of the original aggregate function
/// based on the accumulator arguments.
///
pub fn deduce_aggr_return_type(
&self,
acc_args: &datafusion_expr::function::AccumulatorArgs,
) -> datafusion_common::Result<DataType> {
let input_exprs = acc_args.exprs;
let input_schema = acc_args.schema;
let input_types = input_exprs
.iter()
.map(|e| e.data_type(input_schema))
.collect::<Result<Vec<_>, _>>()?;
let return_type = self.inner.return_type(&input_types)?;
Ok(return_type)
}
}
impl AggregateUDFImpl for StateWrapper {
fn accumulator<'a, 'b>(
&'a self,
acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
) -> datafusion_common::Result<Box<dyn Accumulator>> {
// fix and recover proper acc args for the original aggregate function.
let state_type = acc_args.return_type.clone();
let inner = {
let old_return_type = self.deduce_aggr_return_type(&acc_args)?;
let acc_args = datafusion_expr::function::AccumulatorArgs {
return_type: &old_return_type,
schema: acc_args.schema,
ignore_nulls: acc_args.ignore_nulls,
ordering_req: acc_args.ordering_req,
is_reversed: acc_args.is_reversed,
name: acc_args.name,
is_distinct: acc_args.is_distinct,
exprs: acc_args.exprs,
};
self.inner.accumulator(acc_args)?
};
Ok(Box::new(StateAccum::new(inner, state_type)?))
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn name(&self) -> &str {
self.name.as_str()
}
fn is_nullable(&self) -> bool {
self.inner.is_nullable()
}
/// Return state_fields as the output struct type.
///
fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
let old_return_type = self.inner.return_type(arg_types)?;
let state_fields_args = StateFieldsArgs {
name: self.inner().name(),
input_types: arg_types,
return_type: &old_return_type,
// TODO(discord9): how to get this?, probably ok?
ordering_fields: &[],
is_distinct: false,
};
let state_fields = self.inner.state_fields(state_fields_args)?;
let struct_field = DataType::Struct(state_fields.into());
Ok(struct_field)
}
/// The state function's output fields are the same as the original aggregate function's state fields.
fn state_fields(
&self,
args: datafusion_expr::function::StateFieldsArgs,
) -> datafusion_common::Result<Vec<Field>> {
let old_return_type = self.inner.return_type(args.input_types)?;
let state_fields_args = StateFieldsArgs {
name: args.name,
input_types: args.input_types,
return_type: &old_return_type,
ordering_fields: args.ordering_fields,
is_distinct: args.is_distinct,
};
self.inner.state_fields(state_fields_args)
}
/// The state function's signature is the same as the original aggregate function's signature,
fn signature(&self) -> &Signature {
self.inner.signature()
}
/// Coerce types also do nothing, as optimzer should be able to already make struct types
fn coerce_types(&self, arg_types: &[DataType]) -> datafusion_common::Result<Vec<DataType>> {
self.inner.coerce_types(arg_types)
}
}
/// The wrapper's input is the same as the original aggregate function's input,
/// and the output is the state function's output.
#[derive(Debug)]
pub struct StateAccum {
inner: Box<dyn Accumulator>,
state_fields: Fields,
}
impl StateAccum {
pub fn new(
inner: Box<dyn Accumulator>,
state_type: DataType,
) -> datafusion_common::Result<Self> {
let DataType::Struct(fields) = state_type else {
return Err(datafusion_common::DataFusionError::Internal(format!(
"Expected a struct type for state, got: {:?}",
state_type
)));
};
Ok(Self {
inner,
state_fields: fields,
})
}
}
impl Accumulator for StateAccum {
fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
let state = self.inner.state()?;
let array = state
.iter()
.map(|s| s.to_array())
.collect::<Result<Vec<_>, _>>()?;
let struct_array = StructArray::try_new(self.state_fields.clone(), array, None)?;
Ok(ScalarValue::Struct(Arc::new(struct_array)))
}
fn merge_batch(
&mut self,
states: &[datatypes::arrow::array::ArrayRef],
) -> datafusion_common::Result<()> {
self.inner.merge_batch(states)
}
fn update_batch(
&mut self,
values: &[datatypes::arrow::array::ArrayRef],
) -> datafusion_common::Result<()> {
self.inner.update_batch(values)
}
fn size(&self) -> usize {
self.inner.size()
}
fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
self.inner.state()
}
}
/// TODO(discord9): mark this function as non-ser/de able
///
/// This wrapper shouldn't be register as a udaf, as it contain extra data that is not serializable.
/// and changes for different logical plans.
#[derive(Debug, Clone)]
pub struct MergeWrapper {
inner: AggregateUDF,
name: String,
merge_signature: Signature,
/// The original physical expression of the aggregate function, can't store the original aggregate function directly, as PhysicalExpr didn't implement Any
original_phy_expr: Arc<AggregateFunctionExpr>,
original_input_types: Vec<DataType>,
}
impl MergeWrapper {
pub fn new(
inner: AggregateUDF,
original_phy_expr: Arc<AggregateFunctionExpr>,
original_input_types: Vec<DataType>,
) -> datafusion_common::Result<Self> {
let name = aggr_merge_func_name(inner.name());
// the input type is actually struct type, which is the state fields of the original aggregate function.
let merge_signature = Signature::user_defined(datafusion_expr::Volatility::Immutable);
Ok(Self {
inner,
name,
merge_signature,
original_phy_expr,
original_input_types,
})
}
pub fn inner(&self) -> &AggregateUDF {
&self.inner
}
}
impl AggregateUDFImpl for MergeWrapper {
fn accumulator<'a, 'b>(
&'a self,
acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
) -> datafusion_common::Result<Box<dyn Accumulator>> {
if acc_args.schema.fields().len() != 1
|| !matches!(acc_args.schema.field(0).data_type(), DataType::Struct(_))
{
return Err(datafusion_common::DataFusionError::Internal(format!(
"Expected one struct type as input, got: {:?}",
acc_args.schema
)));
}
let input_type = acc_args.schema.field(0).data_type();
let DataType::Struct(fields) = input_type else {
return Err(datafusion_common::DataFusionError::Internal(format!(
"Expected a struct type for input, got: {:?}",
input_type
)));
};
let inner_accum = self.original_phy_expr.create_accumulator()?;
Ok(Box::new(MergeAccum::new(inner_accum, fields)))
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn name(&self) -> &str {
self.name.as_str()
}
fn is_nullable(&self) -> bool {
self.inner.is_nullable()
}
/// Notice here the `arg_types` is actually the `state_fields`'s data types,
/// so return fixed return type instead of using `arg_types` to determine the return type.
fn return_type(&self, _arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
// The return type is the same as the original aggregate function's return type.
let ret_type = self.inner.return_type(&self.original_input_types)?;
Ok(ret_type)
}
fn signature(&self) -> &Signature {
&self.merge_signature
}
/// Coerce types also do nothing, as optimzer should be able to already make struct types
fn coerce_types(&self, arg_types: &[DataType]) -> datafusion_common::Result<Vec<DataType>> {
// just check if the arg_types are only one and is struct array
if arg_types.len() != 1 || !matches!(arg_types.first(), Some(DataType::Struct(_))) {
return Err(datafusion_common::DataFusionError::Internal(format!(
"Expected one struct type as input, got: {:?}",
arg_types
)));
}
Ok(arg_types.to_vec())
}
/// Just return the original aggregate function's state fields.
fn state_fields(
&self,
_args: datafusion_expr::function::StateFieldsArgs,
) -> datafusion_common::Result<Vec<Field>> {
self.original_phy_expr.state_fields()
}
}
/// The merge accumulator, which modify `update_batch`'s behavior to accept one struct array which
/// include the state fields of original aggregate function, and merge said states into original accumulator
/// the output is the same as original aggregate function
#[derive(Debug)]
pub struct MergeAccum {
inner: Box<dyn Accumulator>,
state_fields: Fields,
}
impl MergeAccum {
pub fn new(inner: Box<dyn Accumulator>, state_fields: &Fields) -> Self {
Self {
inner,
state_fields: state_fields.clone(),
}
}
}
impl Accumulator for MergeAccum {
fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
self.inner.evaluate()
}
fn merge_batch(&mut self, states: &[arrow::array::ArrayRef]) -> datafusion_common::Result<()> {
self.inner.merge_batch(states)
}
fn update_batch(&mut self, values: &[arrow::array::ArrayRef]) -> datafusion_common::Result<()> {
let value = values.first().ok_or_else(|| {
datafusion_common::DataFusionError::Internal("No values provided for merge".to_string())
})?;
// The input values are states from other accumulators, so we merge them.
let struct_arr = value
.as_any()
.downcast_ref::<StructArray>()
.ok_or_else(|| {
datafusion_common::DataFusionError::Internal(format!(
"Expected StructArray, got: {:?}",
value.data_type()
))
})?;
let fields = struct_arr.fields();
if fields != &self.state_fields {
return Err(datafusion_common::DataFusionError::Internal(format!(
"Expected state fields: {:?}, got: {:?}",
self.state_fields, fields
)));
}
// now fields should be the same, so we can merge the batch
// by pass the columns as order should be the same
let state_columns = struct_arr.columns();
self.inner.merge_batch(state_columns)
}
fn size(&self) -> usize {
self.inner.size()
}
fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
self.inner.state()
}
}
#[cfg(test)]
mod tests;

View File

@@ -0,0 +1,804 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::any::Any;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use arrow::array::{ArrayRef, Float64Array, Int64Array, UInt64Array};
use arrow::record_batch::RecordBatch;
use arrow_schema::SchemaRef;
use datafusion::catalog::{Session, TableProvider};
use datafusion::datasource::DefaultTableSource;
use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
use datafusion::functions_aggregate::average::avg_udaf;
use datafusion::functions_aggregate::sum::sum_udaf;
use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
use datafusion::optimizer::AnalyzerRule;
use datafusion::physical_plan::aggregates::AggregateExec;
use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner};
use datafusion::prelude::SessionContext;
use datafusion_common::{Column, TableReference};
use datafusion_expr::expr::AggregateFunction;
use datafusion_expr::sqlparser::ast::NullTreatment;
use datafusion_expr::{Aggregate, Expr, LogicalPlan, SortExpr, TableScan};
use datafusion_physical_expr::aggregate::AggregateExprBuilder;
use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
use datatypes::arrow_array::StringArray;
use futures::{Stream, StreamExt as _};
use pretty_assertions::assert_eq;
use super::*;
use crate::aggrs::approximate::hll::HllState;
use crate::aggrs::approximate::uddsketch::UddSketchState;
use crate::aggrs::count_hash::CountHash;
use crate::function::Function as _;
use crate::scalars::hll_count::HllCalcFunction;
use crate::scalars::uddsketch_calc::UddSketchCalcFunction;
#[derive(Debug)]
pub struct MockInputExec {
input: Vec<RecordBatch>,
schema: SchemaRef,
properties: PlanProperties,
}
impl MockInputExec {
pub fn new(input: Vec<RecordBatch>, schema: SchemaRef) -> Self {
Self {
properties: PlanProperties::new(
EquivalenceProperties::new(schema.clone()),
Partitioning::UnknownPartitioning(1),
EmissionType::Incremental,
Boundedness::Bounded,
),
input,
schema,
}
}
}
impl DisplayAs for MockInputExec {
fn fmt_as(&self, _t: DisplayFormatType, _f: &mut std::fmt::Formatter) -> std::fmt::Result {
unimplemented!()
}
}
impl ExecutionPlan for MockInputExec {
fn name(&self) -> &str {
"MockInputExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn properties(&self) -> &PlanProperties {
&self.properties
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![]
}
fn with_new_children(
self: Arc<Self>,
_children: Vec<Arc<dyn ExecutionPlan>>,
) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
Ok(self)
}
fn execute(
&self,
_partition: usize,
_context: Arc<TaskContext>,
) -> datafusion_common::Result<SendableRecordBatchStream> {
let stream = MockStream {
stream: self.input.clone(),
schema: self.schema.clone(),
idx: 0,
};
Ok(Box::pin(stream))
}
}
struct MockStream {
stream: Vec<RecordBatch>,
schema: SchemaRef,
idx: usize,
}
impl Stream for MockStream {
type Item = datafusion_common::Result<RecordBatch>;
fn poll_next(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<datafusion_common::Result<RecordBatch>>> {
if self.idx < self.stream.len() {
let ret = self.stream[self.idx].clone();
self.idx += 1;
Poll::Ready(Some(Ok(ret)))
} else {
Poll::Ready(None)
}
}
}
impl RecordBatchStream for MockStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}
#[derive(Debug)]
struct DummyTableProvider {
schema: Arc<arrow_schema::Schema>,
record_batch: Mutex<Option<RecordBatch>>,
}
impl DummyTableProvider {
#[allow(unused)]
pub fn new(schema: Arc<arrow_schema::Schema>, record_batch: Option<RecordBatch>) -> Self {
Self {
schema,
record_batch: Mutex::new(record_batch),
}
}
}
impl Default for DummyTableProvider {
fn default() -> Self {
Self {
schema: Arc::new(arrow_schema::Schema::new(vec![Field::new(
"number",
DataType::Int64,
true,
)])),
record_batch: Mutex::new(None),
}
}
}
#[async_trait::async_trait]
impl TableProvider for DummyTableProvider {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn schema(&self) -> Arc<arrow_schema::Schema> {
self.schema.clone()
}
fn table_type(&self) -> datafusion_expr::TableType {
datafusion_expr::TableType::Base
}
async fn scan(
&self,
_state: &dyn Session,
_projection: Option<&Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
) -> datafusion::error::Result<Arc<dyn ExecutionPlan>> {
let input: Vec<RecordBatch> = self
.record_batch
.lock()
.unwrap()
.clone()
.map(|r| vec![r])
.unwrap_or_default();
Ok(Arc::new(MockInputExec::new(input, self.schema.clone())))
}
}
fn dummy_table_scan() -> LogicalPlan {
let table_provider = Arc::new(DummyTableProvider::default());
let table_source = DefaultTableSource::new(table_provider);
LogicalPlan::TableScan(
TableScan::try_new(
TableReference::bare("Number"),
Arc::new(table_source),
None,
vec![],
None,
)
.unwrap(),
)
}
#[tokio::test]
async fn test_sum_udaf() {
let ctx = SessionContext::new();
let sum = datafusion::functions_aggregate::sum::sum_udaf();
let sum = (*sum).clone();
let original_aggr = Aggregate::try_new(
Arc::new(dummy_table_scan()),
vec![],
vec![Expr::AggregateFunction(AggregateFunction::new_udf(
Arc::new(sum.clone()),
vec![Expr::Column(Column::new_unqualified("number"))],
false,
None,
None,
None,
))],
)
.unwrap();
let res = StateMergeHelper::split_aggr_node(original_aggr).unwrap();
let expected_lower_plan = LogicalPlan::Aggregate(
Aggregate::try_new(
Arc::new(dummy_table_scan()),
vec![],
vec![Expr::AggregateFunction(AggregateFunction::new_udf(
Arc::new(StateWrapper::new(sum.clone()).unwrap().into()),
vec![Expr::Column(Column::new_unqualified("number"))],
false,
None,
None,
None,
))],
)
.unwrap(),
)
.recompute_schema()
.unwrap();
assert_eq!(res.lower_state.as_ref(), &expected_lower_plan);
let expected_merge_plan = LogicalPlan::Aggregate(
Aggregate::try_new(
Arc::new(expected_lower_plan),
vec![],
vec![Expr::AggregateFunction(AggregateFunction::new_udf(
Arc::new(
MergeWrapper::new(
sum.clone(),
Arc::new(
AggregateExprBuilder::new(
Arc::new(sum.clone()),
vec![Arc::new(
datafusion::physical_expr::expressions::Column::new(
"number", 0,
),
)],
)
.schema(Arc::new(dummy_table_scan().schema().as_arrow().clone()))
.alias("sum(number)")
.build()
.unwrap(),
),
vec![DataType::Int64],
)
.unwrap()
.into(),
),
vec![Expr::Column(Column::new_unqualified("__sum_state(number)"))],
false,
None,
None,
None,
))
.alias("sum(number)")],
)
.unwrap(),
);
assert_eq!(res.upper_merge.as_ref(), &expected_merge_plan);
let phy_aggr_state_plan = DefaultPhysicalPlanner::default()
.create_physical_plan(&res.lower_state, &ctx.state())
.await
.unwrap();
let aggr_exec = phy_aggr_state_plan
.as_any()
.downcast_ref::<AggregateExec>()
.unwrap();
let aggr_func_expr = &aggr_exec.aggr_expr()[0];
let mut state_accum = aggr_func_expr.create_accumulator().unwrap();
// evaluate the state function
let input = Int64Array::from(vec![Some(1), Some(2), None, Some(3)]);
let values = vec![Arc::new(input) as arrow::array::ArrayRef];
state_accum.update_batch(&values).unwrap();
let state = state_accum.state().unwrap();
assert_eq!(state.len(), 1);
assert_eq!(state[0], ScalarValue::Int64(Some(6)));
let eval_res = state_accum.evaluate().unwrap();
assert_eq!(
eval_res,
ScalarValue::Struct(Arc::new(
StructArray::try_new(
vec![Field::new("sum[sum]", DataType::Int64, true)].into(),
vec![Arc::new(Int64Array::from(vec![Some(6)]))],
None,
)
.unwrap(),
))
);
let phy_aggr_merge_plan = DefaultPhysicalPlanner::default()
.create_physical_plan(&res.upper_merge, &ctx.state())
.await
.unwrap();
let aggr_exec = phy_aggr_merge_plan
.as_any()
.downcast_ref::<AggregateExec>()
.unwrap();
let aggr_func_expr = &aggr_exec.aggr_expr()[0];
let mut merge_accum = aggr_func_expr.create_accumulator().unwrap();
let merge_input =
vec![Arc::new(Int64Array::from(vec![Some(6), Some(42), None])) as arrow::array::ArrayRef];
let merge_input_struct_arr = StructArray::try_new(
vec![Field::new("sum[sum]", DataType::Int64, true)].into(),
merge_input,
None,
)
.unwrap();
merge_accum
.update_batch(&[Arc::new(merge_input_struct_arr)])
.unwrap();
let merge_state = merge_accum.state().unwrap();
assert_eq!(merge_state.len(), 1);
assert_eq!(merge_state[0], ScalarValue::Int64(Some(48)));
let merge_eval_res = merge_accum.evaluate().unwrap();
assert_eq!(merge_eval_res, ScalarValue::Int64(Some(48)));
}
#[tokio::test]
async fn test_avg_udaf() {
let ctx = SessionContext::new();
let avg = datafusion::functions_aggregate::average::avg_udaf();
let avg = (*avg).clone();
let original_aggr = Aggregate::try_new(
Arc::new(dummy_table_scan()),
vec![],
vec![Expr::AggregateFunction(AggregateFunction::new_udf(
Arc::new(avg.clone()),
vec![Expr::Column(Column::new_unqualified("number"))],
false,
None,
None,
None,
))],
)
.unwrap();
let res = StateMergeHelper::split_aggr_node(original_aggr).unwrap();
let state_func: Arc<AggregateUDF> = Arc::new(StateWrapper::new(avg.clone()).unwrap().into());
let expected_aggr_state_plan = LogicalPlan::Aggregate(
Aggregate::try_new(
Arc::new(dummy_table_scan()),
vec![],
vec![Expr::AggregateFunction(AggregateFunction::new_udf(
state_func,
vec![Expr::Column(Column::new_unqualified("number"))],
false,
None,
None,
None,
))],
)
.unwrap(),
);
// type coerced so avg aggr function can function correctly
let coerced_aggr_state_plan = TypeCoercion::new()
.analyze(expected_aggr_state_plan.clone(), &Default::default())
.unwrap();
assert_eq!(res.lower_state.as_ref(), &coerced_aggr_state_plan);
assert_eq!(
res.lower_state.schema().as_arrow(),
&arrow_schema::Schema::new(vec![Field::new(
"__avg_state(number)",
DataType::Struct(
vec![
Field::new("avg[count]", DataType::UInt64, true),
Field::new("avg[sum]", DataType::Float64, true)
]
.into()
),
true,
)])
);
let expected_merge_fn = MergeWrapper::new(
avg.clone(),
Arc::new(
AggregateExprBuilder::new(
Arc::new(avg.clone()),
vec![Arc::new(
datafusion::physical_expr::expressions::Column::new("number", 0),
)],
)
.schema(Arc::new(dummy_table_scan().schema().as_arrow().clone()))
.alias("avg(number)")
.build()
.unwrap(),
),
// coerced to float64
vec![DataType::Float64],
)
.unwrap();
let expected_merge_plan = LogicalPlan::Aggregate(
Aggregate::try_new(
Arc::new(coerced_aggr_state_plan.clone()),
vec![],
vec![Expr::AggregateFunction(AggregateFunction::new_udf(
Arc::new(expected_merge_fn.into()),
vec![Expr::Column(Column::new_unqualified("__avg_state(number)"))],
false,
None,
None,
None,
))
.alias("avg(number)")],
)
.unwrap(),
);
assert_eq!(res.upper_merge.as_ref(), &expected_merge_plan);
let phy_aggr_state_plan = DefaultPhysicalPlanner::default()
.create_physical_plan(&coerced_aggr_state_plan, &ctx.state())
.await
.unwrap();
let aggr_exec = phy_aggr_state_plan
.as_any()
.downcast_ref::<AggregateExec>()
.unwrap();
let aggr_func_expr = &aggr_exec.aggr_expr()[0];
let mut state_accum = aggr_func_expr.create_accumulator().unwrap();
// evaluate the state function
let input = Float64Array::from(vec![Some(1.), Some(2.), None, Some(3.)]);
let values = vec![Arc::new(input) as arrow::array::ArrayRef];
state_accum.update_batch(&values).unwrap();
let state = state_accum.state().unwrap();
assert_eq!(state.len(), 2);
assert_eq!(state[0], ScalarValue::UInt64(Some(3)));
assert_eq!(state[1], ScalarValue::Float64(Some(6.)));
let eval_res = state_accum.evaluate().unwrap();
let expected = Arc::new(
StructArray::try_new(
vec![
Field::new("avg[count]", DataType::UInt64, true),
Field::new("avg[sum]", DataType::Float64, true),
]
.into(),
vec![
Arc::new(UInt64Array::from(vec![Some(3)])),
Arc::new(Float64Array::from(vec![Some(6.)])),
],
None,
)
.unwrap(),
);
assert_eq!(eval_res, ScalarValue::Struct(expected));
let phy_aggr_merge_plan = DefaultPhysicalPlanner::default()
.create_physical_plan(&res.upper_merge, &ctx.state())
.await
.unwrap();
let aggr_exec = phy_aggr_merge_plan
.as_any()
.downcast_ref::<AggregateExec>()
.unwrap();
let aggr_func_expr = &aggr_exec.aggr_expr()[0];
let mut merge_accum = aggr_func_expr.create_accumulator().unwrap();
let merge_input = vec![
Arc::new(UInt64Array::from(vec![Some(3), Some(42), None])) as arrow::array::ArrayRef,
Arc::new(Float64Array::from(vec![Some(48.), Some(84.), None])),
];
let merge_input_struct_arr = StructArray::try_new(
vec![
Field::new("avg[count]", DataType::UInt64, true),
Field::new("avg[sum]", DataType::Float64, true),
]
.into(),
merge_input,
None,
)
.unwrap();
merge_accum
.update_batch(&[Arc::new(merge_input_struct_arr)])
.unwrap();
let merge_state = merge_accum.state().unwrap();
assert_eq!(merge_state.len(), 2);
assert_eq!(merge_state[0], ScalarValue::UInt64(Some(45)));
assert_eq!(merge_state[1], ScalarValue::Float64(Some(132.)));
let merge_eval_res = merge_accum.evaluate().unwrap();
// the merge function returns the average, which is 132 / 45
assert_eq!(merge_eval_res, ScalarValue::Float64(Some(132. / 45_f64)));
}
/// For testing whether the UDAF state fields are correctly implemented.
/// esp. for our own custom UDAF's state fields.
/// By compare eval results before and after split to state/merge functions.
#[tokio::test]
async fn test_udaf_correct_eval_result() {
struct TestCase {
func: Arc<AggregateUDF>,
args: Vec<Expr>,
input_schema: SchemaRef,
input: Vec<ArrayRef>,
expected_output: Option<ScalarValue>,
expected_fn: Option<ExpectedFn>,
distinct: bool,
filter: Option<Box<Expr>>,
order_by: Option<Vec<SortExpr>>,
null_treatment: Option<NullTreatment>,
}
type ExpectedFn = fn(ArrayRef) -> bool;
let test_cases = vec![
TestCase {
func: sum_udaf(),
input_schema: Arc::new(arrow_schema::Schema::new(vec![Field::new(
"number",
DataType::Int64,
true,
)])),
args: vec![Expr::Column(Column::new_unqualified("number"))],
input: vec![Arc::new(Int64Array::from(vec![
Some(1),
Some(2),
None,
Some(3),
]))],
expected_output: Some(ScalarValue::Int64(Some(6))),
expected_fn: None,
distinct: false,
filter: None,
order_by: None,
null_treatment: None,
},
TestCase {
func: avg_udaf(),
input_schema: Arc::new(arrow_schema::Schema::new(vec![Field::new(
"number",
DataType::Int64,
true,
)])),
args: vec![Expr::Column(Column::new_unqualified("number"))],
input: vec![Arc::new(Int64Array::from(vec![
Some(1),
Some(2),
None,
Some(3),
]))],
expected_output: Some(ScalarValue::Float64(Some(2.0))),
expected_fn: None,
distinct: false,
filter: None,
order_by: None,
null_treatment: None,
},
TestCase {
func: Arc::new(CountHash::udf_impl()),
input_schema: Arc::new(arrow_schema::Schema::new(vec![Field::new(
"number",
DataType::Int64,
true,
)])),
args: vec![Expr::Column(Column::new_unqualified("number"))],
input: vec![Arc::new(Int64Array::from(vec![
Some(1),
Some(2),
None,
Some(3),
Some(3),
Some(3),
]))],
expected_output: Some(ScalarValue::Int64(Some(4))),
expected_fn: None,
distinct: false,
filter: None,
order_by: None,
null_treatment: None,
},
TestCase {
func: Arc::new(UddSketchState::state_udf_impl()),
input_schema: Arc::new(arrow_schema::Schema::new(vec![Field::new(
"number",
DataType::Float64,
true,
)])),
args: vec![
Expr::Literal(ScalarValue::Int64(Some(128))),
Expr::Literal(ScalarValue::Float64(Some(0.05))),
Expr::Column(Column::new_unqualified("number")),
],
input: vec![Arc::new(Float64Array::from(vec![
Some(1.),
Some(2.),
None,
Some(3.),
Some(3.),
Some(3.),
]))],
expected_output: None,
expected_fn: Some(|arr| {
let percent = ScalarValue::Float64(Some(0.5)).to_array().unwrap();
let percent = datatypes::vectors::Helper::try_into_vector(percent).unwrap();
let state = datatypes::vectors::Helper::try_into_vector(arr).unwrap();
let udd_calc = UddSketchCalcFunction;
let res = udd_calc
.eval(&Default::default(), &[percent, state])
.unwrap();
let binding = res.to_arrow_array();
let res_arr = binding.as_any().downcast_ref::<Float64Array>().unwrap();
assert!(res_arr.len() == 1);
assert!((res_arr.value(0) - 2.856578984907706f64).abs() <= f64::EPSILON);
true
}),
distinct: false,
filter: None,
order_by: None,
null_treatment: None,
},
TestCase {
func: Arc::new(HllState::state_udf_impl()),
input_schema: Arc::new(arrow_schema::Schema::new(vec![Field::new(
"word",
DataType::Utf8,
true,
)])),
args: vec![Expr::Column(Column::new_unqualified("word"))],
input: vec![Arc::new(StringArray::from(vec![
Some("foo"),
Some("bar"),
None,
Some("baz"),
Some("baz"),
]))],
expected_output: None,
expected_fn: Some(|arr| {
let state = datatypes::vectors::Helper::try_into_vector(arr).unwrap();
let hll_calc = HllCalcFunction;
let res = hll_calc.eval(&Default::default(), &[state]).unwrap();
let binding = res.to_arrow_array();
let res_arr = binding.as_any().downcast_ref::<UInt64Array>().unwrap();
assert!(res_arr.len() == 1);
assert_eq!(res_arr.value(0), 3);
true
}),
distinct: false,
filter: None,
order_by: None,
null_treatment: None,
},
// TODO(discord9): udd_merge/hll_merge/geo_path/quantile_aggr tests
];
let test_table_ref = TableReference::bare("TestTable");
for case in test_cases {
let ctx = SessionContext::new();
let table_provider = DummyTableProvider::new(
case.input_schema.clone(),
Some(RecordBatch::try_new(case.input_schema.clone(), case.input.clone()).unwrap()),
);
let table_source = DefaultTableSource::new(Arc::new(table_provider));
let logical_plan = LogicalPlan::TableScan(
TableScan::try_new(
test_table_ref.clone(),
Arc::new(table_source),
None,
vec![],
None,
)
.unwrap(),
);
let args = case.args;
let aggr_expr = Expr::AggregateFunction(AggregateFunction::new_udf(
case.func.clone(),
args,
case.distinct,
case.filter,
case.order_by,
case.null_treatment,
));
let aggr_plan = LogicalPlan::Aggregate(
Aggregate::try_new(Arc::new(logical_plan), vec![], vec![aggr_expr]).unwrap(),
);
// make sure the aggr_plan is type coerced
let aggr_plan = TypeCoercion::new()
.analyze(aggr_plan, &Default::default())
.unwrap();
// first eval the original aggregate function
let phy_full_aggr_plan = DefaultPhysicalPlanner::default()
.create_physical_plan(&aggr_plan, &ctx.state())
.await
.unwrap();
{
let unsplit_result = execute_phy_plan(&phy_full_aggr_plan).await.unwrap();
assert_eq!(unsplit_result.len(), 1);
let unsplit_batch = &unsplit_result[0];
assert_eq!(unsplit_batch.num_columns(), 1);
assert_eq!(unsplit_batch.num_rows(), 1);
let unsplit_col = unsplit_batch.column(0);
if let Some(expected_output) = &case.expected_output {
assert_eq!(unsplit_col.data_type(), &expected_output.data_type());
assert_eq!(unsplit_col.len(), 1);
assert_eq!(unsplit_col, &expected_output.to_array().unwrap());
}
if let Some(expected_fn) = &case.expected_fn {
assert!(expected_fn(unsplit_col.clone()));
}
}
let LogicalPlan::Aggregate(aggr_plan) = aggr_plan else {
panic!("Expected Aggregate plan");
};
let split_plan = StateMergeHelper::split_aggr_node(aggr_plan).unwrap();
let phy_upper_plan = DefaultPhysicalPlanner::default()
.create_physical_plan(&split_plan.upper_merge, &ctx.state())
.await
.unwrap();
// since upper plan use lower plan as input, execute upper plan should also execute lower plan
// which should give the same result as the original aggregate function
{
let split_res = execute_phy_plan(&phy_upper_plan).await.unwrap();
assert_eq!(split_res.len(), 1);
let split_batch = &split_res[0];
assert_eq!(split_batch.num_columns(), 1);
assert_eq!(split_batch.num_rows(), 1);
let split_col = split_batch.column(0);
if let Some(expected_output) = &case.expected_output {
assert_eq!(split_col.data_type(), &expected_output.data_type());
assert_eq!(split_col.len(), 1);
assert_eq!(split_col, &expected_output.to_array().unwrap());
}
if let Some(expected_fn) = &case.expected_fn {
assert!(expected_fn(split_col.clone()));
}
}
}
}
async fn execute_phy_plan(
phy_plan: &Arc<dyn ExecutionPlan>,
) -> datafusion_common::Result<Vec<RecordBatch>> {
let task_ctx = Arc::new(TaskContext::default());
let mut stream = phy_plan.execute(0, task_ctx)?;
let mut batches = Vec::new();
while let Some(batch) = stream.next().await {
batches.push(batch?);
}
Ok(batches)
}