feat: procedure macro to help writing UDAF (#242)

* feat: procedure macro to help writing UDAF

* resolve code review comments

Co-authored-by: luofucong <luofucong@greptime.com>
This commit is contained in:
LFC
2022-09-13 10:39:44 +08:00
committed by GitHub
parent 628cdb89e8
commit 64b6b2afe1
22 changed files with 226 additions and 317 deletions

15
Cargo.lock generated
View File

@@ -816,6 +816,7 @@ dependencies = [
"arrow2",
"chrono-tz",
"common-error",
"common-function-macro",
"common-query",
"datafusion-common",
"datatypes",
@@ -836,6 +837,19 @@ dependencies = [
"statrs",
]
[[package]]
name = "common-function-macro"
version = "0.1.0"
dependencies = [
"arc-swap",
"common-query",
"datatypes",
"quote",
"snafu",
"static_assertions",
"syn",
]
[[package]]
name = "common-grpc"
version = "0.1.0"
@@ -3670,6 +3684,7 @@ dependencies = [
"catalog",
"common-error",
"common-function",
"common-function-macro",
"common-query",
"common-recordbatch",
"common-telemetry",

View File

@@ -6,6 +6,7 @@ members = [
"src/common/base",
"src/common/error",
"src/common/function",
"src/common/function-macro",
"src/common/grpc",
"src/common/query",
"src/common/recordbatch",

View File

@@ -8,27 +8,29 @@ So is there a way we can make an aggregate function that automatically match the
# 1. Impl `AggregateFunctionCreator` trait for your accumulator creator.
You must first define a struct that can store the input data's type. For example,
You must first define a struct that will be used to create your accumulator. For example,
```Rust
struct MySumAccumulatorCreator {
input_types: ArcSwapOption<Vec<ConcreteDataType>>,
}
#[as_aggr_func_creator]
#[derive(Debug, AggrFuncTypeStore)]
struct MySumAccumulatorCreator {}
```
Attribute macro `#[as_aggr_func_creator]` and derive macro `#[derive(Debug, AggrFuncTypeStore)]` must both annotated on the struct. They work together to provide a storage of aggregate function's input data types, which are needed for creating generic accumulator later.
> Note that the `as_aggr_func_creator` macro will add fields to the struct, so the struct cannot be defined as an empty struct without field like `struct Foo;`, neither as a new type like `struct Foo(bar)`.
Then impl `AggregateFunctionCreator` trait on it. The definition of the trait is:
```Rust
pub trait AggregateFunctionCreator: Send + Sync + Debug {
fn creator(&self) -> AccumulatorCreatorFunction;
fn input_types(&self) -> Vec<ConcreteDataType>;
fn set_input_types(&self, input_types: Vec<ConcreteDataType>);
fn output_type(&self) -> ConcreteDataType;
fn state_types(&self) -> Vec<ConcreteDataType>;
}
```
our query engine will call `set_input_types` the very first, so you can use input data's type in methods that return output type and state types.
You can use input data's type in methods that return output type and state types (just invoke `input_types()`).
The output type is aggregate function's output data's type. For example, `SUM` aggregate function's output type is `u64` for a `u32` datatype column. The state types are accumulator's internal states' types. Take `AVG` aggregate function on a `i32` column as example, it's state types are `i64` (for sum) and `u64` (for count).

View File

@@ -0,0 +1,18 @@
[package]
name = "common-function-macro"
version = "0.1.0"
edition = "2021"
[lib]
proc-macro = true
[dependencies]
common-query = { path = "../query" }
datatypes = { path = "../../datatypes" }
quote = "1.0"
snafu = { version = "0.7", features = ["backtraces"] }
syn = "1.0"
[dev-dependencies]
arc-swap = "1.0"
static_assertions = "1.1.0"

View File

@@ -0,0 +1,71 @@
use proc_macro::TokenStream;
use quote::{quote, quote_spanned};
use syn::parse::Parser;
use syn::spanned::Spanned;
use syn::{parse_macro_input, DeriveInput, ItemStruct};
/// Make struct implemented trait [AggrFuncTypeStore], which is necessary when writing UDAF.
/// This derive macro is expect to be used along with attribute macro [as_aggr_func_creator].
#[proc_macro_derive(AggrFuncTypeStore)]
pub fn aggr_func_type_store_derive(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as DeriveInput);
impl_aggr_func_type_store(&ast)
}
fn impl_aggr_func_type_store(ast: &DeriveInput) -> TokenStream {
let name = &ast.ident;
let gen = quote! {
use common_query::logical_plan::accumulator::AggrFuncTypeStore;
use common_query::error::{InvalidInputStateSnafu, Error as QueryError};
use datatypes::prelude::ConcreteDataType;
impl AggrFuncTypeStore for #name {
fn input_types(&self) -> std::result::Result<Vec<ConcreteDataType>, QueryError> {
let input_types = self.input_types.load();
snafu::ensure!(input_types.is_some(), InvalidInputStateSnafu);
Ok(input_types.as_ref().unwrap().as_ref().clone())
}
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> std::result::Result<(), QueryError> {
let old = self.input_types.swap(Some(std::sync::Arc::new(input_types.clone())));
if let Some(old) = old {
snafu::ensure!(old.len() == input_types.len(), InvalidInputStateSnafu);
for (x, y) in old.iter().zip(input_types.iter()) {
snafu::ensure!(x == y, InvalidInputStateSnafu);
}
}
Ok(())
}
}
};
gen.into()
}
/// A struct can be used as a creator for aggregate function if it has been annotated with this
/// attribute first. This attribute add a necessary field which is intended to store the input
/// data's types to the struct.
/// This attribute is expected to be used along with derive macro [AggrFuncTypeStore].
#[proc_macro_attribute]
pub fn as_aggr_func_creator(_args: TokenStream, input: TokenStream) -> TokenStream {
let mut item_struct = parse_macro_input!(input as ItemStruct);
if let syn::Fields::Named(ref mut fields) = item_struct.fields {
let result = syn::Field::parse_named.parse2(quote! {
input_types: arc_swap::ArcSwapOption<Vec<ConcreteDataType>>
});
match result {
Ok(field) => fields.named.push(field),
Err(e) => return e.into_compile_error().into(),
}
} else {
return quote_spanned!(
item_struct.fields.span() => compile_error!(
"This attribute macro needs to add fields to the its annotated struct, \
so the struct must have \"{}\".")
)
.into();
}
quote! {
#item_struct
}
.into()
}

View File

@@ -0,0 +1,14 @@
use common_function_macro::as_aggr_func_creator;
use common_function_macro::AggrFuncTypeStore;
use static_assertions::{assert_fields, assert_impl_all};
#[as_aggr_func_creator]
#[derive(Debug, Default, AggrFuncTypeStore)]
struct Foo {}
#[test]
fn test_derive() {
Foo::default();
assert_fields!(Foo: input_types);
assert_impl_all!(Foo: std::fmt::Debug, Default, AggrFuncTypeStore);
}

View File

@@ -13,6 +13,7 @@ arc-swap = "1.0"
chrono-tz = "0.6"
common-error = { path = "../error" }
common-query = { path = "../query" }
common-function-macro = { path = "../function-macro" }
datafusion-common = { git = "https://github.com/apache/arrow-datafusion.git" , branch = "arrow2" }
datatypes = { path = "../../datatypes" }
libc = "0.2"

View File

@@ -1,10 +1,8 @@
use std::cmp::Ordering;
use std::sync::Arc;
use arc_swap::ArcSwapOption;
use common_query::error::{
BadAccumulatorImplSnafu, CreateAccumulatorSnafu, InvalidInputStateSnafu, Result,
};
use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore};
use common_query::error::{BadAccumulatorImplSnafu, CreateAccumulatorSnafu, Result};
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*;
use datatypes::vectors::ConstantVector;
@@ -98,10 +96,9 @@ where
}
}
#[derive(Debug, Default)]
pub struct ArgmaxAccumulatorCreator {
input_types: ArcSwapOption<Vec<ConcreteDataType>>,
}
#[as_aggr_func_creator]
#[derive(Debug, Default, AggrFuncTypeStore)]
pub struct ArgmaxAccumulatorCreator {}
impl AggregateFunctionCreator for ArgmaxAccumulatorCreator {
fn creator(&self) -> AccumulatorCreatorFunction {
@@ -124,23 +121,6 @@ impl AggregateFunctionCreator for ArgmaxAccumulatorCreator {
creator
}
fn input_types(&self) -> Result<Vec<ConcreteDataType>> {
let input_types = self.input_types.load();
ensure!(input_types.is_some(), InvalidInputStateSnafu);
Ok(input_types.as_ref().unwrap().as_ref().clone())
}
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> Result<()> {
let old = self.input_types.swap(Some(Arc::new(input_types.clone())));
if let Some(old) = old {
ensure!(old.len() == input_types.len(), InvalidInputStateSnafu);
for (x, y) in old.iter().zip(input_types.iter()) {
ensure!(x == y, InvalidInputStateSnafu);
}
}
Ok(())
}
fn output_type(&self) -> Result<ConcreteDataType> {
Ok(ConcreteDataType::uint64_datatype())
}

View File

@@ -1,10 +1,8 @@
use std::cmp::Ordering;
use std::sync::Arc;
use arc_swap::ArcSwapOption;
use common_query::error::{
BadAccumulatorImplSnafu, CreateAccumulatorSnafu, InvalidInputStateSnafu, Result,
};
use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore};
use common_query::error::{BadAccumulatorImplSnafu, CreateAccumulatorSnafu, Result};
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*;
use datatypes::vectors::ConstantVector;
@@ -107,10 +105,9 @@ where
}
}
#[derive(Debug, Default)]
pub struct ArgminAccumulatorCreator {
input_types: ArcSwapOption<Vec<ConcreteDataType>>,
}
#[as_aggr_func_creator]
#[derive(Debug, Default, AggrFuncTypeStore)]
pub struct ArgminAccumulatorCreator {}
impl AggregateFunctionCreator for ArgminAccumulatorCreator {
fn creator(&self) -> AccumulatorCreatorFunction {
@@ -133,23 +130,6 @@ impl AggregateFunctionCreator for ArgminAccumulatorCreator {
creator
}
fn input_types(&self) -> Result<Vec<ConcreteDataType>> {
let input_types = self.input_types.load();
ensure!(input_types.is_some(), InvalidInputStateSnafu);
Ok(input_types.as_ref().unwrap().as_ref().clone())
}
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> Result<()> {
let old = self.input_types.swap(Some(Arc::new(input_types.clone())));
if let Some(old) = old {
ensure!(old.len() == input_types.len(), InvalidInputStateSnafu);
for (x, y) in old.iter().zip(input_types.iter()) {
ensure!(x == y, InvalidInputStateSnafu);
}
}
Ok(())
}
fn output_type(&self) -> Result<ConcreteDataType> {
Ok(ConcreteDataType::uint32_datatype())
}

View File

@@ -1,10 +1,9 @@
use std::marker::PhantomData;
use std::sync::Arc;
use arc_swap::ArcSwapOption;
use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore};
use common_query::error::{
CreateAccumulatorSnafu, DowncastVectorSnafu, FromScalarValueSnafu, InvalidInputStateSnafu,
Result,
CreateAccumulatorSnafu, DowncastVectorSnafu, FromScalarValueSnafu, Result,
};
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*;
@@ -118,10 +117,9 @@ where
}
}
#[derive(Debug, Default)]
pub struct DiffAccumulatorCreator {
input_types: ArcSwapOption<Vec<ConcreteDataType>>,
}
#[as_aggr_func_creator]
#[derive(Debug, Default, AggrFuncTypeStore)]
pub struct DiffAccumulatorCreator {}
impl AggregateFunctionCreator for DiffAccumulatorCreator {
fn creator(&self) -> AccumulatorCreatorFunction {
@@ -144,24 +142,6 @@ impl AggregateFunctionCreator for DiffAccumulatorCreator {
creator
}
fn input_types(&self) -> Result<Vec<ConcreteDataType>> {
let input_types = self.input_types.load();
ensure!(input_types.is_some(), InvalidInputStateSnafu);
Ok(input_types.as_ref().unwrap().as_ref().clone())
}
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> Result<()> {
let old = self.input_types.swap(Some(Arc::new(input_types.clone())));
if let Some(old) = old {
ensure!(old.len() != input_types.len(), InvalidInputStateSnafu);
for (x, y) in old.iter().zip(input_types.iter()) {
ensure!(x == y, InvalidInputStateSnafu);
}
}
Ok(())
}
fn output_type(&self) -> Result<ConcreteDataType> {
let input_types = self.input_types()?;
ensure!(input_types.len() == 1, InvalidInputStateSnafu);

View File

@@ -1,10 +1,9 @@
use std::marker::PhantomData;
use std::sync::Arc;
use arc_swap::ArcSwapOption;
use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore};
use common_query::error::{
BadAccumulatorImplSnafu, CreateAccumulatorSnafu, DowncastVectorSnafu, InvalidInputStateSnafu,
Result,
BadAccumulatorImplSnafu, CreateAccumulatorSnafu, DowncastVectorSnafu, Result,
};
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*;
@@ -125,10 +124,9 @@ where
}
}
#[derive(Debug, Default)]
pub struct MeanAccumulatorCreator {
input_types: ArcSwapOption<Vec<ConcreteDataType>>,
}
#[as_aggr_func_creator]
#[derive(Debug, Default, AggrFuncTypeStore)]
pub struct MeanAccumulatorCreator {}
impl AggregateFunctionCreator for MeanAccumulatorCreator {
fn creator(&self) -> AccumulatorCreatorFunction {
@@ -151,23 +149,6 @@ impl AggregateFunctionCreator for MeanAccumulatorCreator {
creator
}
fn input_types(&self) -> Result<Vec<ConcreteDataType>> {
let input_types = self.input_types.load();
ensure!(input_types.is_some(), InvalidInputStateSnafu);
Ok(input_types.as_ref().unwrap().as_ref().clone())
}
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> Result<()> {
let old = self.input_types.swap(Some(Arc::new(input_types.clone())));
if let Some(old) = old {
ensure!(old.len() == input_types.len(), InvalidInputStateSnafu);
for (x, y) in old.iter().zip(input_types.iter()) {
ensure!(x == y, InvalidInputStateSnafu);
}
}
Ok(())
}
fn output_type(&self) -> Result<ConcreteDataType> {
let input_types = self.input_types()?;
ensure!(input_types.len() == 1, InvalidInputStateSnafu);

View File

@@ -2,10 +2,9 @@ use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::sync::Arc;
use arc_swap::ArcSwapOption;
use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore};
use common_query::error::{
CreateAccumulatorSnafu, DowncastVectorSnafu, FromScalarValueSnafu, InvalidInputStateSnafu,
Result,
CreateAccumulatorSnafu, DowncastVectorSnafu, FromScalarValueSnafu, Result,
};
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*;
@@ -175,10 +174,9 @@ where
}
}
#[derive(Debug, Default)]
pub struct MedianAccumulatorCreator {
input_types: ArcSwapOption<Vec<ConcreteDataType>>,
}
#[as_aggr_func_creator]
#[derive(Debug, Default, AggrFuncTypeStore)]
pub struct MedianAccumulatorCreator {}
impl AggregateFunctionCreator for MedianAccumulatorCreator {
fn creator(&self) -> AccumulatorCreatorFunction {
@@ -201,23 +199,6 @@ impl AggregateFunctionCreator for MedianAccumulatorCreator {
creator
}
fn input_types(&self) -> Result<Vec<ConcreteDataType>> {
let input_types = self.input_types.load();
ensure!(input_types.is_some(), InvalidInputStateSnafu);
Ok(input_types.as_ref().unwrap().as_ref().clone())
}
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> Result<()> {
let old = self.input_types.swap(Some(Arc::new(input_types.clone())));
if let Some(old) = old {
ensure!(old.len() == input_types.len(), InvalidInputStateSnafu);
for (x, y) in old.iter().zip(input_types.iter()) {
ensure!(x == y, InvalidInputStateSnafu);
}
}
Ok(())
}
fn output_type(&self) -> Result<ConcreteDataType> {
let input_types = self.input_types()?;
ensure!(input_types.len() == 1, InvalidInputStateSnafu);

View File

@@ -64,55 +64,24 @@ pub(crate) struct AggregateFunctions;
impl AggregateFunctions {
pub fn register(registry: &FunctionRegistry) {
registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new(
"median",
1,
Arc::new(|| Arc::new(MedianAccumulatorCreator::default())),
)));
registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new(
"diff",
1,
Arc::new(|| Arc::new(DiffAccumulatorCreator::default())),
)));
registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new(
"mean",
1,
Arc::new(|| Arc::new(MeanAccumulatorCreator::default())),
)));
registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new(
"polyval",
2,
Arc::new(|| Arc::new(PolyvalAccumulatorCreator::default())),
)));
registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new(
"argmax",
1,
Arc::new(|| Arc::new(ArgmaxAccumulatorCreator::default())),
)));
registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new(
"argmin",
1,
Arc::new(|| Arc::new(ArgminAccumulatorCreator::default())),
)));
registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new(
"diff",
1,
Arc::new(|| Arc::new(DiffAccumulatorCreator::default())),
)));
registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new(
"percentile",
2,
Arc::new(|| Arc::new(PercentileAccumulatorCreator::default())),
)));
registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new(
"scipystatsnormcdf",
2,
Arc::new(|| Arc::new(ScipyStatsNormCdfAccumulatorCreator::default())),
)));
registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new(
"scipystatsnormpdf",
2,
Arc::new(|| Arc::new(ScipyStatsNormPdfAccumulatorCreator::default())),
)));
macro_rules! register_aggr_func {
($name :expr, $arg_count :expr, $creator :ty) => {
registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new(
$name,
$arg_count,
Arc::new(|| Arc::new(<$creator>::default())),
)));
};
}
register_aggr_func!("median", 1, MedianAccumulatorCreator);
register_aggr_func!("diff", 1, DiffAccumulatorCreator);
register_aggr_func!("mean", 1, MeanAccumulatorCreator);
register_aggr_func!("polyval", 2, PolyvalAccumulatorCreator);
register_aggr_func!("argmax", 1, ArgmaxAccumulatorCreator);
register_aggr_func!("argmin", 1, ArgminAccumulatorCreator);
register_aggr_func!("percentile", 2, PercentileAccumulatorCreator);
register_aggr_func!("scipystatsnormcdf", 2, ScipyStatsNormCdfAccumulatorCreator);
register_aggr_func!("scipystatsnormpdf", 2, ScipyStatsNormPdfAccumulatorCreator);
}
}

View File

@@ -2,10 +2,10 @@ use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::sync::Arc;
use arc_swap::ArcSwapOption;
use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore};
use common_query::error::{
self, BadAccumulatorImplSnafu, CreateAccumulatorSnafu, DowncastVectorSnafu,
FromScalarValueSnafu, InvalidInputColSnafu, InvalidInputStateSnafu, Result,
FromScalarValueSnafu, InvalidInputColSnafu, Result,
};
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*;
@@ -231,10 +231,9 @@ where
}
}
#[derive(Debug, Default)]
pub struct PercentileAccumulatorCreator {
input_types: ArcSwapOption<Vec<ConcreteDataType>>,
}
#[as_aggr_func_creator]
#[derive(Debug, Default, AggrFuncTypeStore)]
pub struct PercentileAccumulatorCreator {}
impl AggregateFunctionCreator for PercentileAccumulatorCreator {
fn creator(&self) -> AccumulatorCreatorFunction {
@@ -257,23 +256,6 @@ impl AggregateFunctionCreator for PercentileAccumulatorCreator {
creator
}
fn input_types(&self) -> Result<Vec<ConcreteDataType>> {
let input_types = self.input_types.load();
ensure!(input_types.is_some(), InvalidInputStateSnafu);
Ok(input_types.as_ref().unwrap().as_ref().clone())
}
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> Result<()> {
let old = self.input_types.swap(Some(Arc::new(input_types.clone())));
if let Some(old) = old {
ensure!(old.len() == input_types.len(), InvalidInputStateSnafu);
for (x, y) in old.iter().zip(input_types.iter()) {
ensure!(x == y, InvalidInputStateSnafu);
}
}
Ok(())
}
fn output_type(&self) -> Result<ConcreteDataType> {
let input_types = self.input_types()?;
ensure!(input_types.len() == 2, InvalidInputStateSnafu);

View File

@@ -1,10 +1,10 @@
use std::marker::PhantomData;
use std::sync::Arc;
use arc_swap::ArcSwapOption;
use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore};
use common_query::error::{
self, BadAccumulatorImplSnafu, CreateAccumulatorSnafu, DowncastVectorSnafu,
FromScalarValueSnafu, InvalidInputColSnafu, InvalidInputStateSnafu, Result,
FromScalarValueSnafu, InvalidInputColSnafu, Result,
};
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*;
@@ -187,10 +187,9 @@ where
}
}
#[derive(Debug, Default)]
pub struct PolyvalAccumulatorCreator {
input_types: ArcSwapOption<Vec<ConcreteDataType>>,
}
#[as_aggr_func_creator]
#[derive(Debug, Default, AggrFuncTypeStore)]
pub struct PolyvalAccumulatorCreator {}
impl AggregateFunctionCreator for PolyvalAccumulatorCreator {
fn creator(&self) -> AccumulatorCreatorFunction {
@@ -213,23 +212,6 @@ impl AggregateFunctionCreator for PolyvalAccumulatorCreator {
creator
}
fn input_types(&self) -> Result<Vec<ConcreteDataType>> {
let input_types = self.input_types.load();
ensure!(input_types.is_some(), InvalidInputStateSnafu);
Ok(input_types.as_ref().unwrap().as_ref().clone())
}
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> Result<()> {
let old = self.input_types.swap(Some(Arc::new(input_types.clone())));
if let Some(old) = old {
ensure!(old.len() == input_types.len(), InvalidInputStateSnafu);
for (x, y) in old.iter().zip(input_types.iter()) {
ensure!(x == y, InvalidInputStateSnafu);
}
}
Ok(())
}
fn output_type(&self) -> Result<ConcreteDataType> {
let input_types = self.input_types()?;
ensure!(input_types.len() == 2, InvalidInputStateSnafu);

View File

@@ -1,10 +1,9 @@
use std::sync::Arc;
use arc_swap::ArcSwapOption;
use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore};
use common_query::error::{
self, BadAccumulatorImplSnafu, CreateAccumulatorSnafu, DowncastVectorSnafu,
FromScalarValueSnafu, GenerateFunctionSnafu, InvalidInputColSnafu, InvalidInputStateSnafu,
Result,
FromScalarValueSnafu, GenerateFunctionSnafu, InvalidInputColSnafu, Result,
};
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*;
@@ -174,10 +173,9 @@ where
}
}
#[derive(Debug, Default)]
pub struct ScipyStatsNormCdfAccumulatorCreator {
input_types: ArcSwapOption<Vec<ConcreteDataType>>,
}
#[as_aggr_func_creator]
#[derive(Debug, Default, AggrFuncTypeStore)]
pub struct ScipyStatsNormCdfAccumulatorCreator {}
impl AggregateFunctionCreator for ScipyStatsNormCdfAccumulatorCreator {
fn creator(&self) -> AccumulatorCreatorFunction {
@@ -200,23 +198,6 @@ impl AggregateFunctionCreator for ScipyStatsNormCdfAccumulatorCreator {
creator
}
fn input_types(&self) -> Result<Vec<ConcreteDataType>> {
let input_types = self.input_types.load();
ensure!(input_types.is_some(), InvalidInputStateSnafu);
Ok(input_types.as_ref().unwrap().as_ref().clone())
}
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> Result<()> {
let old = self.input_types.swap(Some(Arc::new(input_types.clone())));
if let Some(old) = old {
ensure!(old.len() == input_types.len(), InvalidInputStateSnafu);
for (x, y) in old.iter().zip(input_types.iter()) {
ensure!(x == y, InvalidInputStateSnafu);
}
}
Ok(())
}
fn output_type(&self) -> Result<ConcreteDataType> {
let input_types = self.input_types()?;
ensure!(input_types.len() == 2, InvalidInputStateSnafu);

View File

@@ -1,10 +1,9 @@
use std::sync::Arc;
use arc_swap::ArcSwapOption;
use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore};
use common_query::error::{
self, BadAccumulatorImplSnafu, CreateAccumulatorSnafu, DowncastVectorSnafu,
FromScalarValueSnafu, GenerateFunctionSnafu, InvalidInputColSnafu, InvalidInputStateSnafu,
Result,
FromScalarValueSnafu, GenerateFunctionSnafu, InvalidInputColSnafu, Result,
};
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*;
@@ -174,10 +173,9 @@ where
}
}
#[derive(Debug, Default)]
pub struct ScipyStatsNormPdfAccumulatorCreator {
input_types: ArcSwapOption<Vec<ConcreteDataType>>,
}
#[as_aggr_func_creator]
#[derive(Debug, Default, AggrFuncTypeStore)]
pub struct ScipyStatsNormPdfAccumulatorCreator {}
impl AggregateFunctionCreator for ScipyStatsNormPdfAccumulatorCreator {
fn creator(&self) -> AccumulatorCreatorFunction {
@@ -200,23 +198,6 @@ impl AggregateFunctionCreator for ScipyStatsNormPdfAccumulatorCreator {
creator
}
fn input_types(&self) -> Result<Vec<ConcreteDataType>> {
let input_types = self.input_types.load();
ensure!(input_types.is_some(), InvalidInputStateSnafu);
Ok(input_types.as_ref().unwrap().as_ref().clone())
}
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> Result<()> {
let old = self.input_types.swap(Some(Arc::new(input_types.clone())));
if let Some(old) = old {
ensure!(old.len() == input_types.len(), InvalidInputStateSnafu);
for (x, y) in old.iter().zip(input_types.iter()) {
ensure!(x == y, InvalidInputStateSnafu);
}
}
Ok(())
}
fn output_type(&self) -> Result<ConcreteDataType> {
let input_types = self.input_types()?;
ensure!(input_types.len() == 2, InvalidInputStateSnafu);

View File

@@ -19,4 +19,4 @@ statrs = "0.15"
[dev-dependencies]
tokio = { version = "1.0", features = ["full"] }
common-base = {path = "../base"}
common-base = { path = "../base" }

View File

@@ -45,20 +45,14 @@ pub trait Accumulator: Send + Sync + Debug {
}
/// An `AggregateFunctionCreator` dynamically creates `Accumulator`.
/// DataFusion does not provide the input data's types when creating Accumulator, we have to stores
/// it somewhere else ourself. So an `AggregateFunctionCreator` often has a companion struct, that
/// can store the input data types, and knows the output and states types of an Accumulator.
/// That's how we create the Accumulator generically.
pub trait AggregateFunctionCreator: Send + Sync + Debug {
///
/// An `AggregateFunctionCreator` often has a companion struct, that
/// can store the input data types (impl [AggrFuncTypeStore]), and knows the output and states
/// types of an Accumulator.
pub trait AggregateFunctionCreator: AggrFuncTypeStore {
/// Create a function that can create a new accumulator with some input data type.
fn creator(&self) -> AccumulatorCreatorFunction;
/// Get the input data type of the Accumulator.
fn input_types(&self) -> Result<Vec<ConcreteDataType>>;
/// Store the input data type that is provided by DataFusion at runtime.
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> Result<()>;
/// Get the Accumulator's output data type.
fn output_type(&self) -> Result<ConcreteDataType>;
@@ -66,6 +60,20 @@ pub trait AggregateFunctionCreator: Send + Sync + Debug {
fn state_types(&self) -> Result<Vec<ConcreteDataType>>;
}
/// `AggrFuncTypeStore` stores the aggregate function's input data's types.
///
/// When creating Accumulator generically, we have to know the input data's types.
/// However, DataFusion does not provide the input data's types at the time of creating Accumulator.
/// To solve the problem, we store the datatypes upfront here.
pub trait AggrFuncTypeStore: Send + Sync + Debug {
/// Get the input data types of the Accumulator.
fn input_types(&self) -> Result<Vec<ConcreteDataType>>;
/// Store the input data types that are provided by DataFusion at runtime (when it is evaluating
/// return type function).
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> Result<()>;
}
pub fn make_accumulator_function(
creator: Arc<dyn AggregateFunctionCreator>,
) -> AccumulatorFunctionImpl {

View File

@@ -1,4 +1,4 @@
mod accumulator;
pub mod accumulator;
mod expr;
mod udaf;
mod udf;
@@ -177,11 +177,7 @@ mod tests {
#[derive(Debug)]
struct DummyAccumulatorCreator;
impl AggregateFunctionCreator for DummyAccumulatorCreator {
fn creator(&self) -> AccumulatorCreatorFunction {
Arc::new(|_| Ok(Box::new(DummyAccumulator)))
}
impl AggrFuncTypeStore for DummyAccumulatorCreator {
fn input_types(&self) -> Result<Vec<ConcreteDataType>> {
Ok(vec![ConcreteDataType::float64_datatype()])
}
@@ -189,6 +185,12 @@ mod tests {
fn set_input_types(&self, _: Vec<ConcreteDataType>) -> Result<()> {
Ok(())
}
}
impl AggregateFunctionCreator for DummyAccumulatorCreator {
fn creator(&self) -> AccumulatorCreatorFunction {
Arc::new(|_| Ok(Box::new(DummyAccumulator)))
}
fn output_type(&self) -> Result<ConcreteDataType> {
Ok(self.input_types()?.into_iter().next().unwrap())

View File

@@ -32,6 +32,7 @@ tokio = "1.0"
[dev-dependencies]
approx_eq = "0.1"
common-function-macro = { path = "../common/function-macro" }
num = "0.4"
num-traits = "0.2"
format_num = "0.1"

View File

@@ -2,12 +2,12 @@ use std::fmt::Debug;
use std::marker::PhantomData;
use std::sync::Arc;
use arc_swap::ArcSwapOption;
use catalog::memory::{MemoryCatalogList, MemoryCatalogProvider, MemorySchemaProvider};
use catalog::{
CatalogList, CatalogProvider, SchemaProvider, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME,
};
use common_function::scalars::aggregate::AggregateFunctionMeta;
use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore};
use common_query::error::CreateAccumulatorSnafu;
use common_query::error::Result as QueryResult;
use common_query::logical_plan::Accumulator;
@@ -54,10 +54,9 @@ where
}
}
#[derive(Debug, Default)]
struct MySumAccumulatorCreator {
input_type: ArcSwapOption<Vec<ConcreteDataType>>,
}
#[as_aggr_func_creator]
#[derive(Debug, Default, AggrFuncTypeStore)]
struct MySumAccumulatorCreator {}
impl AggregateFunctionCreator for MySumAccumulatorCreator {
fn creator(&self) -> AccumulatorCreatorFunction {
@@ -80,26 +79,6 @@ impl AggregateFunctionCreator for MySumAccumulatorCreator {
creator
}
fn input_types(&self) -> QueryResult<Vec<ConcreteDataType>> {
Ok(self.input_type
.load()
.as_ref()
.expect("input_type is not present, check if DataFusion has changed its UDAF execution logic")
.as_ref()
.clone())
}
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> QueryResult<()> {
let old = self.input_type.swap(Some(Arc::new(input_types.clone())));
if let Some(old) = old {
assert_eq!(old.len(), input_types.len());
old.iter().zip(input_types.iter()).for_each(|(x, y)|
assert_eq!(x, y, "input type {:?} != {:?}, check if DataFusion has changed its UDAF execution logic", x, y)
);
}
Ok(())
}
fn output_type(&self) -> QueryResult<ConcreteDataType> {
let input_type = &self.input_types()?[0];
with_match_primitive_type_id!(