From 77b540ff68b8eb656014a7f5b8b2c8465a00404f Mon Sep 17 00:00:00 2001 From: discord9 <55937128+discord9@users.noreply.github.com> Date: Fri, 18 Jul 2025 01:37:40 +0800 Subject: [PATCH] feat: state/merge wrapper for aggr func (#6377) * refactor: move to query crate Signed-off-by: discord9 * refactor: split to multiple columns Signed-off-by: discord9 * feat: aggr merge accum wrapper Signed-off-by: discord9 * rename shorter Signed-off-by: discord9 * feat: add all in one helper Signed-off-by: discord9 * tests: sum&avg Signed-off-by: discord9 * chore: allow unused Signed-off-by: discord9 * chore: typos Signed-off-by: discord9 * refactor: per ds Signed-off-by: discord9 * chore: fix tests Signed-off-by: discord9 * refactor: move to common-function Signed-off-by: discord9 * WIP massive refactor Signed-off-by: discord9 * typo Signed-off-by: discord9 * todo: stuff Signed-off-by: discord9 * refactor: state2input type Signed-off-by: discord9 * chore: rm unused Signed-off-by: discord9 * refactor: per bot review Signed-off-by: discord9 * chore: per bot Signed-off-by: discord9 * refactor: rm duplicate infer type Signed-off-by: discord9 * chore: better test Signed-off-by: discord9 * fix: test sum refactor&fix wrong state types Signed-off-by: discord9 * test: refactor avg udaf test Signed-off-by: discord9 * refactor: split files Signed-off-by: discord9 * refactor: docs&dedup Signed-off-by: discord9 * refactor: allow merge to carry extra info Signed-off-by: discord9 * chore: rm unused Signed-off-by: discord9 * chore: clippy Signed-off-by: discord9 * chore: docs&unused Signed-off-by: discord9 * refactor: check fields equal Signed-off-by: discord9 * test: test count_hash Signed-off-by: discord9 * test: more custom udaf Signed-off-by: discord9 * chore: clippy Signed-off-by: discord9 * refactor: per review Signed-off-by: discord9 --------- Signed-off-by: discord9 --- Cargo.lock | 5 + src/common/function/Cargo.toml | 5 + src/common/function/src/aggrs.rs | 2 + src/common/function/src/aggrs/aggr_wrapper.rs | 538 ++++++++++++ .../function/src/aggrs/aggr_wrapper/tests.rs | 804 ++++++++++++++++++ 5 files changed, 1354 insertions(+) create mode 100644 src/common/function/src/aggrs/aggr_wrapper.rs create mode 100644 src/common/function/src/aggrs/aggr_wrapper/tests.rs diff --git a/Cargo.lock b/Cargo.lock index 6fbfa8e9fd..06325ccc16 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2347,6 +2347,8 @@ dependencies = [ "api", "approx 0.5.1", "arc-swap", + "arrow 54.2.1", + "arrow-schema 54.3.1", "async-trait", "bincode", "catalog", @@ -2365,8 +2367,10 @@ dependencies = [ "datafusion-common", "datafusion-expr", "datafusion-functions-aggregate-common", + "datafusion-physical-expr", "datatypes", "derive_more", + "futures", "geo", "geo-types", "geohash", @@ -2379,6 +2383,7 @@ dependencies = [ "num-traits", "once_cell", "paste", + "pretty_assertions", "s2", "serde", "serde_json", diff --git a/src/common/function/Cargo.toml b/src/common/function/Cargo.toml index 2a67488d32..710a02ad8f 100644 --- a/src/common/function/Cargo.toml +++ b/src/common/function/Cargo.toml @@ -16,6 +16,8 @@ geo = ["geohash", "h3o", "s2", "wkt", "geo-types", "dep:geo"] ahash.workspace = true api.workspace = true arc-swap = "1.0" +arrow.workspace = true +arrow-schema.workspace = true async-trait.workspace = true bincode = "1.3" catalog.workspace = true @@ -34,6 +36,7 @@ datafusion.workspace = true datafusion-common.workspace = true datafusion-expr.workspace = true datafusion-functions-aggregate-common.workspace = true +datafusion-physical-expr.workspace = true datatypes.workspace = true derive_more = { version = "1", default-features = false, features = ["display"] } geo = { version = "0.29", optional = true } @@ -62,5 +65,7 @@ wkt = { version = "0.11", optional = true } [dev-dependencies] approx = "0.5" +futures.workspace = true +pretty_assertions = "1.4.0" serde = { version = "1.0", features = ["derive"] } tokio.workspace = true diff --git a/src/common/function/src/aggrs.rs b/src/common/function/src/aggrs.rs index e60de8101a..e618921a18 100644 --- a/src/common/function/src/aggrs.rs +++ b/src/common/function/src/aggrs.rs @@ -17,3 +17,5 @@ pub mod count_hash; #[cfg(feature = "geo")] pub mod geo; pub mod vector; + +pub mod aggr_wrapper; diff --git a/src/common/function/src/aggrs/aggr_wrapper.rs b/src/common/function/src/aggrs/aggr_wrapper.rs new file mode 100644 index 0000000000..0bfd52ed3b --- /dev/null +++ b/src/common/function/src/aggrs/aggr_wrapper.rs @@ -0,0 +1,538 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Wrapper for making aggregate functions out of state/merge functions of original aggregate functions. +//! +//! i.e. for a aggregate function `foo`, we will have a state function `foo_state` and a merge function `foo_merge`. +//! +//! `foo_state`'s input args is the same as `foo`'s, and its output is a state object. +//! Note that `foo_state` might have multiple output columns, so it's a struct array +//! that each output column is a struct field. +//! `foo_merge`'s input arg is the same as `foo_state`'s output, and its output is the same as `foo`'s input. +//! + +use std::sync::Arc; + +use arrow::array::StructArray; +use arrow_schema::Fields; +use datafusion::optimizer::analyzer::type_coercion::TypeCoercion; +use datafusion::optimizer::AnalyzerRule; +use datafusion::physical_planner::create_aggregate_expr_and_maybe_filter; +use datafusion_common::{Column, ScalarValue}; +use datafusion_expr::expr::AggregateFunction; +use datafusion_expr::function::StateFieldsArgs; +use datafusion_expr::{ + Accumulator, Aggregate, AggregateUDF, AggregateUDFImpl, Expr, ExprSchemable, LogicalPlan, + Signature, +}; +use datafusion_physical_expr::aggregate::AggregateFunctionExpr; +use datatypes::arrow::datatypes::{DataType, Field}; + +/// Returns the name of the state function for the given aggregate function name. +/// The state function is used to compute the state of the aggregate function. +/// The state function's name is in the format `___state +pub fn aggr_state_func_name(aggr_name: &str) -> String { + format!("__{}_state", aggr_name) +} + +/// Returns the name of the merge function for the given aggregate function name. +/// The merge function is used to merge the states of the state functions. +/// The merge function's name is in the format `___merge +pub fn aggr_merge_func_name(aggr_name: &str) -> String { + format!("__{}_merge", aggr_name) +} + +/// A wrapper to make an aggregate function out of the state and merge functions of the original aggregate function. +/// It contains the original aggregate function, the state functions, and the merge function. +/// +/// Notice state functions may have multiple output columns, so it's return type is always a struct array, and the merge function is used to merge the states of the state functions. +#[derive(Debug, Clone)] +pub struct StateMergeHelper; + +/// A struct to hold the two aggregate plans, one for the state function(lower) and one for the merge function(upper). +#[allow(unused)] +#[derive(Debug, Clone)] +pub struct StepAggrPlan { + /// Upper merge plan, which is the aggregate plan that merges the states of the state function. + pub upper_merge: Arc, + /// Lower state plan, which is the aggregate plan that computes the state of the aggregate function. + pub lower_state: Arc, +} + +pub fn get_aggr_func(expr: &Expr) -> Option<&datafusion_expr::expr::AggregateFunction> { + let mut expr_ref = expr; + while let Expr::Alias(alias) = expr_ref { + expr_ref = &alias.expr; + } + if let Expr::AggregateFunction(aggr_func) = expr_ref { + Some(aggr_func) + } else { + None + } +} + +impl StateMergeHelper { + /// Split an aggregate plan into two aggregate plans, one for the state function and one for the merge function. + pub fn split_aggr_node(aggr_plan: Aggregate) -> datafusion_common::Result { + let aggr = { + // certain aggr func need type coercion to work correctly, so we need to analyze the plan first. + let aggr_plan = TypeCoercion::new().analyze( + LogicalPlan::Aggregate(aggr_plan).clone(), + &Default::default(), + )?; + if let LogicalPlan::Aggregate(aggr) = aggr_plan { + aggr + } else { + return Err(datafusion_common::DataFusionError::Internal(format!( + "Failed to coerce expressions in aggregate plan, expected Aggregate, got: {:?}", + aggr_plan + ))); + } + }; + let mut lower_aggr_exprs = vec![]; + let mut upper_aggr_exprs = vec![]; + + for aggr_expr in aggr.aggr_expr.iter() { + let Some(aggr_func) = get_aggr_func(aggr_expr) else { + return Err(datafusion_common::DataFusionError::NotImplemented(format!( + "Unsupported aggregate expression for step aggr optimize: {:?}", + aggr_expr + ))); + }; + + let original_input_types = aggr_func + .args + .iter() + .map(|e| e.get_type(&aggr.input.schema())) + .collect::, _>>()?; + + // first create the state function from the original aggregate function. + let state_func = StateWrapper::new((*aggr_func.func).clone())?; + + let expr = AggregateFunction { + func: Arc::new(state_func.into()), + args: aggr_func.args.clone(), + distinct: aggr_func.distinct, + filter: aggr_func.filter.clone(), + order_by: aggr_func.order_by.clone(), + null_treatment: aggr_func.null_treatment, + }; + let expr = Expr::AggregateFunction(expr); + let lower_state_output_col_name = expr.schema_name().to_string(); + + lower_aggr_exprs.push(expr); + + let (original_phy_expr, _filter, _ordering) = create_aggregate_expr_and_maybe_filter( + aggr_expr, + aggr.input.schema(), + aggr.input.schema().as_arrow(), + &Default::default(), + )?; + + let merge_func = MergeWrapper::new( + (*aggr_func.func).clone(), + original_phy_expr, + original_input_types, + )?; + let arg = Expr::Column(Column::new_unqualified(lower_state_output_col_name)); + let expr = AggregateFunction { + func: Arc::new(merge_func.into()), + args: vec![arg], + distinct: aggr_func.distinct, + filter: aggr_func.filter.clone(), + order_by: aggr_func.order_by.clone(), + null_treatment: aggr_func.null_treatment, + }; + + // alias to the original aggregate expr's schema name, so parent plan can refer to it + // correctly. + let expr = Expr::AggregateFunction(expr).alias(aggr_expr.schema_name().to_string()); + upper_aggr_exprs.push(expr); + } + + let mut lower = aggr.clone(); + lower.aggr_expr = lower_aggr_exprs; + let lower_plan = LogicalPlan::Aggregate(lower); + + // update aggregate's output schema + let lower_plan = Arc::new(lower_plan.recompute_schema()?); + + let mut upper = aggr.clone(); + let aggr_plan = LogicalPlan::Aggregate(aggr); + upper.aggr_expr = upper_aggr_exprs; + upper.input = lower_plan.clone(); + // upper schema's output schema should be the same as the original aggregate plan's output schema + let upper_check = upper.clone(); + let upper_plan = Arc::new(LogicalPlan::Aggregate(upper_check).recompute_schema()?); + if *upper_plan.schema() != *aggr_plan.schema() { + return Err(datafusion_common::DataFusionError::Internal(format!( + "Upper aggregate plan's schema is not the same as the original aggregate plan's schema: \n[transformed]:{}\n[ original]{}", + upper_plan.schema(), aggr_plan.schema() + ))); + } + + Ok(StepAggrPlan { + lower_state: lower_plan, + upper_merge: upper_plan, + }) + } +} + +/// Wrapper to make an aggregate function out of a state function. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StateWrapper { + inner: AggregateUDF, + name: String, +} + +impl StateWrapper { + /// `state_index`: The index of the state in the output of the state function. + pub fn new(inner: AggregateUDF) -> datafusion_common::Result { + let name = aggr_state_func_name(inner.name()); + Ok(Self { inner, name }) + } + + pub fn inner(&self) -> &AggregateUDF { + &self.inner + } + + /// Deduce the return type of the original aggregate function + /// based on the accumulator arguments. + /// + pub fn deduce_aggr_return_type( + &self, + acc_args: &datafusion_expr::function::AccumulatorArgs, + ) -> datafusion_common::Result { + let input_exprs = acc_args.exprs; + let input_schema = acc_args.schema; + let input_types = input_exprs + .iter() + .map(|e| e.data_type(input_schema)) + .collect::, _>>()?; + let return_type = self.inner.return_type(&input_types)?; + Ok(return_type) + } +} + +impl AggregateUDFImpl for StateWrapper { + fn accumulator<'a, 'b>( + &'a self, + acc_args: datafusion_expr::function::AccumulatorArgs<'b>, + ) -> datafusion_common::Result> { + // fix and recover proper acc args for the original aggregate function. + let state_type = acc_args.return_type.clone(); + let inner = { + let old_return_type = self.deduce_aggr_return_type(&acc_args)?; + let acc_args = datafusion_expr::function::AccumulatorArgs { + return_type: &old_return_type, + schema: acc_args.schema, + ignore_nulls: acc_args.ignore_nulls, + ordering_req: acc_args.ordering_req, + is_reversed: acc_args.is_reversed, + name: acc_args.name, + is_distinct: acc_args.is_distinct, + exprs: acc_args.exprs, + }; + self.inner.accumulator(acc_args)? + }; + Ok(Box::new(StateAccum::new(inner, state_type)?)) + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &str { + self.name.as_str() + } + + fn is_nullable(&self) -> bool { + self.inner.is_nullable() + } + + /// Return state_fields as the output struct type. + /// + fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + let old_return_type = self.inner.return_type(arg_types)?; + let state_fields_args = StateFieldsArgs { + name: self.inner().name(), + input_types: arg_types, + return_type: &old_return_type, + // TODO(discord9): how to get this?, probably ok? + ordering_fields: &[], + is_distinct: false, + }; + let state_fields = self.inner.state_fields(state_fields_args)?; + let struct_field = DataType::Struct(state_fields.into()); + Ok(struct_field) + } + + /// The state function's output fields are the same as the original aggregate function's state fields. + fn state_fields( + &self, + args: datafusion_expr::function::StateFieldsArgs, + ) -> datafusion_common::Result> { + let old_return_type = self.inner.return_type(args.input_types)?; + let state_fields_args = StateFieldsArgs { + name: args.name, + input_types: args.input_types, + return_type: &old_return_type, + ordering_fields: args.ordering_fields, + is_distinct: args.is_distinct, + }; + self.inner.state_fields(state_fields_args) + } + + /// The state function's signature is the same as the original aggregate function's signature, + fn signature(&self) -> &Signature { + self.inner.signature() + } + + /// Coerce types also do nothing, as optimzer should be able to already make struct types + fn coerce_types(&self, arg_types: &[DataType]) -> datafusion_common::Result> { + self.inner.coerce_types(arg_types) + } +} + +/// The wrapper's input is the same as the original aggregate function's input, +/// and the output is the state function's output. +#[derive(Debug)] +pub struct StateAccum { + inner: Box, + state_fields: Fields, +} + +impl StateAccum { + pub fn new( + inner: Box, + state_type: DataType, + ) -> datafusion_common::Result { + let DataType::Struct(fields) = state_type else { + return Err(datafusion_common::DataFusionError::Internal(format!( + "Expected a struct type for state, got: {:?}", + state_type + ))); + }; + Ok(Self { + inner, + state_fields: fields, + }) + } +} + +impl Accumulator for StateAccum { + fn evaluate(&mut self) -> datafusion_common::Result { + let state = self.inner.state()?; + + let array = state + .iter() + .map(|s| s.to_array()) + .collect::, _>>()?; + let struct_array = StructArray::try_new(self.state_fields.clone(), array, None)?; + Ok(ScalarValue::Struct(Arc::new(struct_array))) + } + + fn merge_batch( + &mut self, + states: &[datatypes::arrow::array::ArrayRef], + ) -> datafusion_common::Result<()> { + self.inner.merge_batch(states) + } + + fn update_batch( + &mut self, + values: &[datatypes::arrow::array::ArrayRef], + ) -> datafusion_common::Result<()> { + self.inner.update_batch(values) + } + + fn size(&self) -> usize { + self.inner.size() + } + + fn state(&mut self) -> datafusion_common::Result> { + self.inner.state() + } +} + +/// TODO(discord9): mark this function as non-ser/de able +/// +/// This wrapper shouldn't be register as a udaf, as it contain extra data that is not serializable. +/// and changes for different logical plans. +#[derive(Debug, Clone)] +pub struct MergeWrapper { + inner: AggregateUDF, + name: String, + merge_signature: Signature, + /// The original physical expression of the aggregate function, can't store the original aggregate function directly, as PhysicalExpr didn't implement Any + original_phy_expr: Arc, + original_input_types: Vec, +} +impl MergeWrapper { + pub fn new( + inner: AggregateUDF, + original_phy_expr: Arc, + original_input_types: Vec, + ) -> datafusion_common::Result { + let name = aggr_merge_func_name(inner.name()); + // the input type is actually struct type, which is the state fields of the original aggregate function. + let merge_signature = Signature::user_defined(datafusion_expr::Volatility::Immutable); + + Ok(Self { + inner, + name, + merge_signature, + original_phy_expr, + original_input_types, + }) + } + + pub fn inner(&self) -> &AggregateUDF { + &self.inner + } +} + +impl AggregateUDFImpl for MergeWrapper { + fn accumulator<'a, 'b>( + &'a self, + acc_args: datafusion_expr::function::AccumulatorArgs<'b>, + ) -> datafusion_common::Result> { + if acc_args.schema.fields().len() != 1 + || !matches!(acc_args.schema.field(0).data_type(), DataType::Struct(_)) + { + return Err(datafusion_common::DataFusionError::Internal(format!( + "Expected one struct type as input, got: {:?}", + acc_args.schema + ))); + } + let input_type = acc_args.schema.field(0).data_type(); + let DataType::Struct(fields) = input_type else { + return Err(datafusion_common::DataFusionError::Internal(format!( + "Expected a struct type for input, got: {:?}", + input_type + ))); + }; + + let inner_accum = self.original_phy_expr.create_accumulator()?; + Ok(Box::new(MergeAccum::new(inner_accum, fields))) + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &str { + self.name.as_str() + } + + fn is_nullable(&self) -> bool { + self.inner.is_nullable() + } + + /// Notice here the `arg_types` is actually the `state_fields`'s data types, + /// so return fixed return type instead of using `arg_types` to determine the return type. + fn return_type(&self, _arg_types: &[DataType]) -> datafusion_common::Result { + // The return type is the same as the original aggregate function's return type. + let ret_type = self.inner.return_type(&self.original_input_types)?; + Ok(ret_type) + } + fn signature(&self) -> &Signature { + &self.merge_signature + } + + /// Coerce types also do nothing, as optimzer should be able to already make struct types + fn coerce_types(&self, arg_types: &[DataType]) -> datafusion_common::Result> { + // just check if the arg_types are only one and is struct array + if arg_types.len() != 1 || !matches!(arg_types.first(), Some(DataType::Struct(_))) { + return Err(datafusion_common::DataFusionError::Internal(format!( + "Expected one struct type as input, got: {:?}", + arg_types + ))); + } + Ok(arg_types.to_vec()) + } + + /// Just return the original aggregate function's state fields. + fn state_fields( + &self, + _args: datafusion_expr::function::StateFieldsArgs, + ) -> datafusion_common::Result> { + self.original_phy_expr.state_fields() + } +} + +/// The merge accumulator, which modify `update_batch`'s behavior to accept one struct array which +/// include the state fields of original aggregate function, and merge said states into original accumulator +/// the output is the same as original aggregate function +#[derive(Debug)] +pub struct MergeAccum { + inner: Box, + state_fields: Fields, +} + +impl MergeAccum { + pub fn new(inner: Box, state_fields: &Fields) -> Self { + Self { + inner, + state_fields: state_fields.clone(), + } + } +} + +impl Accumulator for MergeAccum { + fn evaluate(&mut self) -> datafusion_common::Result { + self.inner.evaluate() + } + + fn merge_batch(&mut self, states: &[arrow::array::ArrayRef]) -> datafusion_common::Result<()> { + self.inner.merge_batch(states) + } + + fn update_batch(&mut self, values: &[arrow::array::ArrayRef]) -> datafusion_common::Result<()> { + let value = values.first().ok_or_else(|| { + datafusion_common::DataFusionError::Internal("No values provided for merge".to_string()) + })?; + // The input values are states from other accumulators, so we merge them. + let struct_arr = value + .as_any() + .downcast_ref::() + .ok_or_else(|| { + datafusion_common::DataFusionError::Internal(format!( + "Expected StructArray, got: {:?}", + value.data_type() + )) + })?; + let fields = struct_arr.fields(); + if fields != &self.state_fields { + return Err(datafusion_common::DataFusionError::Internal(format!( + "Expected state fields: {:?}, got: {:?}", + self.state_fields, fields + ))); + } + + // now fields should be the same, so we can merge the batch + // by pass the columns as order should be the same + let state_columns = struct_arr.columns(); + self.inner.merge_batch(state_columns) + } + + fn size(&self) -> usize { + self.inner.size() + } + + fn state(&mut self) -> datafusion_common::Result> { + self.inner.state() + } +} + +#[cfg(test)] +mod tests; diff --git a/src/common/function/src/aggrs/aggr_wrapper/tests.rs b/src/common/function/src/aggrs/aggr_wrapper/tests.rs new file mode 100644 index 0000000000..23a516a619 --- /dev/null +++ b/src/common/function/src/aggrs/aggr_wrapper/tests.rs @@ -0,0 +1,804 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::any::Any; +use std::pin::Pin; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll}; + +use arrow::array::{ArrayRef, Float64Array, Int64Array, UInt64Array}; +use arrow::record_batch::RecordBatch; +use arrow_schema::SchemaRef; +use datafusion::catalog::{Session, TableProvider}; +use datafusion::datasource::DefaultTableSource; +use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext}; +use datafusion::functions_aggregate::average::avg_udaf; +use datafusion::functions_aggregate::sum::sum_udaf; +use datafusion::optimizer::analyzer::type_coercion::TypeCoercion; +use datafusion::optimizer::AnalyzerRule; +use datafusion::physical_plan::aggregates::AggregateExec; +use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties}; +use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}; +use datafusion::prelude::SessionContext; +use datafusion_common::{Column, TableReference}; +use datafusion_expr::expr::AggregateFunction; +use datafusion_expr::sqlparser::ast::NullTreatment; +use datafusion_expr::{Aggregate, Expr, LogicalPlan, SortExpr, TableScan}; +use datafusion_physical_expr::aggregate::AggregateExprBuilder; +use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; +use datatypes::arrow_array::StringArray; +use futures::{Stream, StreamExt as _}; +use pretty_assertions::assert_eq; + +use super::*; +use crate::aggrs::approximate::hll::HllState; +use crate::aggrs::approximate::uddsketch::UddSketchState; +use crate::aggrs::count_hash::CountHash; +use crate::function::Function as _; +use crate::scalars::hll_count::HllCalcFunction; +use crate::scalars::uddsketch_calc::UddSketchCalcFunction; + +#[derive(Debug)] +pub struct MockInputExec { + input: Vec, + schema: SchemaRef, + properties: PlanProperties, +} + +impl MockInputExec { + pub fn new(input: Vec, schema: SchemaRef) -> Self { + Self { + properties: PlanProperties::new( + EquivalenceProperties::new(schema.clone()), + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + Boundedness::Bounded, + ), + input, + schema, + } + } +} + +impl DisplayAs for MockInputExec { + fn fmt_as(&self, _t: DisplayFormatType, _f: &mut std::fmt::Formatter) -> std::fmt::Result { + unimplemented!() + } +} + +impl ExecutionPlan for MockInputExec { + fn name(&self) -> &str { + "MockInputExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> datafusion_common::Result> { + Ok(self) + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> datafusion_common::Result { + let stream = MockStream { + stream: self.input.clone(), + schema: self.schema.clone(), + idx: 0, + }; + Ok(Box::pin(stream)) + } +} + +struct MockStream { + stream: Vec, + schema: SchemaRef, + idx: usize, +} + +impl Stream for MockStream { + type Item = datafusion_common::Result; + fn poll_next( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll>> { + if self.idx < self.stream.len() { + let ret = self.stream[self.idx].clone(); + self.idx += 1; + Poll::Ready(Some(Ok(ret))) + } else { + Poll::Ready(None) + } + } +} + +impl RecordBatchStream for MockStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +#[derive(Debug)] +struct DummyTableProvider { + schema: Arc, + record_batch: Mutex>, +} + +impl DummyTableProvider { + #[allow(unused)] + pub fn new(schema: Arc, record_batch: Option) -> Self { + Self { + schema, + record_batch: Mutex::new(record_batch), + } + } +} + +impl Default for DummyTableProvider { + fn default() -> Self { + Self { + schema: Arc::new(arrow_schema::Schema::new(vec![Field::new( + "number", + DataType::Int64, + true, + )])), + record_batch: Mutex::new(None), + } + } +} + +#[async_trait::async_trait] +impl TableProvider for DummyTableProvider { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> Arc { + self.schema.clone() + } + + fn table_type(&self) -> datafusion_expr::TableType { + datafusion_expr::TableType::Base + } + + async fn scan( + &self, + _state: &dyn Session, + _projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> datafusion::error::Result> { + let input: Vec = self + .record_batch + .lock() + .unwrap() + .clone() + .map(|r| vec![r]) + .unwrap_or_default(); + Ok(Arc::new(MockInputExec::new(input, self.schema.clone()))) + } +} + +fn dummy_table_scan() -> LogicalPlan { + let table_provider = Arc::new(DummyTableProvider::default()); + let table_source = DefaultTableSource::new(table_provider); + LogicalPlan::TableScan( + TableScan::try_new( + TableReference::bare("Number"), + Arc::new(table_source), + None, + vec![], + None, + ) + .unwrap(), + ) +} + +#[tokio::test] +async fn test_sum_udaf() { + let ctx = SessionContext::new(); + + let sum = datafusion::functions_aggregate::sum::sum_udaf(); + let sum = (*sum).clone(); + let original_aggr = Aggregate::try_new( + Arc::new(dummy_table_scan()), + vec![], + vec![Expr::AggregateFunction(AggregateFunction::new_udf( + Arc::new(sum.clone()), + vec![Expr::Column(Column::new_unqualified("number"))], + false, + None, + None, + None, + ))], + ) + .unwrap(); + let res = StateMergeHelper::split_aggr_node(original_aggr).unwrap(); + + let expected_lower_plan = LogicalPlan::Aggregate( + Aggregate::try_new( + Arc::new(dummy_table_scan()), + vec![], + vec![Expr::AggregateFunction(AggregateFunction::new_udf( + Arc::new(StateWrapper::new(sum.clone()).unwrap().into()), + vec![Expr::Column(Column::new_unqualified("number"))], + false, + None, + None, + None, + ))], + ) + .unwrap(), + ) + .recompute_schema() + .unwrap(); + assert_eq!(res.lower_state.as_ref(), &expected_lower_plan); + + let expected_merge_plan = LogicalPlan::Aggregate( + Aggregate::try_new( + Arc::new(expected_lower_plan), + vec![], + vec![Expr::AggregateFunction(AggregateFunction::new_udf( + Arc::new( + MergeWrapper::new( + sum.clone(), + Arc::new( + AggregateExprBuilder::new( + Arc::new(sum.clone()), + vec![Arc::new( + datafusion::physical_expr::expressions::Column::new( + "number", 0, + ), + )], + ) + .schema(Arc::new(dummy_table_scan().schema().as_arrow().clone())) + .alias("sum(number)") + .build() + .unwrap(), + ), + vec![DataType::Int64], + ) + .unwrap() + .into(), + ), + vec![Expr::Column(Column::new_unqualified("__sum_state(number)"))], + false, + None, + None, + None, + )) + .alias("sum(number)")], + ) + .unwrap(), + ); + assert_eq!(res.upper_merge.as_ref(), &expected_merge_plan); + + let phy_aggr_state_plan = DefaultPhysicalPlanner::default() + .create_physical_plan(&res.lower_state, &ctx.state()) + .await + .unwrap(); + let aggr_exec = phy_aggr_state_plan + .as_any() + .downcast_ref::() + .unwrap(); + let aggr_func_expr = &aggr_exec.aggr_expr()[0]; + let mut state_accum = aggr_func_expr.create_accumulator().unwrap(); + + // evaluate the state function + let input = Int64Array::from(vec![Some(1), Some(2), None, Some(3)]); + let values = vec![Arc::new(input) as arrow::array::ArrayRef]; + + state_accum.update_batch(&values).unwrap(); + let state = state_accum.state().unwrap(); + assert_eq!(state.len(), 1); + assert_eq!(state[0], ScalarValue::Int64(Some(6))); + + let eval_res = state_accum.evaluate().unwrap(); + assert_eq!( + eval_res, + ScalarValue::Struct(Arc::new( + StructArray::try_new( + vec![Field::new("sum[sum]", DataType::Int64, true)].into(), + vec![Arc::new(Int64Array::from(vec![Some(6)]))], + None, + ) + .unwrap(), + )) + ); + + let phy_aggr_merge_plan = DefaultPhysicalPlanner::default() + .create_physical_plan(&res.upper_merge, &ctx.state()) + .await + .unwrap(); + let aggr_exec = phy_aggr_merge_plan + .as_any() + .downcast_ref::() + .unwrap(); + let aggr_func_expr = &aggr_exec.aggr_expr()[0]; + let mut merge_accum = aggr_func_expr.create_accumulator().unwrap(); + + let merge_input = + vec![Arc::new(Int64Array::from(vec![Some(6), Some(42), None])) as arrow::array::ArrayRef]; + let merge_input_struct_arr = StructArray::try_new( + vec![Field::new("sum[sum]", DataType::Int64, true)].into(), + merge_input, + None, + ) + .unwrap(); + + merge_accum + .update_batch(&[Arc::new(merge_input_struct_arr)]) + .unwrap(); + let merge_state = merge_accum.state().unwrap(); + assert_eq!(merge_state.len(), 1); + assert_eq!(merge_state[0], ScalarValue::Int64(Some(48))); + + let merge_eval_res = merge_accum.evaluate().unwrap(); + assert_eq!(merge_eval_res, ScalarValue::Int64(Some(48))); +} + +#[tokio::test] +async fn test_avg_udaf() { + let ctx = SessionContext::new(); + + let avg = datafusion::functions_aggregate::average::avg_udaf(); + let avg = (*avg).clone(); + + let original_aggr = Aggregate::try_new( + Arc::new(dummy_table_scan()), + vec![], + vec![Expr::AggregateFunction(AggregateFunction::new_udf( + Arc::new(avg.clone()), + vec![Expr::Column(Column::new_unqualified("number"))], + false, + None, + None, + None, + ))], + ) + .unwrap(); + let res = StateMergeHelper::split_aggr_node(original_aggr).unwrap(); + + let state_func: Arc = Arc::new(StateWrapper::new(avg.clone()).unwrap().into()); + let expected_aggr_state_plan = LogicalPlan::Aggregate( + Aggregate::try_new( + Arc::new(dummy_table_scan()), + vec![], + vec![Expr::AggregateFunction(AggregateFunction::new_udf( + state_func, + vec![Expr::Column(Column::new_unqualified("number"))], + false, + None, + None, + None, + ))], + ) + .unwrap(), + ); + // type coerced so avg aggr function can function correctly + let coerced_aggr_state_plan = TypeCoercion::new() + .analyze(expected_aggr_state_plan.clone(), &Default::default()) + .unwrap(); + assert_eq!(res.lower_state.as_ref(), &coerced_aggr_state_plan); + assert_eq!( + res.lower_state.schema().as_arrow(), + &arrow_schema::Schema::new(vec![Field::new( + "__avg_state(number)", + DataType::Struct( + vec![ + Field::new("avg[count]", DataType::UInt64, true), + Field::new("avg[sum]", DataType::Float64, true) + ] + .into() + ), + true, + )]) + ); + + let expected_merge_fn = MergeWrapper::new( + avg.clone(), + Arc::new( + AggregateExprBuilder::new( + Arc::new(avg.clone()), + vec![Arc::new( + datafusion::physical_expr::expressions::Column::new("number", 0), + )], + ) + .schema(Arc::new(dummy_table_scan().schema().as_arrow().clone())) + .alias("avg(number)") + .build() + .unwrap(), + ), + // coerced to float64 + vec![DataType::Float64], + ) + .unwrap(); + + let expected_merge_plan = LogicalPlan::Aggregate( + Aggregate::try_new( + Arc::new(coerced_aggr_state_plan.clone()), + vec![], + vec![Expr::AggregateFunction(AggregateFunction::new_udf( + Arc::new(expected_merge_fn.into()), + vec![Expr::Column(Column::new_unqualified("__avg_state(number)"))], + false, + None, + None, + None, + )) + .alias("avg(number)")], + ) + .unwrap(), + ); + assert_eq!(res.upper_merge.as_ref(), &expected_merge_plan); + + let phy_aggr_state_plan = DefaultPhysicalPlanner::default() + .create_physical_plan(&coerced_aggr_state_plan, &ctx.state()) + .await + .unwrap(); + let aggr_exec = phy_aggr_state_plan + .as_any() + .downcast_ref::() + .unwrap(); + let aggr_func_expr = &aggr_exec.aggr_expr()[0]; + let mut state_accum = aggr_func_expr.create_accumulator().unwrap(); + + // evaluate the state function + let input = Float64Array::from(vec![Some(1.), Some(2.), None, Some(3.)]); + let values = vec![Arc::new(input) as arrow::array::ArrayRef]; + + state_accum.update_batch(&values).unwrap(); + let state = state_accum.state().unwrap(); + assert_eq!(state.len(), 2); + assert_eq!(state[0], ScalarValue::UInt64(Some(3))); + assert_eq!(state[1], ScalarValue::Float64(Some(6.))); + + let eval_res = state_accum.evaluate().unwrap(); + let expected = Arc::new( + StructArray::try_new( + vec![ + Field::new("avg[count]", DataType::UInt64, true), + Field::new("avg[sum]", DataType::Float64, true), + ] + .into(), + vec![ + Arc::new(UInt64Array::from(vec![Some(3)])), + Arc::new(Float64Array::from(vec![Some(6.)])), + ], + None, + ) + .unwrap(), + ); + assert_eq!(eval_res, ScalarValue::Struct(expected)); + + let phy_aggr_merge_plan = DefaultPhysicalPlanner::default() + .create_physical_plan(&res.upper_merge, &ctx.state()) + .await + .unwrap(); + let aggr_exec = phy_aggr_merge_plan + .as_any() + .downcast_ref::() + .unwrap(); + let aggr_func_expr = &aggr_exec.aggr_expr()[0]; + + let mut merge_accum = aggr_func_expr.create_accumulator().unwrap(); + + let merge_input = vec![ + Arc::new(UInt64Array::from(vec![Some(3), Some(42), None])) as arrow::array::ArrayRef, + Arc::new(Float64Array::from(vec![Some(48.), Some(84.), None])), + ]; + let merge_input_struct_arr = StructArray::try_new( + vec![ + Field::new("avg[count]", DataType::UInt64, true), + Field::new("avg[sum]", DataType::Float64, true), + ] + .into(), + merge_input, + None, + ) + .unwrap(); + + merge_accum + .update_batch(&[Arc::new(merge_input_struct_arr)]) + .unwrap(); + let merge_state = merge_accum.state().unwrap(); + assert_eq!(merge_state.len(), 2); + assert_eq!(merge_state[0], ScalarValue::UInt64(Some(45))); + assert_eq!(merge_state[1], ScalarValue::Float64(Some(132.))); + + let merge_eval_res = merge_accum.evaluate().unwrap(); + // the merge function returns the average, which is 132 / 45 + assert_eq!(merge_eval_res, ScalarValue::Float64(Some(132. / 45_f64))); +} + +/// For testing whether the UDAF state fields are correctly implemented. +/// esp. for our own custom UDAF's state fields. +/// By compare eval results before and after split to state/merge functions. +#[tokio::test] +async fn test_udaf_correct_eval_result() { + struct TestCase { + func: Arc, + args: Vec, + input_schema: SchemaRef, + input: Vec, + expected_output: Option, + expected_fn: Option, + distinct: bool, + filter: Option>, + order_by: Option>, + null_treatment: Option, + } + type ExpectedFn = fn(ArrayRef) -> bool; + + let test_cases = vec![ + TestCase { + func: sum_udaf(), + input_schema: Arc::new(arrow_schema::Schema::new(vec![Field::new( + "number", + DataType::Int64, + true, + )])), + args: vec![Expr::Column(Column::new_unqualified("number"))], + input: vec![Arc::new(Int64Array::from(vec![ + Some(1), + Some(2), + None, + Some(3), + ]))], + expected_output: Some(ScalarValue::Int64(Some(6))), + expected_fn: None, + distinct: false, + filter: None, + order_by: None, + null_treatment: None, + }, + TestCase { + func: avg_udaf(), + input_schema: Arc::new(arrow_schema::Schema::new(vec![Field::new( + "number", + DataType::Int64, + true, + )])), + args: vec![Expr::Column(Column::new_unqualified("number"))], + input: vec![Arc::new(Int64Array::from(vec![ + Some(1), + Some(2), + None, + Some(3), + ]))], + expected_output: Some(ScalarValue::Float64(Some(2.0))), + expected_fn: None, + distinct: false, + filter: None, + order_by: None, + null_treatment: None, + }, + TestCase { + func: Arc::new(CountHash::udf_impl()), + input_schema: Arc::new(arrow_schema::Schema::new(vec![Field::new( + "number", + DataType::Int64, + true, + )])), + args: vec![Expr::Column(Column::new_unqualified("number"))], + input: vec![Arc::new(Int64Array::from(vec![ + Some(1), + Some(2), + None, + Some(3), + Some(3), + Some(3), + ]))], + expected_output: Some(ScalarValue::Int64(Some(4))), + expected_fn: None, + distinct: false, + filter: None, + order_by: None, + null_treatment: None, + }, + TestCase { + func: Arc::new(UddSketchState::state_udf_impl()), + input_schema: Arc::new(arrow_schema::Schema::new(vec![Field::new( + "number", + DataType::Float64, + true, + )])), + args: vec![ + Expr::Literal(ScalarValue::Int64(Some(128))), + Expr::Literal(ScalarValue::Float64(Some(0.05))), + Expr::Column(Column::new_unqualified("number")), + ], + input: vec![Arc::new(Float64Array::from(vec![ + Some(1.), + Some(2.), + None, + Some(3.), + Some(3.), + Some(3.), + ]))], + expected_output: None, + expected_fn: Some(|arr| { + let percent = ScalarValue::Float64(Some(0.5)).to_array().unwrap(); + let percent = datatypes::vectors::Helper::try_into_vector(percent).unwrap(); + let state = datatypes::vectors::Helper::try_into_vector(arr).unwrap(); + let udd_calc = UddSketchCalcFunction; + let res = udd_calc + .eval(&Default::default(), &[percent, state]) + .unwrap(); + let binding = res.to_arrow_array(); + let res_arr = binding.as_any().downcast_ref::().unwrap(); + assert!(res_arr.len() == 1); + assert!((res_arr.value(0) - 2.856578984907706f64).abs() <= f64::EPSILON); + true + }), + distinct: false, + filter: None, + order_by: None, + null_treatment: None, + }, + TestCase { + func: Arc::new(HllState::state_udf_impl()), + input_schema: Arc::new(arrow_schema::Schema::new(vec![Field::new( + "word", + DataType::Utf8, + true, + )])), + args: vec![Expr::Column(Column::new_unqualified("word"))], + input: vec![Arc::new(StringArray::from(vec![ + Some("foo"), + Some("bar"), + None, + Some("baz"), + Some("baz"), + ]))], + expected_output: None, + expected_fn: Some(|arr| { + let state = datatypes::vectors::Helper::try_into_vector(arr).unwrap(); + let hll_calc = HllCalcFunction; + let res = hll_calc.eval(&Default::default(), &[state]).unwrap(); + let binding = res.to_arrow_array(); + let res_arr = binding.as_any().downcast_ref::().unwrap(); + assert!(res_arr.len() == 1); + assert_eq!(res_arr.value(0), 3); + true + }), + distinct: false, + filter: None, + order_by: None, + null_treatment: None, + }, + // TODO(discord9): udd_merge/hll_merge/geo_path/quantile_aggr tests + ]; + let test_table_ref = TableReference::bare("TestTable"); + + for case in test_cases { + let ctx = SessionContext::new(); + let table_provider = DummyTableProvider::new( + case.input_schema.clone(), + Some(RecordBatch::try_new(case.input_schema.clone(), case.input.clone()).unwrap()), + ); + let table_source = DefaultTableSource::new(Arc::new(table_provider)); + let logical_plan = LogicalPlan::TableScan( + TableScan::try_new( + test_table_ref.clone(), + Arc::new(table_source), + None, + vec![], + None, + ) + .unwrap(), + ); + + let args = case.args; + + let aggr_expr = Expr::AggregateFunction(AggregateFunction::new_udf( + case.func.clone(), + args, + case.distinct, + case.filter, + case.order_by, + case.null_treatment, + )); + + let aggr_plan = LogicalPlan::Aggregate( + Aggregate::try_new(Arc::new(logical_plan), vec![], vec![aggr_expr]).unwrap(), + ); + + // make sure the aggr_plan is type coerced + let aggr_plan = TypeCoercion::new() + .analyze(aggr_plan, &Default::default()) + .unwrap(); + + // first eval the original aggregate function + let phy_full_aggr_plan = DefaultPhysicalPlanner::default() + .create_physical_plan(&aggr_plan, &ctx.state()) + .await + .unwrap(); + + { + let unsplit_result = execute_phy_plan(&phy_full_aggr_plan).await.unwrap(); + assert_eq!(unsplit_result.len(), 1); + let unsplit_batch = &unsplit_result[0]; + assert_eq!(unsplit_batch.num_columns(), 1); + assert_eq!(unsplit_batch.num_rows(), 1); + let unsplit_col = unsplit_batch.column(0); + if let Some(expected_output) = &case.expected_output { + assert_eq!(unsplit_col.data_type(), &expected_output.data_type()); + assert_eq!(unsplit_col.len(), 1); + assert_eq!(unsplit_col, &expected_output.to_array().unwrap()); + } + + if let Some(expected_fn) = &case.expected_fn { + assert!(expected_fn(unsplit_col.clone())); + } + } + let LogicalPlan::Aggregate(aggr_plan) = aggr_plan else { + panic!("Expected Aggregate plan"); + }; + let split_plan = StateMergeHelper::split_aggr_node(aggr_plan).unwrap(); + + let phy_upper_plan = DefaultPhysicalPlanner::default() + .create_physical_plan(&split_plan.upper_merge, &ctx.state()) + .await + .unwrap(); + + // since upper plan use lower plan as input, execute upper plan should also execute lower plan + // which should give the same result as the original aggregate function + { + let split_res = execute_phy_plan(&phy_upper_plan).await.unwrap(); + + assert_eq!(split_res.len(), 1); + let split_batch = &split_res[0]; + assert_eq!(split_batch.num_columns(), 1); + assert_eq!(split_batch.num_rows(), 1); + let split_col = split_batch.column(0); + if let Some(expected_output) = &case.expected_output { + assert_eq!(split_col.data_type(), &expected_output.data_type()); + assert_eq!(split_col.len(), 1); + assert_eq!(split_col, &expected_output.to_array().unwrap()); + } + + if let Some(expected_fn) = &case.expected_fn { + assert!(expected_fn(split_col.clone())); + } + } + } +} + +async fn execute_phy_plan( + phy_plan: &Arc, +) -> datafusion_common::Result> { + let task_ctx = Arc::new(TaskContext::default()); + let mut stream = phy_plan.execute(0, task_ctx)?; + let mut batches = Vec::new(); + while let Some(batch) = stream.next().await { + batches.push(batch?); + } + Ok(batches) +}