mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2025-12-26 08:00:01 +00:00
feat: state/merge wrapper for aggr func (#6377)
* refactor: move to query crate Signed-off-by: discord9 <discord9@163.com> * refactor: split to multiple columns Signed-off-by: discord9 <discord9@163.com> * feat: aggr merge accum wrapper Signed-off-by: discord9 <discord9@163.com> * rename shorter Signed-off-by: discord9 <discord9@163.com> * feat: add all in one helper Signed-off-by: discord9 <discord9@163.com> * tests: sum&avg Signed-off-by: discord9 <discord9@163.com> * chore: allow unused Signed-off-by: discord9 <discord9@163.com> * chore: typos Signed-off-by: discord9 <discord9@163.com> * refactor: per ds Signed-off-by: discord9 <discord9@163.com> * chore: fix tests Signed-off-by: discord9 <discord9@163.com> * refactor: move to common-function Signed-off-by: discord9 <discord9@163.com> * WIP massive refactor Signed-off-by: discord9 <discord9@163.com> * typo Signed-off-by: discord9 <discord9@163.com> * todo: stuff Signed-off-by: discord9 <discord9@163.com> * refactor: state2input type Signed-off-by: discord9 <discord9@163.com> * chore: rm unused Signed-off-by: discord9 <discord9@163.com> * refactor: per bot review Signed-off-by: discord9 <discord9@163.com> * chore: per bot Signed-off-by: discord9 <discord9@163.com> * refactor: rm duplicate infer type Signed-off-by: discord9 <discord9@163.com> * chore: better test Signed-off-by: discord9 <discord9@163.com> * fix: test sum refactor&fix wrong state types Signed-off-by: discord9 <discord9@163.com> * test: refactor avg udaf test Signed-off-by: discord9 <discord9@163.com> * refactor: split files Signed-off-by: discord9 <discord9@163.com> * refactor: docs&dedup Signed-off-by: discord9 <discord9@163.com> * refactor: allow merge to carry extra info Signed-off-by: discord9 <discord9@163.com> * chore: rm unused Signed-off-by: discord9 <discord9@163.com> * chore: clippy Signed-off-by: discord9 <discord9@163.com> * chore: docs&unused Signed-off-by: discord9 <discord9@163.com> * refactor: check fields equal Signed-off-by: discord9 <discord9@163.com> * test: test count_hash Signed-off-by: discord9 <discord9@163.com> * test: more custom udaf Signed-off-by: discord9 <discord9@163.com> * chore: clippy Signed-off-by: discord9 <discord9@163.com> * refactor: per review Signed-off-by: discord9 <discord9@163.com> --------- Signed-off-by: discord9 <discord9@163.com>
This commit is contained in:
5
Cargo.lock
generated
5
Cargo.lock
generated
@@ -2347,6 +2347,8 @@ dependencies = [
|
||||
"api",
|
||||
"approx 0.5.1",
|
||||
"arc-swap",
|
||||
"arrow 54.2.1",
|
||||
"arrow-schema 54.3.1",
|
||||
"async-trait",
|
||||
"bincode",
|
||||
"catalog",
|
||||
@@ -2365,8 +2367,10 @@ dependencies = [
|
||||
"datafusion-common",
|
||||
"datafusion-expr",
|
||||
"datafusion-functions-aggregate-common",
|
||||
"datafusion-physical-expr",
|
||||
"datatypes",
|
||||
"derive_more",
|
||||
"futures",
|
||||
"geo",
|
||||
"geo-types",
|
||||
"geohash",
|
||||
@@ -2379,6 +2383,7 @@ dependencies = [
|
||||
"num-traits",
|
||||
"once_cell",
|
||||
"paste",
|
||||
"pretty_assertions",
|
||||
"s2",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
||||
@@ -16,6 +16,8 @@ geo = ["geohash", "h3o", "s2", "wkt", "geo-types", "dep:geo"]
|
||||
ahash.workspace = true
|
||||
api.workspace = true
|
||||
arc-swap = "1.0"
|
||||
arrow.workspace = true
|
||||
arrow-schema.workspace = true
|
||||
async-trait.workspace = true
|
||||
bincode = "1.3"
|
||||
catalog.workspace = true
|
||||
@@ -34,6 +36,7 @@ datafusion.workspace = true
|
||||
datafusion-common.workspace = true
|
||||
datafusion-expr.workspace = true
|
||||
datafusion-functions-aggregate-common.workspace = true
|
||||
datafusion-physical-expr.workspace = true
|
||||
datatypes.workspace = true
|
||||
derive_more = { version = "1", default-features = false, features = ["display"] }
|
||||
geo = { version = "0.29", optional = true }
|
||||
@@ -62,5 +65,7 @@ wkt = { version = "0.11", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
approx = "0.5"
|
||||
futures.workspace = true
|
||||
pretty_assertions = "1.4.0"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
tokio.workspace = true
|
||||
|
||||
@@ -17,3 +17,5 @@ pub mod count_hash;
|
||||
#[cfg(feature = "geo")]
|
||||
pub mod geo;
|
||||
pub mod vector;
|
||||
|
||||
pub mod aggr_wrapper;
|
||||
|
||||
538
src/common/function/src/aggrs/aggr_wrapper.rs
Normal file
538
src/common/function/src/aggrs/aggr_wrapper.rs
Normal file
@@ -0,0 +1,538 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Wrapper for making aggregate functions out of state/merge functions of original aggregate functions.
|
||||
//!
|
||||
//! i.e. for a aggregate function `foo`, we will have a state function `foo_state` and a merge function `foo_merge`.
|
||||
//!
|
||||
//! `foo_state`'s input args is the same as `foo`'s, and its output is a state object.
|
||||
//! Note that `foo_state` might have multiple output columns, so it's a struct array
|
||||
//! that each output column is a struct field.
|
||||
//! `foo_merge`'s input arg is the same as `foo_state`'s output, and its output is the same as `foo`'s input.
|
||||
//!
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow::array::StructArray;
|
||||
use arrow_schema::Fields;
|
||||
use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
|
||||
use datafusion::optimizer::AnalyzerRule;
|
||||
use datafusion::physical_planner::create_aggregate_expr_and_maybe_filter;
|
||||
use datafusion_common::{Column, ScalarValue};
|
||||
use datafusion_expr::expr::AggregateFunction;
|
||||
use datafusion_expr::function::StateFieldsArgs;
|
||||
use datafusion_expr::{
|
||||
Accumulator, Aggregate, AggregateUDF, AggregateUDFImpl, Expr, ExprSchemable, LogicalPlan,
|
||||
Signature,
|
||||
};
|
||||
use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
|
||||
use datatypes::arrow::datatypes::{DataType, Field};
|
||||
|
||||
/// Returns the name of the state function for the given aggregate function name.
|
||||
/// The state function is used to compute the state of the aggregate function.
|
||||
/// The state function's name is in the format `__<aggr_name>_state
|
||||
pub fn aggr_state_func_name(aggr_name: &str) -> String {
|
||||
format!("__{}_state", aggr_name)
|
||||
}
|
||||
|
||||
/// Returns the name of the merge function for the given aggregate function name.
|
||||
/// The merge function is used to merge the states of the state functions.
|
||||
/// The merge function's name is in the format `__<aggr_name>_merge
|
||||
pub fn aggr_merge_func_name(aggr_name: &str) -> String {
|
||||
format!("__{}_merge", aggr_name)
|
||||
}
|
||||
|
||||
/// A wrapper to make an aggregate function out of the state and merge functions of the original aggregate function.
|
||||
/// It contains the original aggregate function, the state functions, and the merge function.
|
||||
///
|
||||
/// Notice state functions may have multiple output columns, so it's return type is always a struct array, and the merge function is used to merge the states of the state functions.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StateMergeHelper;
|
||||
|
||||
/// A struct to hold the two aggregate plans, one for the state function(lower) and one for the merge function(upper).
|
||||
#[allow(unused)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StepAggrPlan {
|
||||
/// Upper merge plan, which is the aggregate plan that merges the states of the state function.
|
||||
pub upper_merge: Arc<LogicalPlan>,
|
||||
/// Lower state plan, which is the aggregate plan that computes the state of the aggregate function.
|
||||
pub lower_state: Arc<LogicalPlan>,
|
||||
}
|
||||
|
||||
pub fn get_aggr_func(expr: &Expr) -> Option<&datafusion_expr::expr::AggregateFunction> {
|
||||
let mut expr_ref = expr;
|
||||
while let Expr::Alias(alias) = expr_ref {
|
||||
expr_ref = &alias.expr;
|
||||
}
|
||||
if let Expr::AggregateFunction(aggr_func) = expr_ref {
|
||||
Some(aggr_func)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl StateMergeHelper {
|
||||
/// Split an aggregate plan into two aggregate plans, one for the state function and one for the merge function.
|
||||
pub fn split_aggr_node(aggr_plan: Aggregate) -> datafusion_common::Result<StepAggrPlan> {
|
||||
let aggr = {
|
||||
// certain aggr func need type coercion to work correctly, so we need to analyze the plan first.
|
||||
let aggr_plan = TypeCoercion::new().analyze(
|
||||
LogicalPlan::Aggregate(aggr_plan).clone(),
|
||||
&Default::default(),
|
||||
)?;
|
||||
if let LogicalPlan::Aggregate(aggr) = aggr_plan {
|
||||
aggr
|
||||
} else {
|
||||
return Err(datafusion_common::DataFusionError::Internal(format!(
|
||||
"Failed to coerce expressions in aggregate plan, expected Aggregate, got: {:?}",
|
||||
aggr_plan
|
||||
)));
|
||||
}
|
||||
};
|
||||
let mut lower_aggr_exprs = vec![];
|
||||
let mut upper_aggr_exprs = vec![];
|
||||
|
||||
for aggr_expr in aggr.aggr_expr.iter() {
|
||||
let Some(aggr_func) = get_aggr_func(aggr_expr) else {
|
||||
return Err(datafusion_common::DataFusionError::NotImplemented(format!(
|
||||
"Unsupported aggregate expression for step aggr optimize: {:?}",
|
||||
aggr_expr
|
||||
)));
|
||||
};
|
||||
|
||||
let original_input_types = aggr_func
|
||||
.args
|
||||
.iter()
|
||||
.map(|e| e.get_type(&aggr.input.schema()))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
// first create the state function from the original aggregate function.
|
||||
let state_func = StateWrapper::new((*aggr_func.func).clone())?;
|
||||
|
||||
let expr = AggregateFunction {
|
||||
func: Arc::new(state_func.into()),
|
||||
args: aggr_func.args.clone(),
|
||||
distinct: aggr_func.distinct,
|
||||
filter: aggr_func.filter.clone(),
|
||||
order_by: aggr_func.order_by.clone(),
|
||||
null_treatment: aggr_func.null_treatment,
|
||||
};
|
||||
let expr = Expr::AggregateFunction(expr);
|
||||
let lower_state_output_col_name = expr.schema_name().to_string();
|
||||
|
||||
lower_aggr_exprs.push(expr);
|
||||
|
||||
let (original_phy_expr, _filter, _ordering) = create_aggregate_expr_and_maybe_filter(
|
||||
aggr_expr,
|
||||
aggr.input.schema(),
|
||||
aggr.input.schema().as_arrow(),
|
||||
&Default::default(),
|
||||
)?;
|
||||
|
||||
let merge_func = MergeWrapper::new(
|
||||
(*aggr_func.func).clone(),
|
||||
original_phy_expr,
|
||||
original_input_types,
|
||||
)?;
|
||||
let arg = Expr::Column(Column::new_unqualified(lower_state_output_col_name));
|
||||
let expr = AggregateFunction {
|
||||
func: Arc::new(merge_func.into()),
|
||||
args: vec![arg],
|
||||
distinct: aggr_func.distinct,
|
||||
filter: aggr_func.filter.clone(),
|
||||
order_by: aggr_func.order_by.clone(),
|
||||
null_treatment: aggr_func.null_treatment,
|
||||
};
|
||||
|
||||
// alias to the original aggregate expr's schema name, so parent plan can refer to it
|
||||
// correctly.
|
||||
let expr = Expr::AggregateFunction(expr).alias(aggr_expr.schema_name().to_string());
|
||||
upper_aggr_exprs.push(expr);
|
||||
}
|
||||
|
||||
let mut lower = aggr.clone();
|
||||
lower.aggr_expr = lower_aggr_exprs;
|
||||
let lower_plan = LogicalPlan::Aggregate(lower);
|
||||
|
||||
// update aggregate's output schema
|
||||
let lower_plan = Arc::new(lower_plan.recompute_schema()?);
|
||||
|
||||
let mut upper = aggr.clone();
|
||||
let aggr_plan = LogicalPlan::Aggregate(aggr);
|
||||
upper.aggr_expr = upper_aggr_exprs;
|
||||
upper.input = lower_plan.clone();
|
||||
// upper schema's output schema should be the same as the original aggregate plan's output schema
|
||||
let upper_check = upper.clone();
|
||||
let upper_plan = Arc::new(LogicalPlan::Aggregate(upper_check).recompute_schema()?);
|
||||
if *upper_plan.schema() != *aggr_plan.schema() {
|
||||
return Err(datafusion_common::DataFusionError::Internal(format!(
|
||||
"Upper aggregate plan's schema is not the same as the original aggregate plan's schema: \n[transformed]:{}\n[ original]{}",
|
||||
upper_plan.schema(), aggr_plan.schema()
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(StepAggrPlan {
|
||||
lower_state: lower_plan,
|
||||
upper_merge: upper_plan,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper to make an aggregate function out of a state function.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct StateWrapper {
|
||||
inner: AggregateUDF,
|
||||
name: String,
|
||||
}
|
||||
|
||||
impl StateWrapper {
|
||||
/// `state_index`: The index of the state in the output of the state function.
|
||||
pub fn new(inner: AggregateUDF) -> datafusion_common::Result<Self> {
|
||||
let name = aggr_state_func_name(inner.name());
|
||||
Ok(Self { inner, name })
|
||||
}
|
||||
|
||||
pub fn inner(&self) -> &AggregateUDF {
|
||||
&self.inner
|
||||
}
|
||||
|
||||
/// Deduce the return type of the original aggregate function
|
||||
/// based on the accumulator arguments.
|
||||
///
|
||||
pub fn deduce_aggr_return_type(
|
||||
&self,
|
||||
acc_args: &datafusion_expr::function::AccumulatorArgs,
|
||||
) -> datafusion_common::Result<DataType> {
|
||||
let input_exprs = acc_args.exprs;
|
||||
let input_schema = acc_args.schema;
|
||||
let input_types = input_exprs
|
||||
.iter()
|
||||
.map(|e| e.data_type(input_schema))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let return_type = self.inner.return_type(&input_types)?;
|
||||
Ok(return_type)
|
||||
}
|
||||
}
|
||||
|
||||
impl AggregateUDFImpl for StateWrapper {
|
||||
fn accumulator<'a, 'b>(
|
||||
&'a self,
|
||||
acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
|
||||
) -> datafusion_common::Result<Box<dyn Accumulator>> {
|
||||
// fix and recover proper acc args for the original aggregate function.
|
||||
let state_type = acc_args.return_type.clone();
|
||||
let inner = {
|
||||
let old_return_type = self.deduce_aggr_return_type(&acc_args)?;
|
||||
let acc_args = datafusion_expr::function::AccumulatorArgs {
|
||||
return_type: &old_return_type,
|
||||
schema: acc_args.schema,
|
||||
ignore_nulls: acc_args.ignore_nulls,
|
||||
ordering_req: acc_args.ordering_req,
|
||||
is_reversed: acc_args.is_reversed,
|
||||
name: acc_args.name,
|
||||
is_distinct: acc_args.is_distinct,
|
||||
exprs: acc_args.exprs,
|
||||
};
|
||||
self.inner.accumulator(acc_args)?
|
||||
};
|
||||
Ok(Box::new(StateAccum::new(inner, state_type)?))
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
fn name(&self) -> &str {
|
||||
self.name.as_str()
|
||||
}
|
||||
|
||||
fn is_nullable(&self) -> bool {
|
||||
self.inner.is_nullable()
|
||||
}
|
||||
|
||||
/// Return state_fields as the output struct type.
|
||||
///
|
||||
fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
|
||||
let old_return_type = self.inner.return_type(arg_types)?;
|
||||
let state_fields_args = StateFieldsArgs {
|
||||
name: self.inner().name(),
|
||||
input_types: arg_types,
|
||||
return_type: &old_return_type,
|
||||
// TODO(discord9): how to get this?, probably ok?
|
||||
ordering_fields: &[],
|
||||
is_distinct: false,
|
||||
};
|
||||
let state_fields = self.inner.state_fields(state_fields_args)?;
|
||||
let struct_field = DataType::Struct(state_fields.into());
|
||||
Ok(struct_field)
|
||||
}
|
||||
|
||||
/// The state function's output fields are the same as the original aggregate function's state fields.
|
||||
fn state_fields(
|
||||
&self,
|
||||
args: datafusion_expr::function::StateFieldsArgs,
|
||||
) -> datafusion_common::Result<Vec<Field>> {
|
||||
let old_return_type = self.inner.return_type(args.input_types)?;
|
||||
let state_fields_args = StateFieldsArgs {
|
||||
name: args.name,
|
||||
input_types: args.input_types,
|
||||
return_type: &old_return_type,
|
||||
ordering_fields: args.ordering_fields,
|
||||
is_distinct: args.is_distinct,
|
||||
};
|
||||
self.inner.state_fields(state_fields_args)
|
||||
}
|
||||
|
||||
/// The state function's signature is the same as the original aggregate function's signature,
|
||||
fn signature(&self) -> &Signature {
|
||||
self.inner.signature()
|
||||
}
|
||||
|
||||
/// Coerce types also do nothing, as optimzer should be able to already make struct types
|
||||
fn coerce_types(&self, arg_types: &[DataType]) -> datafusion_common::Result<Vec<DataType>> {
|
||||
self.inner.coerce_types(arg_types)
|
||||
}
|
||||
}
|
||||
|
||||
/// The wrapper's input is the same as the original aggregate function's input,
|
||||
/// and the output is the state function's output.
|
||||
#[derive(Debug)]
|
||||
pub struct StateAccum {
|
||||
inner: Box<dyn Accumulator>,
|
||||
state_fields: Fields,
|
||||
}
|
||||
|
||||
impl StateAccum {
|
||||
pub fn new(
|
||||
inner: Box<dyn Accumulator>,
|
||||
state_type: DataType,
|
||||
) -> datafusion_common::Result<Self> {
|
||||
let DataType::Struct(fields) = state_type else {
|
||||
return Err(datafusion_common::DataFusionError::Internal(format!(
|
||||
"Expected a struct type for state, got: {:?}",
|
||||
state_type
|
||||
)));
|
||||
};
|
||||
Ok(Self {
|
||||
inner,
|
||||
state_fields: fields,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Accumulator for StateAccum {
|
||||
fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
|
||||
let state = self.inner.state()?;
|
||||
|
||||
let array = state
|
||||
.iter()
|
||||
.map(|s| s.to_array())
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let struct_array = StructArray::try_new(self.state_fields.clone(), array, None)?;
|
||||
Ok(ScalarValue::Struct(Arc::new(struct_array)))
|
||||
}
|
||||
|
||||
fn merge_batch(
|
||||
&mut self,
|
||||
states: &[datatypes::arrow::array::ArrayRef],
|
||||
) -> datafusion_common::Result<()> {
|
||||
self.inner.merge_batch(states)
|
||||
}
|
||||
|
||||
fn update_batch(
|
||||
&mut self,
|
||||
values: &[datatypes::arrow::array::ArrayRef],
|
||||
) -> datafusion_common::Result<()> {
|
||||
self.inner.update_batch(values)
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.inner.size()
|
||||
}
|
||||
|
||||
fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
|
||||
self.inner.state()
|
||||
}
|
||||
}
|
||||
|
||||
/// TODO(discord9): mark this function as non-ser/de able
|
||||
///
|
||||
/// This wrapper shouldn't be register as a udaf, as it contain extra data that is not serializable.
|
||||
/// and changes for different logical plans.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MergeWrapper {
|
||||
inner: AggregateUDF,
|
||||
name: String,
|
||||
merge_signature: Signature,
|
||||
/// The original physical expression of the aggregate function, can't store the original aggregate function directly, as PhysicalExpr didn't implement Any
|
||||
original_phy_expr: Arc<AggregateFunctionExpr>,
|
||||
original_input_types: Vec<DataType>,
|
||||
}
|
||||
impl MergeWrapper {
|
||||
pub fn new(
|
||||
inner: AggregateUDF,
|
||||
original_phy_expr: Arc<AggregateFunctionExpr>,
|
||||
original_input_types: Vec<DataType>,
|
||||
) -> datafusion_common::Result<Self> {
|
||||
let name = aggr_merge_func_name(inner.name());
|
||||
// the input type is actually struct type, which is the state fields of the original aggregate function.
|
||||
let merge_signature = Signature::user_defined(datafusion_expr::Volatility::Immutable);
|
||||
|
||||
Ok(Self {
|
||||
inner,
|
||||
name,
|
||||
merge_signature,
|
||||
original_phy_expr,
|
||||
original_input_types,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn inner(&self) -> &AggregateUDF {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl AggregateUDFImpl for MergeWrapper {
|
||||
fn accumulator<'a, 'b>(
|
||||
&'a self,
|
||||
acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
|
||||
) -> datafusion_common::Result<Box<dyn Accumulator>> {
|
||||
if acc_args.schema.fields().len() != 1
|
||||
|| !matches!(acc_args.schema.field(0).data_type(), DataType::Struct(_))
|
||||
{
|
||||
return Err(datafusion_common::DataFusionError::Internal(format!(
|
||||
"Expected one struct type as input, got: {:?}",
|
||||
acc_args.schema
|
||||
)));
|
||||
}
|
||||
let input_type = acc_args.schema.field(0).data_type();
|
||||
let DataType::Struct(fields) = input_type else {
|
||||
return Err(datafusion_common::DataFusionError::Internal(format!(
|
||||
"Expected a struct type for input, got: {:?}",
|
||||
input_type
|
||||
)));
|
||||
};
|
||||
|
||||
let inner_accum = self.original_phy_expr.create_accumulator()?;
|
||||
Ok(Box::new(MergeAccum::new(inner_accum, fields)))
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
fn name(&self) -> &str {
|
||||
self.name.as_str()
|
||||
}
|
||||
|
||||
fn is_nullable(&self) -> bool {
|
||||
self.inner.is_nullable()
|
||||
}
|
||||
|
||||
/// Notice here the `arg_types` is actually the `state_fields`'s data types,
|
||||
/// so return fixed return type instead of using `arg_types` to determine the return type.
|
||||
fn return_type(&self, _arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
|
||||
// The return type is the same as the original aggregate function's return type.
|
||||
let ret_type = self.inner.return_type(&self.original_input_types)?;
|
||||
Ok(ret_type)
|
||||
}
|
||||
fn signature(&self) -> &Signature {
|
||||
&self.merge_signature
|
||||
}
|
||||
|
||||
/// Coerce types also do nothing, as optimzer should be able to already make struct types
|
||||
fn coerce_types(&self, arg_types: &[DataType]) -> datafusion_common::Result<Vec<DataType>> {
|
||||
// just check if the arg_types are only one and is struct array
|
||||
if arg_types.len() != 1 || !matches!(arg_types.first(), Some(DataType::Struct(_))) {
|
||||
return Err(datafusion_common::DataFusionError::Internal(format!(
|
||||
"Expected one struct type as input, got: {:?}",
|
||||
arg_types
|
||||
)));
|
||||
}
|
||||
Ok(arg_types.to_vec())
|
||||
}
|
||||
|
||||
/// Just return the original aggregate function's state fields.
|
||||
fn state_fields(
|
||||
&self,
|
||||
_args: datafusion_expr::function::StateFieldsArgs,
|
||||
) -> datafusion_common::Result<Vec<Field>> {
|
||||
self.original_phy_expr.state_fields()
|
||||
}
|
||||
}
|
||||
|
||||
/// The merge accumulator, which modify `update_batch`'s behavior to accept one struct array which
|
||||
/// include the state fields of original aggregate function, and merge said states into original accumulator
|
||||
/// the output is the same as original aggregate function
|
||||
#[derive(Debug)]
|
||||
pub struct MergeAccum {
|
||||
inner: Box<dyn Accumulator>,
|
||||
state_fields: Fields,
|
||||
}
|
||||
|
||||
impl MergeAccum {
|
||||
pub fn new(inner: Box<dyn Accumulator>, state_fields: &Fields) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
state_fields: state_fields.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Accumulator for MergeAccum {
|
||||
fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
|
||||
self.inner.evaluate()
|
||||
}
|
||||
|
||||
fn merge_batch(&mut self, states: &[arrow::array::ArrayRef]) -> datafusion_common::Result<()> {
|
||||
self.inner.merge_batch(states)
|
||||
}
|
||||
|
||||
fn update_batch(&mut self, values: &[arrow::array::ArrayRef]) -> datafusion_common::Result<()> {
|
||||
let value = values.first().ok_or_else(|| {
|
||||
datafusion_common::DataFusionError::Internal("No values provided for merge".to_string())
|
||||
})?;
|
||||
// The input values are states from other accumulators, so we merge them.
|
||||
let struct_arr = value
|
||||
.as_any()
|
||||
.downcast_ref::<StructArray>()
|
||||
.ok_or_else(|| {
|
||||
datafusion_common::DataFusionError::Internal(format!(
|
||||
"Expected StructArray, got: {:?}",
|
||||
value.data_type()
|
||||
))
|
||||
})?;
|
||||
let fields = struct_arr.fields();
|
||||
if fields != &self.state_fields {
|
||||
return Err(datafusion_common::DataFusionError::Internal(format!(
|
||||
"Expected state fields: {:?}, got: {:?}",
|
||||
self.state_fields, fields
|
||||
)));
|
||||
}
|
||||
|
||||
// now fields should be the same, so we can merge the batch
|
||||
// by pass the columns as order should be the same
|
||||
let state_columns = struct_arr.columns();
|
||||
self.inner.merge_batch(state_columns)
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.inner.size()
|
||||
}
|
||||
|
||||
fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
|
||||
self.inner.state()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
804
src/common/function/src/aggrs/aggr_wrapper/tests.rs
Normal file
804
src/common/function/src/aggrs/aggr_wrapper/tests.rs
Normal file
@@ -0,0 +1,804 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::any::Any;
|
||||
use std::pin::Pin;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use arrow::array::{ArrayRef, Float64Array, Int64Array, UInt64Array};
|
||||
use arrow::record_batch::RecordBatch;
|
||||
use arrow_schema::SchemaRef;
|
||||
use datafusion::catalog::{Session, TableProvider};
|
||||
use datafusion::datasource::DefaultTableSource;
|
||||
use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
|
||||
use datafusion::functions_aggregate::average::avg_udaf;
|
||||
use datafusion::functions_aggregate::sum::sum_udaf;
|
||||
use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
|
||||
use datafusion::optimizer::AnalyzerRule;
|
||||
use datafusion::physical_plan::aggregates::AggregateExec;
|
||||
use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
|
||||
use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
|
||||
use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner};
|
||||
use datafusion::prelude::SessionContext;
|
||||
use datafusion_common::{Column, TableReference};
|
||||
use datafusion_expr::expr::AggregateFunction;
|
||||
use datafusion_expr::sqlparser::ast::NullTreatment;
|
||||
use datafusion_expr::{Aggregate, Expr, LogicalPlan, SortExpr, TableScan};
|
||||
use datafusion_physical_expr::aggregate::AggregateExprBuilder;
|
||||
use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
|
||||
use datatypes::arrow_array::StringArray;
|
||||
use futures::{Stream, StreamExt as _};
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use super::*;
|
||||
use crate::aggrs::approximate::hll::HllState;
|
||||
use crate::aggrs::approximate::uddsketch::UddSketchState;
|
||||
use crate::aggrs::count_hash::CountHash;
|
||||
use crate::function::Function as _;
|
||||
use crate::scalars::hll_count::HllCalcFunction;
|
||||
use crate::scalars::uddsketch_calc::UddSketchCalcFunction;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MockInputExec {
|
||||
input: Vec<RecordBatch>,
|
||||
schema: SchemaRef,
|
||||
properties: PlanProperties,
|
||||
}
|
||||
|
||||
impl MockInputExec {
|
||||
pub fn new(input: Vec<RecordBatch>, schema: SchemaRef) -> Self {
|
||||
Self {
|
||||
properties: PlanProperties::new(
|
||||
EquivalenceProperties::new(schema.clone()),
|
||||
Partitioning::UnknownPartitioning(1),
|
||||
EmissionType::Incremental,
|
||||
Boundedness::Bounded,
|
||||
),
|
||||
input,
|
||||
schema,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DisplayAs for MockInputExec {
|
||||
fn fmt_as(&self, _t: DisplayFormatType, _f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutionPlan for MockInputExec {
|
||||
fn name(&self) -> &str {
|
||||
"MockInputExec"
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
|
||||
fn properties(&self) -> &PlanProperties {
|
||||
&self.properties
|
||||
}
|
||||
|
||||
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
|
||||
vec![]
|
||||
}
|
||||
|
||||
fn with_new_children(
|
||||
self: Arc<Self>,
|
||||
_children: Vec<Arc<dyn ExecutionPlan>>,
|
||||
) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
fn execute(
|
||||
&self,
|
||||
_partition: usize,
|
||||
_context: Arc<TaskContext>,
|
||||
) -> datafusion_common::Result<SendableRecordBatchStream> {
|
||||
let stream = MockStream {
|
||||
stream: self.input.clone(),
|
||||
schema: self.schema.clone(),
|
||||
idx: 0,
|
||||
};
|
||||
Ok(Box::pin(stream))
|
||||
}
|
||||
}
|
||||
|
||||
struct MockStream {
|
||||
stream: Vec<RecordBatch>,
|
||||
schema: SchemaRef,
|
||||
idx: usize,
|
||||
}
|
||||
|
||||
impl Stream for MockStream {
|
||||
type Item = datafusion_common::Result<RecordBatch>;
|
||||
fn poll_next(
|
||||
mut self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
) -> Poll<Option<datafusion_common::Result<RecordBatch>>> {
|
||||
if self.idx < self.stream.len() {
|
||||
let ret = self.stream[self.idx].clone();
|
||||
self.idx += 1;
|
||||
Poll::Ready(Some(Ok(ret)))
|
||||
} else {
|
||||
Poll::Ready(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RecordBatchStream for MockStream {
|
||||
fn schema(&self) -> SchemaRef {
|
||||
self.schema.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct DummyTableProvider {
|
||||
schema: Arc<arrow_schema::Schema>,
|
||||
record_batch: Mutex<Option<RecordBatch>>,
|
||||
}
|
||||
|
||||
impl DummyTableProvider {
|
||||
#[allow(unused)]
|
||||
pub fn new(schema: Arc<arrow_schema::Schema>, record_batch: Option<RecordBatch>) -> Self {
|
||||
Self {
|
||||
schema,
|
||||
record_batch: Mutex::new(record_batch),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DummyTableProvider {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
schema: Arc::new(arrow_schema::Schema::new(vec![Field::new(
|
||||
"number",
|
||||
DataType::Int64,
|
||||
true,
|
||||
)])),
|
||||
record_batch: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl TableProvider for DummyTableProvider {
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
fn schema(&self) -> Arc<arrow_schema::Schema> {
|
||||
self.schema.clone()
|
||||
}
|
||||
|
||||
fn table_type(&self) -> datafusion_expr::TableType {
|
||||
datafusion_expr::TableType::Base
|
||||
}
|
||||
|
||||
async fn scan(
|
||||
&self,
|
||||
_state: &dyn Session,
|
||||
_projection: Option<&Vec<usize>>,
|
||||
_filters: &[Expr],
|
||||
_limit: Option<usize>,
|
||||
) -> datafusion::error::Result<Arc<dyn ExecutionPlan>> {
|
||||
let input: Vec<RecordBatch> = self
|
||||
.record_batch
|
||||
.lock()
|
||||
.unwrap()
|
||||
.clone()
|
||||
.map(|r| vec![r])
|
||||
.unwrap_or_default();
|
||||
Ok(Arc::new(MockInputExec::new(input, self.schema.clone())))
|
||||
}
|
||||
}
|
||||
|
||||
fn dummy_table_scan() -> LogicalPlan {
|
||||
let table_provider = Arc::new(DummyTableProvider::default());
|
||||
let table_source = DefaultTableSource::new(table_provider);
|
||||
LogicalPlan::TableScan(
|
||||
TableScan::try_new(
|
||||
TableReference::bare("Number"),
|
||||
Arc::new(table_source),
|
||||
None,
|
||||
vec![],
|
||||
None,
|
||||
)
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sum_udaf() {
|
||||
let ctx = SessionContext::new();
|
||||
|
||||
let sum = datafusion::functions_aggregate::sum::sum_udaf();
|
||||
let sum = (*sum).clone();
|
||||
let original_aggr = Aggregate::try_new(
|
||||
Arc::new(dummy_table_scan()),
|
||||
vec![],
|
||||
vec![Expr::AggregateFunction(AggregateFunction::new_udf(
|
||||
Arc::new(sum.clone()),
|
||||
vec![Expr::Column(Column::new_unqualified("number"))],
|
||||
false,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
))],
|
||||
)
|
||||
.unwrap();
|
||||
let res = StateMergeHelper::split_aggr_node(original_aggr).unwrap();
|
||||
|
||||
let expected_lower_plan = LogicalPlan::Aggregate(
|
||||
Aggregate::try_new(
|
||||
Arc::new(dummy_table_scan()),
|
||||
vec![],
|
||||
vec![Expr::AggregateFunction(AggregateFunction::new_udf(
|
||||
Arc::new(StateWrapper::new(sum.clone()).unwrap().into()),
|
||||
vec![Expr::Column(Column::new_unqualified("number"))],
|
||||
false,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
))],
|
||||
)
|
||||
.unwrap(),
|
||||
)
|
||||
.recompute_schema()
|
||||
.unwrap();
|
||||
assert_eq!(res.lower_state.as_ref(), &expected_lower_plan);
|
||||
|
||||
let expected_merge_plan = LogicalPlan::Aggregate(
|
||||
Aggregate::try_new(
|
||||
Arc::new(expected_lower_plan),
|
||||
vec![],
|
||||
vec![Expr::AggregateFunction(AggregateFunction::new_udf(
|
||||
Arc::new(
|
||||
MergeWrapper::new(
|
||||
sum.clone(),
|
||||
Arc::new(
|
||||
AggregateExprBuilder::new(
|
||||
Arc::new(sum.clone()),
|
||||
vec![Arc::new(
|
||||
datafusion::physical_expr::expressions::Column::new(
|
||||
"number", 0,
|
||||
),
|
||||
)],
|
||||
)
|
||||
.schema(Arc::new(dummy_table_scan().schema().as_arrow().clone()))
|
||||
.alias("sum(number)")
|
||||
.build()
|
||||
.unwrap(),
|
||||
),
|
||||
vec![DataType::Int64],
|
||||
)
|
||||
.unwrap()
|
||||
.into(),
|
||||
),
|
||||
vec![Expr::Column(Column::new_unqualified("__sum_state(number)"))],
|
||||
false,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
))
|
||||
.alias("sum(number)")],
|
||||
)
|
||||
.unwrap(),
|
||||
);
|
||||
assert_eq!(res.upper_merge.as_ref(), &expected_merge_plan);
|
||||
|
||||
let phy_aggr_state_plan = DefaultPhysicalPlanner::default()
|
||||
.create_physical_plan(&res.lower_state, &ctx.state())
|
||||
.await
|
||||
.unwrap();
|
||||
let aggr_exec = phy_aggr_state_plan
|
||||
.as_any()
|
||||
.downcast_ref::<AggregateExec>()
|
||||
.unwrap();
|
||||
let aggr_func_expr = &aggr_exec.aggr_expr()[0];
|
||||
let mut state_accum = aggr_func_expr.create_accumulator().unwrap();
|
||||
|
||||
// evaluate the state function
|
||||
let input = Int64Array::from(vec![Some(1), Some(2), None, Some(3)]);
|
||||
let values = vec![Arc::new(input) as arrow::array::ArrayRef];
|
||||
|
||||
state_accum.update_batch(&values).unwrap();
|
||||
let state = state_accum.state().unwrap();
|
||||
assert_eq!(state.len(), 1);
|
||||
assert_eq!(state[0], ScalarValue::Int64(Some(6)));
|
||||
|
||||
let eval_res = state_accum.evaluate().unwrap();
|
||||
assert_eq!(
|
||||
eval_res,
|
||||
ScalarValue::Struct(Arc::new(
|
||||
StructArray::try_new(
|
||||
vec![Field::new("sum[sum]", DataType::Int64, true)].into(),
|
||||
vec![Arc::new(Int64Array::from(vec![Some(6)]))],
|
||||
None,
|
||||
)
|
||||
.unwrap(),
|
||||
))
|
||||
);
|
||||
|
||||
let phy_aggr_merge_plan = DefaultPhysicalPlanner::default()
|
||||
.create_physical_plan(&res.upper_merge, &ctx.state())
|
||||
.await
|
||||
.unwrap();
|
||||
let aggr_exec = phy_aggr_merge_plan
|
||||
.as_any()
|
||||
.downcast_ref::<AggregateExec>()
|
||||
.unwrap();
|
||||
let aggr_func_expr = &aggr_exec.aggr_expr()[0];
|
||||
let mut merge_accum = aggr_func_expr.create_accumulator().unwrap();
|
||||
|
||||
let merge_input =
|
||||
vec![Arc::new(Int64Array::from(vec![Some(6), Some(42), None])) as arrow::array::ArrayRef];
|
||||
let merge_input_struct_arr = StructArray::try_new(
|
||||
vec![Field::new("sum[sum]", DataType::Int64, true)].into(),
|
||||
merge_input,
|
||||
None,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
merge_accum
|
||||
.update_batch(&[Arc::new(merge_input_struct_arr)])
|
||||
.unwrap();
|
||||
let merge_state = merge_accum.state().unwrap();
|
||||
assert_eq!(merge_state.len(), 1);
|
||||
assert_eq!(merge_state[0], ScalarValue::Int64(Some(48)));
|
||||
|
||||
let merge_eval_res = merge_accum.evaluate().unwrap();
|
||||
assert_eq!(merge_eval_res, ScalarValue::Int64(Some(48)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_avg_udaf() {
|
||||
let ctx = SessionContext::new();
|
||||
|
||||
let avg = datafusion::functions_aggregate::average::avg_udaf();
|
||||
let avg = (*avg).clone();
|
||||
|
||||
let original_aggr = Aggregate::try_new(
|
||||
Arc::new(dummy_table_scan()),
|
||||
vec![],
|
||||
vec![Expr::AggregateFunction(AggregateFunction::new_udf(
|
||||
Arc::new(avg.clone()),
|
||||
vec![Expr::Column(Column::new_unqualified("number"))],
|
||||
false,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
))],
|
||||
)
|
||||
.unwrap();
|
||||
let res = StateMergeHelper::split_aggr_node(original_aggr).unwrap();
|
||||
|
||||
let state_func: Arc<AggregateUDF> = Arc::new(StateWrapper::new(avg.clone()).unwrap().into());
|
||||
let expected_aggr_state_plan = LogicalPlan::Aggregate(
|
||||
Aggregate::try_new(
|
||||
Arc::new(dummy_table_scan()),
|
||||
vec![],
|
||||
vec![Expr::AggregateFunction(AggregateFunction::new_udf(
|
||||
state_func,
|
||||
vec![Expr::Column(Column::new_unqualified("number"))],
|
||||
false,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
))],
|
||||
)
|
||||
.unwrap(),
|
||||
);
|
||||
// type coerced so avg aggr function can function correctly
|
||||
let coerced_aggr_state_plan = TypeCoercion::new()
|
||||
.analyze(expected_aggr_state_plan.clone(), &Default::default())
|
||||
.unwrap();
|
||||
assert_eq!(res.lower_state.as_ref(), &coerced_aggr_state_plan);
|
||||
assert_eq!(
|
||||
res.lower_state.schema().as_arrow(),
|
||||
&arrow_schema::Schema::new(vec![Field::new(
|
||||
"__avg_state(number)",
|
||||
DataType::Struct(
|
||||
vec![
|
||||
Field::new("avg[count]", DataType::UInt64, true),
|
||||
Field::new("avg[sum]", DataType::Float64, true)
|
||||
]
|
||||
.into()
|
||||
),
|
||||
true,
|
||||
)])
|
||||
);
|
||||
|
||||
let expected_merge_fn = MergeWrapper::new(
|
||||
avg.clone(),
|
||||
Arc::new(
|
||||
AggregateExprBuilder::new(
|
||||
Arc::new(avg.clone()),
|
||||
vec![Arc::new(
|
||||
datafusion::physical_expr::expressions::Column::new("number", 0),
|
||||
)],
|
||||
)
|
||||
.schema(Arc::new(dummy_table_scan().schema().as_arrow().clone()))
|
||||
.alias("avg(number)")
|
||||
.build()
|
||||
.unwrap(),
|
||||
),
|
||||
// coerced to float64
|
||||
vec![DataType::Float64],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let expected_merge_plan = LogicalPlan::Aggregate(
|
||||
Aggregate::try_new(
|
||||
Arc::new(coerced_aggr_state_plan.clone()),
|
||||
vec![],
|
||||
vec![Expr::AggregateFunction(AggregateFunction::new_udf(
|
||||
Arc::new(expected_merge_fn.into()),
|
||||
vec![Expr::Column(Column::new_unqualified("__avg_state(number)"))],
|
||||
false,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
))
|
||||
.alias("avg(number)")],
|
||||
)
|
||||
.unwrap(),
|
||||
);
|
||||
assert_eq!(res.upper_merge.as_ref(), &expected_merge_plan);
|
||||
|
||||
let phy_aggr_state_plan = DefaultPhysicalPlanner::default()
|
||||
.create_physical_plan(&coerced_aggr_state_plan, &ctx.state())
|
||||
.await
|
||||
.unwrap();
|
||||
let aggr_exec = phy_aggr_state_plan
|
||||
.as_any()
|
||||
.downcast_ref::<AggregateExec>()
|
||||
.unwrap();
|
||||
let aggr_func_expr = &aggr_exec.aggr_expr()[0];
|
||||
let mut state_accum = aggr_func_expr.create_accumulator().unwrap();
|
||||
|
||||
// evaluate the state function
|
||||
let input = Float64Array::from(vec![Some(1.), Some(2.), None, Some(3.)]);
|
||||
let values = vec![Arc::new(input) as arrow::array::ArrayRef];
|
||||
|
||||
state_accum.update_batch(&values).unwrap();
|
||||
let state = state_accum.state().unwrap();
|
||||
assert_eq!(state.len(), 2);
|
||||
assert_eq!(state[0], ScalarValue::UInt64(Some(3)));
|
||||
assert_eq!(state[1], ScalarValue::Float64(Some(6.)));
|
||||
|
||||
let eval_res = state_accum.evaluate().unwrap();
|
||||
let expected = Arc::new(
|
||||
StructArray::try_new(
|
||||
vec![
|
||||
Field::new("avg[count]", DataType::UInt64, true),
|
||||
Field::new("avg[sum]", DataType::Float64, true),
|
||||
]
|
||||
.into(),
|
||||
vec![
|
||||
Arc::new(UInt64Array::from(vec![Some(3)])),
|
||||
Arc::new(Float64Array::from(vec![Some(6.)])),
|
||||
],
|
||||
None,
|
||||
)
|
||||
.unwrap(),
|
||||
);
|
||||
assert_eq!(eval_res, ScalarValue::Struct(expected));
|
||||
|
||||
let phy_aggr_merge_plan = DefaultPhysicalPlanner::default()
|
||||
.create_physical_plan(&res.upper_merge, &ctx.state())
|
||||
.await
|
||||
.unwrap();
|
||||
let aggr_exec = phy_aggr_merge_plan
|
||||
.as_any()
|
||||
.downcast_ref::<AggregateExec>()
|
||||
.unwrap();
|
||||
let aggr_func_expr = &aggr_exec.aggr_expr()[0];
|
||||
|
||||
let mut merge_accum = aggr_func_expr.create_accumulator().unwrap();
|
||||
|
||||
let merge_input = vec![
|
||||
Arc::new(UInt64Array::from(vec![Some(3), Some(42), None])) as arrow::array::ArrayRef,
|
||||
Arc::new(Float64Array::from(vec![Some(48.), Some(84.), None])),
|
||||
];
|
||||
let merge_input_struct_arr = StructArray::try_new(
|
||||
vec![
|
||||
Field::new("avg[count]", DataType::UInt64, true),
|
||||
Field::new("avg[sum]", DataType::Float64, true),
|
||||
]
|
||||
.into(),
|
||||
merge_input,
|
||||
None,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
merge_accum
|
||||
.update_batch(&[Arc::new(merge_input_struct_arr)])
|
||||
.unwrap();
|
||||
let merge_state = merge_accum.state().unwrap();
|
||||
assert_eq!(merge_state.len(), 2);
|
||||
assert_eq!(merge_state[0], ScalarValue::UInt64(Some(45)));
|
||||
assert_eq!(merge_state[1], ScalarValue::Float64(Some(132.)));
|
||||
|
||||
let merge_eval_res = merge_accum.evaluate().unwrap();
|
||||
// the merge function returns the average, which is 132 / 45
|
||||
assert_eq!(merge_eval_res, ScalarValue::Float64(Some(132. / 45_f64)));
|
||||
}
|
||||
|
||||
/// For testing whether the UDAF state fields are correctly implemented.
|
||||
/// esp. for our own custom UDAF's state fields.
|
||||
/// By compare eval results before and after split to state/merge functions.
|
||||
#[tokio::test]
|
||||
async fn test_udaf_correct_eval_result() {
|
||||
struct TestCase {
|
||||
func: Arc<AggregateUDF>,
|
||||
args: Vec<Expr>,
|
||||
input_schema: SchemaRef,
|
||||
input: Vec<ArrayRef>,
|
||||
expected_output: Option<ScalarValue>,
|
||||
expected_fn: Option<ExpectedFn>,
|
||||
distinct: bool,
|
||||
filter: Option<Box<Expr>>,
|
||||
order_by: Option<Vec<SortExpr>>,
|
||||
null_treatment: Option<NullTreatment>,
|
||||
}
|
||||
type ExpectedFn = fn(ArrayRef) -> bool;
|
||||
|
||||
let test_cases = vec![
|
||||
TestCase {
|
||||
func: sum_udaf(),
|
||||
input_schema: Arc::new(arrow_schema::Schema::new(vec![Field::new(
|
||||
"number",
|
||||
DataType::Int64,
|
||||
true,
|
||||
)])),
|
||||
args: vec![Expr::Column(Column::new_unqualified("number"))],
|
||||
input: vec![Arc::new(Int64Array::from(vec![
|
||||
Some(1),
|
||||
Some(2),
|
||||
None,
|
||||
Some(3),
|
||||
]))],
|
||||
expected_output: Some(ScalarValue::Int64(Some(6))),
|
||||
expected_fn: None,
|
||||
distinct: false,
|
||||
filter: None,
|
||||
order_by: None,
|
||||
null_treatment: None,
|
||||
},
|
||||
TestCase {
|
||||
func: avg_udaf(),
|
||||
input_schema: Arc::new(arrow_schema::Schema::new(vec![Field::new(
|
||||
"number",
|
||||
DataType::Int64,
|
||||
true,
|
||||
)])),
|
||||
args: vec![Expr::Column(Column::new_unqualified("number"))],
|
||||
input: vec![Arc::new(Int64Array::from(vec![
|
||||
Some(1),
|
||||
Some(2),
|
||||
None,
|
||||
Some(3),
|
||||
]))],
|
||||
expected_output: Some(ScalarValue::Float64(Some(2.0))),
|
||||
expected_fn: None,
|
||||
distinct: false,
|
||||
filter: None,
|
||||
order_by: None,
|
||||
null_treatment: None,
|
||||
},
|
||||
TestCase {
|
||||
func: Arc::new(CountHash::udf_impl()),
|
||||
input_schema: Arc::new(arrow_schema::Schema::new(vec![Field::new(
|
||||
"number",
|
||||
DataType::Int64,
|
||||
true,
|
||||
)])),
|
||||
args: vec![Expr::Column(Column::new_unqualified("number"))],
|
||||
input: vec![Arc::new(Int64Array::from(vec![
|
||||
Some(1),
|
||||
Some(2),
|
||||
None,
|
||||
Some(3),
|
||||
Some(3),
|
||||
Some(3),
|
||||
]))],
|
||||
expected_output: Some(ScalarValue::Int64(Some(4))),
|
||||
expected_fn: None,
|
||||
distinct: false,
|
||||
filter: None,
|
||||
order_by: None,
|
||||
null_treatment: None,
|
||||
},
|
||||
TestCase {
|
||||
func: Arc::new(UddSketchState::state_udf_impl()),
|
||||
input_schema: Arc::new(arrow_schema::Schema::new(vec![Field::new(
|
||||
"number",
|
||||
DataType::Float64,
|
||||
true,
|
||||
)])),
|
||||
args: vec![
|
||||
Expr::Literal(ScalarValue::Int64(Some(128))),
|
||||
Expr::Literal(ScalarValue::Float64(Some(0.05))),
|
||||
Expr::Column(Column::new_unqualified("number")),
|
||||
],
|
||||
input: vec![Arc::new(Float64Array::from(vec![
|
||||
Some(1.),
|
||||
Some(2.),
|
||||
None,
|
||||
Some(3.),
|
||||
Some(3.),
|
||||
Some(3.),
|
||||
]))],
|
||||
expected_output: None,
|
||||
expected_fn: Some(|arr| {
|
||||
let percent = ScalarValue::Float64(Some(0.5)).to_array().unwrap();
|
||||
let percent = datatypes::vectors::Helper::try_into_vector(percent).unwrap();
|
||||
let state = datatypes::vectors::Helper::try_into_vector(arr).unwrap();
|
||||
let udd_calc = UddSketchCalcFunction;
|
||||
let res = udd_calc
|
||||
.eval(&Default::default(), &[percent, state])
|
||||
.unwrap();
|
||||
let binding = res.to_arrow_array();
|
||||
let res_arr = binding.as_any().downcast_ref::<Float64Array>().unwrap();
|
||||
assert!(res_arr.len() == 1);
|
||||
assert!((res_arr.value(0) - 2.856578984907706f64).abs() <= f64::EPSILON);
|
||||
true
|
||||
}),
|
||||
distinct: false,
|
||||
filter: None,
|
||||
order_by: None,
|
||||
null_treatment: None,
|
||||
},
|
||||
TestCase {
|
||||
func: Arc::new(HllState::state_udf_impl()),
|
||||
input_schema: Arc::new(arrow_schema::Schema::new(vec![Field::new(
|
||||
"word",
|
||||
DataType::Utf8,
|
||||
true,
|
||||
)])),
|
||||
args: vec![Expr::Column(Column::new_unqualified("word"))],
|
||||
input: vec![Arc::new(StringArray::from(vec![
|
||||
Some("foo"),
|
||||
Some("bar"),
|
||||
None,
|
||||
Some("baz"),
|
||||
Some("baz"),
|
||||
]))],
|
||||
expected_output: None,
|
||||
expected_fn: Some(|arr| {
|
||||
let state = datatypes::vectors::Helper::try_into_vector(arr).unwrap();
|
||||
let hll_calc = HllCalcFunction;
|
||||
let res = hll_calc.eval(&Default::default(), &[state]).unwrap();
|
||||
let binding = res.to_arrow_array();
|
||||
let res_arr = binding.as_any().downcast_ref::<UInt64Array>().unwrap();
|
||||
assert!(res_arr.len() == 1);
|
||||
assert_eq!(res_arr.value(0), 3);
|
||||
true
|
||||
}),
|
||||
distinct: false,
|
||||
filter: None,
|
||||
order_by: None,
|
||||
null_treatment: None,
|
||||
},
|
||||
// TODO(discord9): udd_merge/hll_merge/geo_path/quantile_aggr tests
|
||||
];
|
||||
let test_table_ref = TableReference::bare("TestTable");
|
||||
|
||||
for case in test_cases {
|
||||
let ctx = SessionContext::new();
|
||||
let table_provider = DummyTableProvider::new(
|
||||
case.input_schema.clone(),
|
||||
Some(RecordBatch::try_new(case.input_schema.clone(), case.input.clone()).unwrap()),
|
||||
);
|
||||
let table_source = DefaultTableSource::new(Arc::new(table_provider));
|
||||
let logical_plan = LogicalPlan::TableScan(
|
||||
TableScan::try_new(
|
||||
test_table_ref.clone(),
|
||||
Arc::new(table_source),
|
||||
None,
|
||||
vec![],
|
||||
None,
|
||||
)
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
let args = case.args;
|
||||
|
||||
let aggr_expr = Expr::AggregateFunction(AggregateFunction::new_udf(
|
||||
case.func.clone(),
|
||||
args,
|
||||
case.distinct,
|
||||
case.filter,
|
||||
case.order_by,
|
||||
case.null_treatment,
|
||||
));
|
||||
|
||||
let aggr_plan = LogicalPlan::Aggregate(
|
||||
Aggregate::try_new(Arc::new(logical_plan), vec![], vec![aggr_expr]).unwrap(),
|
||||
);
|
||||
|
||||
// make sure the aggr_plan is type coerced
|
||||
let aggr_plan = TypeCoercion::new()
|
||||
.analyze(aggr_plan, &Default::default())
|
||||
.unwrap();
|
||||
|
||||
// first eval the original aggregate function
|
||||
let phy_full_aggr_plan = DefaultPhysicalPlanner::default()
|
||||
.create_physical_plan(&aggr_plan, &ctx.state())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
{
|
||||
let unsplit_result = execute_phy_plan(&phy_full_aggr_plan).await.unwrap();
|
||||
assert_eq!(unsplit_result.len(), 1);
|
||||
let unsplit_batch = &unsplit_result[0];
|
||||
assert_eq!(unsplit_batch.num_columns(), 1);
|
||||
assert_eq!(unsplit_batch.num_rows(), 1);
|
||||
let unsplit_col = unsplit_batch.column(0);
|
||||
if let Some(expected_output) = &case.expected_output {
|
||||
assert_eq!(unsplit_col.data_type(), &expected_output.data_type());
|
||||
assert_eq!(unsplit_col.len(), 1);
|
||||
assert_eq!(unsplit_col, &expected_output.to_array().unwrap());
|
||||
}
|
||||
|
||||
if let Some(expected_fn) = &case.expected_fn {
|
||||
assert!(expected_fn(unsplit_col.clone()));
|
||||
}
|
||||
}
|
||||
let LogicalPlan::Aggregate(aggr_plan) = aggr_plan else {
|
||||
panic!("Expected Aggregate plan");
|
||||
};
|
||||
let split_plan = StateMergeHelper::split_aggr_node(aggr_plan).unwrap();
|
||||
|
||||
let phy_upper_plan = DefaultPhysicalPlanner::default()
|
||||
.create_physical_plan(&split_plan.upper_merge, &ctx.state())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// since upper plan use lower plan as input, execute upper plan should also execute lower plan
|
||||
// which should give the same result as the original aggregate function
|
||||
{
|
||||
let split_res = execute_phy_plan(&phy_upper_plan).await.unwrap();
|
||||
|
||||
assert_eq!(split_res.len(), 1);
|
||||
let split_batch = &split_res[0];
|
||||
assert_eq!(split_batch.num_columns(), 1);
|
||||
assert_eq!(split_batch.num_rows(), 1);
|
||||
let split_col = split_batch.column(0);
|
||||
if let Some(expected_output) = &case.expected_output {
|
||||
assert_eq!(split_col.data_type(), &expected_output.data_type());
|
||||
assert_eq!(split_col.len(), 1);
|
||||
assert_eq!(split_col, &expected_output.to_array().unwrap());
|
||||
}
|
||||
|
||||
if let Some(expected_fn) = &case.expected_fn {
|
||||
assert!(expected_fn(split_col.clone()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn execute_phy_plan(
|
||||
phy_plan: &Arc<dyn ExecutionPlan>,
|
||||
) -> datafusion_common::Result<Vec<RecordBatch>> {
|
||||
let task_ctx = Arc::new(TaskContext::default());
|
||||
let mut stream = phy_plan.execute(0, task_ctx)?;
|
||||
let mut batches = Vec::new();
|
||||
while let Some(batch) = stream.next().await {
|
||||
batches.push(batch?);
|
||||
}
|
||||
Ok(batches)
|
||||
}
|
||||
Reference in New Issue
Block a user