From 64b6b2afe1990eb321225cceadca4880afab1f07 Mon Sep 17 00:00:00 2001 From: LFC Date: Tue, 13 Sep 2022 10:39:44 +0800 Subject: [PATCH] feat: procedure macro to help writing UDAF (#242) * feat: procedure macro to help writing UDAF * resolve code review comments Co-authored-by: luofucong --- Cargo.lock | 15 ++++ Cargo.toml | 1 + .../how-to/how-to-write-aggregate-function.md | 16 +++-- src/common/function-macro/Cargo.toml | 18 +++++ src/common/function-macro/src/lib.rs | 71 +++++++++++++++++++ .../function-macro/tests/test_derive.rs | 14 ++++ src/common/function/Cargo.toml | 1 + .../function/src/scalars/aggregate/argmax.rs | 30 ++------ .../function/src/scalars/aggregate/argmin.rs | 30 ++------ .../function/src/scalars/aggregate/diff.rs | 30 ++------ .../function/src/scalars/aggregate/mean.rs | 29 ++------ .../function/src/scalars/aggregate/median.rs | 29 ++------ .../function/src/scalars/aggregate/mod.rs | 69 +++++------------- .../src/scalars/aggregate/percentile.rs | 28 ++------ .../function/src/scalars/aggregate/polyval.rs | 28 ++------ .../scalars/aggregate/scipy_stats_norm_cdf.rs | 29 ++------ .../scalars/aggregate/scipy_stats_norm_pdf.rs | 29 ++------ src/common/query/Cargo.toml | 2 +- .../query/src/logical_plan/accumulator.rs | 30 +++++--- src/common/query/src/logical_plan/mod.rs | 14 ++-- src/query/Cargo.toml | 1 + src/query/tests/my_sum_udaf_example.rs | 29 ++------ 22 files changed, 226 insertions(+), 317 deletions(-) create mode 100644 src/common/function-macro/Cargo.toml create mode 100644 src/common/function-macro/src/lib.rs create mode 100644 src/common/function-macro/tests/test_derive.rs diff --git a/Cargo.lock b/Cargo.lock index 8ad2d13759..076d598f97 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -816,6 +816,7 @@ dependencies = [ "arrow2", "chrono-tz", "common-error", + "common-function-macro", "common-query", "datafusion-common", "datatypes", @@ -836,6 +837,19 @@ dependencies = [ "statrs", ] +[[package]] +name = "common-function-macro" +version = "0.1.0" +dependencies = [ + "arc-swap", + "common-query", + "datatypes", + "quote", + "snafu", + "static_assertions", + "syn", +] + [[package]] name = "common-grpc" version = "0.1.0" @@ -3670,6 +3684,7 @@ dependencies = [ "catalog", "common-error", "common-function", + "common-function-macro", "common-query", "common-recordbatch", "common-telemetry", diff --git a/Cargo.toml b/Cargo.toml index 0b1bd2b318..8baa0c6bba 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ members = [ "src/common/base", "src/common/error", "src/common/function", + "src/common/function-macro", "src/common/grpc", "src/common/query", "src/common/recordbatch", diff --git a/docs/how-to/how-to-write-aggregate-function.md b/docs/how-to/how-to-write-aggregate-function.md index 1d0125ef6a..15624353cd 100644 --- a/docs/how-to/how-to-write-aggregate-function.md +++ b/docs/how-to/how-to-write-aggregate-function.md @@ -8,27 +8,29 @@ So is there a way we can make an aggregate function that automatically match the # 1. Impl `AggregateFunctionCreator` trait for your accumulator creator. -You must first define a struct that can store the input data's type. For example, +You must first define a struct that will be used to create your accumulator. For example, ```Rust -struct MySumAccumulatorCreator { - input_types: ArcSwapOption>, -} +#[as_aggr_func_creator] +#[derive(Debug, AggrFuncTypeStore)] +struct MySumAccumulatorCreator {} ``` +Attribute macro `#[as_aggr_func_creator]` and derive macro `#[derive(Debug, AggrFuncTypeStore)]` must both annotated on the struct. They work together to provide a storage of aggregate function's input data types, which are needed for creating generic accumulator later. + +> Note that the `as_aggr_func_creator` macro will add fields to the struct, so the struct cannot be defined as an empty struct without field like `struct Foo;`, neither as a new type like `struct Foo(bar)`. + Then impl `AggregateFunctionCreator` trait on it. The definition of the trait is: ```Rust pub trait AggregateFunctionCreator: Send + Sync + Debug { fn creator(&self) -> AccumulatorCreatorFunction; - fn input_types(&self) -> Vec; - fn set_input_types(&self, input_types: Vec); fn output_type(&self) -> ConcreteDataType; fn state_types(&self) -> Vec; } ``` -our query engine will call `set_input_types` the very first, so you can use input data's type in methods that return output type and state types. +You can use input data's type in methods that return output type and state types (just invoke `input_types()`). The output type is aggregate function's output data's type. For example, `SUM` aggregate function's output type is `u64` for a `u32` datatype column. The state types are accumulator's internal states' types. Take `AVG` aggregate function on a `i32` column as example, it's state types are `i64` (for sum) and `u64` (for count). diff --git a/src/common/function-macro/Cargo.toml b/src/common/function-macro/Cargo.toml new file mode 100644 index 0000000000..2bbb4e2d67 --- /dev/null +++ b/src/common/function-macro/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "common-function-macro" +version = "0.1.0" +edition = "2021" + +[lib] +proc-macro = true + +[dependencies] +common-query = { path = "../query" } +datatypes = { path = "../../datatypes" } +quote = "1.0" +snafu = { version = "0.7", features = ["backtraces"] } +syn = "1.0" + +[dev-dependencies] +arc-swap = "1.0" +static_assertions = "1.1.0" diff --git a/src/common/function-macro/src/lib.rs b/src/common/function-macro/src/lib.rs new file mode 100644 index 0000000000..5d95d346b3 --- /dev/null +++ b/src/common/function-macro/src/lib.rs @@ -0,0 +1,71 @@ +use proc_macro::TokenStream; +use quote::{quote, quote_spanned}; +use syn::parse::Parser; +use syn::spanned::Spanned; +use syn::{parse_macro_input, DeriveInput, ItemStruct}; + +/// Make struct implemented trait [AggrFuncTypeStore], which is necessary when writing UDAF. +/// This derive macro is expect to be used along with attribute macro [as_aggr_func_creator]. +#[proc_macro_derive(AggrFuncTypeStore)] +pub fn aggr_func_type_store_derive(input: TokenStream) -> TokenStream { + let ast = parse_macro_input!(input as DeriveInput); + impl_aggr_func_type_store(&ast) +} + +fn impl_aggr_func_type_store(ast: &DeriveInput) -> TokenStream { + let name = &ast.ident; + let gen = quote! { + use common_query::logical_plan::accumulator::AggrFuncTypeStore; + use common_query::error::{InvalidInputStateSnafu, Error as QueryError}; + use datatypes::prelude::ConcreteDataType; + + impl AggrFuncTypeStore for #name { + fn input_types(&self) -> std::result::Result, QueryError> { + let input_types = self.input_types.load(); + snafu::ensure!(input_types.is_some(), InvalidInputStateSnafu); + Ok(input_types.as_ref().unwrap().as_ref().clone()) + } + + fn set_input_types(&self, input_types: Vec) -> std::result::Result<(), QueryError> { + let old = self.input_types.swap(Some(std::sync::Arc::new(input_types.clone()))); + if let Some(old) = old { + snafu::ensure!(old.len() == input_types.len(), InvalidInputStateSnafu); + for (x, y) in old.iter().zip(input_types.iter()) { + snafu::ensure!(x == y, InvalidInputStateSnafu); + } + } + Ok(()) + } + } + }; + gen.into() +} + +/// A struct can be used as a creator for aggregate function if it has been annotated with this +/// attribute first. This attribute add a necessary field which is intended to store the input +/// data's types to the struct. +/// This attribute is expected to be used along with derive macro [AggrFuncTypeStore]. +#[proc_macro_attribute] +pub fn as_aggr_func_creator(_args: TokenStream, input: TokenStream) -> TokenStream { + let mut item_struct = parse_macro_input!(input as ItemStruct); + if let syn::Fields::Named(ref mut fields) = item_struct.fields { + let result = syn::Field::parse_named.parse2(quote! { + input_types: arc_swap::ArcSwapOption> + }); + match result { + Ok(field) => fields.named.push(field), + Err(e) => return e.into_compile_error().into(), + } + } else { + return quote_spanned!( + item_struct.fields.span() => compile_error!( + "This attribute macro needs to add fields to the its annotated struct, \ + so the struct must have \"{}\".") + ) + .into(); + } + quote! { + #item_struct + } + .into() +} diff --git a/src/common/function-macro/tests/test_derive.rs b/src/common/function-macro/tests/test_derive.rs new file mode 100644 index 0000000000..58199aa24c --- /dev/null +++ b/src/common/function-macro/tests/test_derive.rs @@ -0,0 +1,14 @@ +use common_function_macro::as_aggr_func_creator; +use common_function_macro::AggrFuncTypeStore; +use static_assertions::{assert_fields, assert_impl_all}; + +#[as_aggr_func_creator] +#[derive(Debug, Default, AggrFuncTypeStore)] +struct Foo {} + +#[test] +fn test_derive() { + Foo::default(); + assert_fields!(Foo: input_types); + assert_impl_all!(Foo: std::fmt::Debug, Default, AggrFuncTypeStore); +} diff --git a/src/common/function/Cargo.toml b/src/common/function/Cargo.toml index e462b6be17..c25e63e8a8 100644 --- a/src/common/function/Cargo.toml +++ b/src/common/function/Cargo.toml @@ -13,6 +13,7 @@ arc-swap = "1.0" chrono-tz = "0.6" common-error = { path = "../error" } common-query = { path = "../query" } +common-function-macro = { path = "../function-macro" } datafusion-common = { git = "https://github.com/apache/arrow-datafusion.git" , branch = "arrow2" } datatypes = { path = "../../datatypes" } libc = "0.2" diff --git a/src/common/function/src/scalars/aggregate/argmax.rs b/src/common/function/src/scalars/aggregate/argmax.rs index 2946864736..390591c676 100644 --- a/src/common/function/src/scalars/aggregate/argmax.rs +++ b/src/common/function/src/scalars/aggregate/argmax.rs @@ -1,10 +1,8 @@ use std::cmp::Ordering; use std::sync::Arc; -use arc_swap::ArcSwapOption; -use common_query::error::{ - BadAccumulatorImplSnafu, CreateAccumulatorSnafu, InvalidInputStateSnafu, Result, -}; +use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore}; +use common_query::error::{BadAccumulatorImplSnafu, CreateAccumulatorSnafu, Result}; use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; use common_query::prelude::*; use datatypes::vectors::ConstantVector; @@ -98,10 +96,9 @@ where } } -#[derive(Debug, Default)] -pub struct ArgmaxAccumulatorCreator { - input_types: ArcSwapOption>, -} +#[as_aggr_func_creator] +#[derive(Debug, Default, AggrFuncTypeStore)] +pub struct ArgmaxAccumulatorCreator {} impl AggregateFunctionCreator for ArgmaxAccumulatorCreator { fn creator(&self) -> AccumulatorCreatorFunction { @@ -124,23 +121,6 @@ impl AggregateFunctionCreator for ArgmaxAccumulatorCreator { creator } - fn input_types(&self) -> Result> { - let input_types = self.input_types.load(); - ensure!(input_types.is_some(), InvalidInputStateSnafu); - Ok(input_types.as_ref().unwrap().as_ref().clone()) - } - - fn set_input_types(&self, input_types: Vec) -> Result<()> { - let old = self.input_types.swap(Some(Arc::new(input_types.clone()))); - if let Some(old) = old { - ensure!(old.len() == input_types.len(), InvalidInputStateSnafu); - for (x, y) in old.iter().zip(input_types.iter()) { - ensure!(x == y, InvalidInputStateSnafu); - } - } - Ok(()) - } - fn output_type(&self) -> Result { Ok(ConcreteDataType::uint64_datatype()) } diff --git a/src/common/function/src/scalars/aggregate/argmin.rs b/src/common/function/src/scalars/aggregate/argmin.rs index 63ccc7c190..a3d8457ce8 100644 --- a/src/common/function/src/scalars/aggregate/argmin.rs +++ b/src/common/function/src/scalars/aggregate/argmin.rs @@ -1,10 +1,8 @@ use std::cmp::Ordering; use std::sync::Arc; -use arc_swap::ArcSwapOption; -use common_query::error::{ - BadAccumulatorImplSnafu, CreateAccumulatorSnafu, InvalidInputStateSnafu, Result, -}; +use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore}; +use common_query::error::{BadAccumulatorImplSnafu, CreateAccumulatorSnafu, Result}; use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; use common_query::prelude::*; use datatypes::vectors::ConstantVector; @@ -107,10 +105,9 @@ where } } -#[derive(Debug, Default)] -pub struct ArgminAccumulatorCreator { - input_types: ArcSwapOption>, -} +#[as_aggr_func_creator] +#[derive(Debug, Default, AggrFuncTypeStore)] +pub struct ArgminAccumulatorCreator {} impl AggregateFunctionCreator for ArgminAccumulatorCreator { fn creator(&self) -> AccumulatorCreatorFunction { @@ -133,23 +130,6 @@ impl AggregateFunctionCreator for ArgminAccumulatorCreator { creator } - fn input_types(&self) -> Result> { - let input_types = self.input_types.load(); - ensure!(input_types.is_some(), InvalidInputStateSnafu); - Ok(input_types.as_ref().unwrap().as_ref().clone()) - } - - fn set_input_types(&self, input_types: Vec) -> Result<()> { - let old = self.input_types.swap(Some(Arc::new(input_types.clone()))); - if let Some(old) = old { - ensure!(old.len() == input_types.len(), InvalidInputStateSnafu); - for (x, y) in old.iter().zip(input_types.iter()) { - ensure!(x == y, InvalidInputStateSnafu); - } - } - Ok(()) - } - fn output_type(&self) -> Result { Ok(ConcreteDataType::uint32_datatype()) } diff --git a/src/common/function/src/scalars/aggregate/diff.rs b/src/common/function/src/scalars/aggregate/diff.rs index e414434a9c..e43c66b4e4 100644 --- a/src/common/function/src/scalars/aggregate/diff.rs +++ b/src/common/function/src/scalars/aggregate/diff.rs @@ -1,10 +1,9 @@ use std::marker::PhantomData; use std::sync::Arc; -use arc_swap::ArcSwapOption; +use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore}; use common_query::error::{ - CreateAccumulatorSnafu, DowncastVectorSnafu, FromScalarValueSnafu, InvalidInputStateSnafu, - Result, + CreateAccumulatorSnafu, DowncastVectorSnafu, FromScalarValueSnafu, Result, }; use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; use common_query::prelude::*; @@ -118,10 +117,9 @@ where } } -#[derive(Debug, Default)] -pub struct DiffAccumulatorCreator { - input_types: ArcSwapOption>, -} +#[as_aggr_func_creator] +#[derive(Debug, Default, AggrFuncTypeStore)] +pub struct DiffAccumulatorCreator {} impl AggregateFunctionCreator for DiffAccumulatorCreator { fn creator(&self) -> AccumulatorCreatorFunction { @@ -144,24 +142,6 @@ impl AggregateFunctionCreator for DiffAccumulatorCreator { creator } - fn input_types(&self) -> Result> { - let input_types = self.input_types.load(); - ensure!(input_types.is_some(), InvalidInputStateSnafu); - - Ok(input_types.as_ref().unwrap().as_ref().clone()) - } - - fn set_input_types(&self, input_types: Vec) -> Result<()> { - let old = self.input_types.swap(Some(Arc::new(input_types.clone()))); - if let Some(old) = old { - ensure!(old.len() != input_types.len(), InvalidInputStateSnafu); - for (x, y) in old.iter().zip(input_types.iter()) { - ensure!(x == y, InvalidInputStateSnafu); - } - } - Ok(()) - } - fn output_type(&self) -> Result { let input_types = self.input_types()?; ensure!(input_types.len() == 1, InvalidInputStateSnafu); diff --git a/src/common/function/src/scalars/aggregate/mean.rs b/src/common/function/src/scalars/aggregate/mean.rs index 64b4ab528f..f4cf0839ae 100644 --- a/src/common/function/src/scalars/aggregate/mean.rs +++ b/src/common/function/src/scalars/aggregate/mean.rs @@ -1,10 +1,9 @@ use std::marker::PhantomData; use std::sync::Arc; -use arc_swap::ArcSwapOption; +use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore}; use common_query::error::{ - BadAccumulatorImplSnafu, CreateAccumulatorSnafu, DowncastVectorSnafu, InvalidInputStateSnafu, - Result, + BadAccumulatorImplSnafu, CreateAccumulatorSnafu, DowncastVectorSnafu, Result, }; use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; use common_query::prelude::*; @@ -125,10 +124,9 @@ where } } -#[derive(Debug, Default)] -pub struct MeanAccumulatorCreator { - input_types: ArcSwapOption>, -} +#[as_aggr_func_creator] +#[derive(Debug, Default, AggrFuncTypeStore)] +pub struct MeanAccumulatorCreator {} impl AggregateFunctionCreator for MeanAccumulatorCreator { fn creator(&self) -> AccumulatorCreatorFunction { @@ -151,23 +149,6 @@ impl AggregateFunctionCreator for MeanAccumulatorCreator { creator } - fn input_types(&self) -> Result> { - let input_types = self.input_types.load(); - ensure!(input_types.is_some(), InvalidInputStateSnafu); - Ok(input_types.as_ref().unwrap().as_ref().clone()) - } - - fn set_input_types(&self, input_types: Vec) -> Result<()> { - let old = self.input_types.swap(Some(Arc::new(input_types.clone()))); - if let Some(old) = old { - ensure!(old.len() == input_types.len(), InvalidInputStateSnafu); - for (x, y) in old.iter().zip(input_types.iter()) { - ensure!(x == y, InvalidInputStateSnafu); - } - } - Ok(()) - } - fn output_type(&self) -> Result { let input_types = self.input_types()?; ensure!(input_types.len() == 1, InvalidInputStateSnafu); diff --git a/src/common/function/src/scalars/aggregate/median.rs b/src/common/function/src/scalars/aggregate/median.rs index 04ab7632e6..ef2e1bf3f2 100644 --- a/src/common/function/src/scalars/aggregate/median.rs +++ b/src/common/function/src/scalars/aggregate/median.rs @@ -2,10 +2,9 @@ use std::cmp::Reverse; use std::collections::BinaryHeap; use std::sync::Arc; -use arc_swap::ArcSwapOption; +use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore}; use common_query::error::{ - CreateAccumulatorSnafu, DowncastVectorSnafu, FromScalarValueSnafu, InvalidInputStateSnafu, - Result, + CreateAccumulatorSnafu, DowncastVectorSnafu, FromScalarValueSnafu, Result, }; use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; use common_query::prelude::*; @@ -175,10 +174,9 @@ where } } -#[derive(Debug, Default)] -pub struct MedianAccumulatorCreator { - input_types: ArcSwapOption>, -} +#[as_aggr_func_creator] +#[derive(Debug, Default, AggrFuncTypeStore)] +pub struct MedianAccumulatorCreator {} impl AggregateFunctionCreator for MedianAccumulatorCreator { fn creator(&self) -> AccumulatorCreatorFunction { @@ -201,23 +199,6 @@ impl AggregateFunctionCreator for MedianAccumulatorCreator { creator } - fn input_types(&self) -> Result> { - let input_types = self.input_types.load(); - ensure!(input_types.is_some(), InvalidInputStateSnafu); - Ok(input_types.as_ref().unwrap().as_ref().clone()) - } - - fn set_input_types(&self, input_types: Vec) -> Result<()> { - let old = self.input_types.swap(Some(Arc::new(input_types.clone()))); - if let Some(old) = old { - ensure!(old.len() == input_types.len(), InvalidInputStateSnafu); - for (x, y) in old.iter().zip(input_types.iter()) { - ensure!(x == y, InvalidInputStateSnafu); - } - } - Ok(()) - } - fn output_type(&self) -> Result { let input_types = self.input_types()?; ensure!(input_types.len() == 1, InvalidInputStateSnafu); diff --git a/src/common/function/src/scalars/aggregate/mod.rs b/src/common/function/src/scalars/aggregate/mod.rs index 9779dd7463..06dbf36561 100644 --- a/src/common/function/src/scalars/aggregate/mod.rs +++ b/src/common/function/src/scalars/aggregate/mod.rs @@ -64,55 +64,24 @@ pub(crate) struct AggregateFunctions; impl AggregateFunctions { pub fn register(registry: &FunctionRegistry) { - registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new( - "median", - 1, - Arc::new(|| Arc::new(MedianAccumulatorCreator::default())), - ))); - registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new( - "diff", - 1, - Arc::new(|| Arc::new(DiffAccumulatorCreator::default())), - ))); - registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new( - "mean", - 1, - Arc::new(|| Arc::new(MeanAccumulatorCreator::default())), - ))); - registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new( - "polyval", - 2, - Arc::new(|| Arc::new(PolyvalAccumulatorCreator::default())), - ))); - registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new( - "argmax", - 1, - Arc::new(|| Arc::new(ArgmaxAccumulatorCreator::default())), - ))); - registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new( - "argmin", - 1, - Arc::new(|| Arc::new(ArgminAccumulatorCreator::default())), - ))); - registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new( - "diff", - 1, - Arc::new(|| Arc::new(DiffAccumulatorCreator::default())), - ))); - registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new( - "percentile", - 2, - Arc::new(|| Arc::new(PercentileAccumulatorCreator::default())), - ))); - registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new( - "scipystatsnormcdf", - 2, - Arc::new(|| Arc::new(ScipyStatsNormCdfAccumulatorCreator::default())), - ))); - registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new( - "scipystatsnormpdf", - 2, - Arc::new(|| Arc::new(ScipyStatsNormPdfAccumulatorCreator::default())), - ))); + macro_rules! register_aggr_func { + ($name :expr, $arg_count :expr, $creator :ty) => { + registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new( + $name, + $arg_count, + Arc::new(|| Arc::new(<$creator>::default())), + ))); + }; + } + + register_aggr_func!("median", 1, MedianAccumulatorCreator); + register_aggr_func!("diff", 1, DiffAccumulatorCreator); + register_aggr_func!("mean", 1, MeanAccumulatorCreator); + register_aggr_func!("polyval", 2, PolyvalAccumulatorCreator); + register_aggr_func!("argmax", 1, ArgmaxAccumulatorCreator); + register_aggr_func!("argmin", 1, ArgminAccumulatorCreator); + register_aggr_func!("percentile", 2, PercentileAccumulatorCreator); + register_aggr_func!("scipystatsnormcdf", 2, ScipyStatsNormCdfAccumulatorCreator); + register_aggr_func!("scipystatsnormpdf", 2, ScipyStatsNormPdfAccumulatorCreator); } } diff --git a/src/common/function/src/scalars/aggregate/percentile.rs b/src/common/function/src/scalars/aggregate/percentile.rs index f14e139faf..3b79a4bdf1 100644 --- a/src/common/function/src/scalars/aggregate/percentile.rs +++ b/src/common/function/src/scalars/aggregate/percentile.rs @@ -2,10 +2,10 @@ use std::cmp::Reverse; use std::collections::BinaryHeap; use std::sync::Arc; -use arc_swap::ArcSwapOption; +use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore}; use common_query::error::{ self, BadAccumulatorImplSnafu, CreateAccumulatorSnafu, DowncastVectorSnafu, - FromScalarValueSnafu, InvalidInputColSnafu, InvalidInputStateSnafu, Result, + FromScalarValueSnafu, InvalidInputColSnafu, Result, }; use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; use common_query::prelude::*; @@ -231,10 +231,9 @@ where } } -#[derive(Debug, Default)] -pub struct PercentileAccumulatorCreator { - input_types: ArcSwapOption>, -} +#[as_aggr_func_creator] +#[derive(Debug, Default, AggrFuncTypeStore)] +pub struct PercentileAccumulatorCreator {} impl AggregateFunctionCreator for PercentileAccumulatorCreator { fn creator(&self) -> AccumulatorCreatorFunction { @@ -257,23 +256,6 @@ impl AggregateFunctionCreator for PercentileAccumulatorCreator { creator } - fn input_types(&self) -> Result> { - let input_types = self.input_types.load(); - ensure!(input_types.is_some(), InvalidInputStateSnafu); - Ok(input_types.as_ref().unwrap().as_ref().clone()) - } - - fn set_input_types(&self, input_types: Vec) -> Result<()> { - let old = self.input_types.swap(Some(Arc::new(input_types.clone()))); - if let Some(old) = old { - ensure!(old.len() == input_types.len(), InvalidInputStateSnafu); - for (x, y) in old.iter().zip(input_types.iter()) { - ensure!(x == y, InvalidInputStateSnafu); - } - } - Ok(()) - } - fn output_type(&self) -> Result { let input_types = self.input_types()?; ensure!(input_types.len() == 2, InvalidInputStateSnafu); diff --git a/src/common/function/src/scalars/aggregate/polyval.rs b/src/common/function/src/scalars/aggregate/polyval.rs index d7b37ecbb4..5a87c49d85 100644 --- a/src/common/function/src/scalars/aggregate/polyval.rs +++ b/src/common/function/src/scalars/aggregate/polyval.rs @@ -1,10 +1,10 @@ use std::marker::PhantomData; use std::sync::Arc; -use arc_swap::ArcSwapOption; +use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore}; use common_query::error::{ self, BadAccumulatorImplSnafu, CreateAccumulatorSnafu, DowncastVectorSnafu, - FromScalarValueSnafu, InvalidInputColSnafu, InvalidInputStateSnafu, Result, + FromScalarValueSnafu, InvalidInputColSnafu, Result, }; use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; use common_query::prelude::*; @@ -187,10 +187,9 @@ where } } -#[derive(Debug, Default)] -pub struct PolyvalAccumulatorCreator { - input_types: ArcSwapOption>, -} +#[as_aggr_func_creator] +#[derive(Debug, Default, AggrFuncTypeStore)] +pub struct PolyvalAccumulatorCreator {} impl AggregateFunctionCreator for PolyvalAccumulatorCreator { fn creator(&self) -> AccumulatorCreatorFunction { @@ -213,23 +212,6 @@ impl AggregateFunctionCreator for PolyvalAccumulatorCreator { creator } - fn input_types(&self) -> Result> { - let input_types = self.input_types.load(); - ensure!(input_types.is_some(), InvalidInputStateSnafu); - Ok(input_types.as_ref().unwrap().as_ref().clone()) - } - - fn set_input_types(&self, input_types: Vec) -> Result<()> { - let old = self.input_types.swap(Some(Arc::new(input_types.clone()))); - if let Some(old) = old { - ensure!(old.len() == input_types.len(), InvalidInputStateSnafu); - for (x, y) in old.iter().zip(input_types.iter()) { - ensure!(x == y, InvalidInputStateSnafu); - } - } - Ok(()) - } - fn output_type(&self) -> Result { let input_types = self.input_types()?; ensure!(input_types.len() == 2, InvalidInputStateSnafu); diff --git a/src/common/function/src/scalars/aggregate/scipy_stats_norm_cdf.rs b/src/common/function/src/scalars/aggregate/scipy_stats_norm_cdf.rs index e21f25d3f0..4346b1b004 100644 --- a/src/common/function/src/scalars/aggregate/scipy_stats_norm_cdf.rs +++ b/src/common/function/src/scalars/aggregate/scipy_stats_norm_cdf.rs @@ -1,10 +1,9 @@ use std::sync::Arc; -use arc_swap::ArcSwapOption; +use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore}; use common_query::error::{ self, BadAccumulatorImplSnafu, CreateAccumulatorSnafu, DowncastVectorSnafu, - FromScalarValueSnafu, GenerateFunctionSnafu, InvalidInputColSnafu, InvalidInputStateSnafu, - Result, + FromScalarValueSnafu, GenerateFunctionSnafu, InvalidInputColSnafu, Result, }; use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; use common_query::prelude::*; @@ -174,10 +173,9 @@ where } } -#[derive(Debug, Default)] -pub struct ScipyStatsNormCdfAccumulatorCreator { - input_types: ArcSwapOption>, -} +#[as_aggr_func_creator] +#[derive(Debug, Default, AggrFuncTypeStore)] +pub struct ScipyStatsNormCdfAccumulatorCreator {} impl AggregateFunctionCreator for ScipyStatsNormCdfAccumulatorCreator { fn creator(&self) -> AccumulatorCreatorFunction { @@ -200,23 +198,6 @@ impl AggregateFunctionCreator for ScipyStatsNormCdfAccumulatorCreator { creator } - fn input_types(&self) -> Result> { - let input_types = self.input_types.load(); - ensure!(input_types.is_some(), InvalidInputStateSnafu); - Ok(input_types.as_ref().unwrap().as_ref().clone()) - } - - fn set_input_types(&self, input_types: Vec) -> Result<()> { - let old = self.input_types.swap(Some(Arc::new(input_types.clone()))); - if let Some(old) = old { - ensure!(old.len() == input_types.len(), InvalidInputStateSnafu); - for (x, y) in old.iter().zip(input_types.iter()) { - ensure!(x == y, InvalidInputStateSnafu); - } - } - Ok(()) - } - fn output_type(&self) -> Result { let input_types = self.input_types()?; ensure!(input_types.len() == 2, InvalidInputStateSnafu); diff --git a/src/common/function/src/scalars/aggregate/scipy_stats_norm_pdf.rs b/src/common/function/src/scalars/aggregate/scipy_stats_norm_pdf.rs index 689659cde8..b38b90d0dd 100644 --- a/src/common/function/src/scalars/aggregate/scipy_stats_norm_pdf.rs +++ b/src/common/function/src/scalars/aggregate/scipy_stats_norm_pdf.rs @@ -1,10 +1,9 @@ use std::sync::Arc; -use arc_swap::ArcSwapOption; +use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore}; use common_query::error::{ self, BadAccumulatorImplSnafu, CreateAccumulatorSnafu, DowncastVectorSnafu, - FromScalarValueSnafu, GenerateFunctionSnafu, InvalidInputColSnafu, InvalidInputStateSnafu, - Result, + FromScalarValueSnafu, GenerateFunctionSnafu, InvalidInputColSnafu, Result, }; use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; use common_query::prelude::*; @@ -174,10 +173,9 @@ where } } -#[derive(Debug, Default)] -pub struct ScipyStatsNormPdfAccumulatorCreator { - input_types: ArcSwapOption>, -} +#[as_aggr_func_creator] +#[derive(Debug, Default, AggrFuncTypeStore)] +pub struct ScipyStatsNormPdfAccumulatorCreator {} impl AggregateFunctionCreator for ScipyStatsNormPdfAccumulatorCreator { fn creator(&self) -> AccumulatorCreatorFunction { @@ -200,23 +198,6 @@ impl AggregateFunctionCreator for ScipyStatsNormPdfAccumulatorCreator { creator } - fn input_types(&self) -> Result> { - let input_types = self.input_types.load(); - ensure!(input_types.is_some(), InvalidInputStateSnafu); - Ok(input_types.as_ref().unwrap().as_ref().clone()) - } - - fn set_input_types(&self, input_types: Vec) -> Result<()> { - let old = self.input_types.swap(Some(Arc::new(input_types.clone()))); - if let Some(old) = old { - ensure!(old.len() == input_types.len(), InvalidInputStateSnafu); - for (x, y) in old.iter().zip(input_types.iter()) { - ensure!(x == y, InvalidInputStateSnafu); - } - } - Ok(()) - } - fn output_type(&self) -> Result { let input_types = self.input_types()?; ensure!(input_types.len() == 2, InvalidInputStateSnafu); diff --git a/src/common/query/Cargo.toml b/src/common/query/Cargo.toml index b21c5e3058..c75235c77d 100644 --- a/src/common/query/Cargo.toml +++ b/src/common/query/Cargo.toml @@ -19,4 +19,4 @@ statrs = "0.15" [dev-dependencies] tokio = { version = "1.0", features = ["full"] } -common-base = {path = "../base"} +common-base = { path = "../base" } diff --git a/src/common/query/src/logical_plan/accumulator.rs b/src/common/query/src/logical_plan/accumulator.rs index 949024df95..d0c60902ef 100644 --- a/src/common/query/src/logical_plan/accumulator.rs +++ b/src/common/query/src/logical_plan/accumulator.rs @@ -45,20 +45,14 @@ pub trait Accumulator: Send + Sync + Debug { } /// An `AggregateFunctionCreator` dynamically creates `Accumulator`. -/// DataFusion does not provide the input data's types when creating Accumulator, we have to stores -/// it somewhere else ourself. So an `AggregateFunctionCreator` often has a companion struct, that -/// can store the input data types, and knows the output and states types of an Accumulator. -/// That's how we create the Accumulator generically. -pub trait AggregateFunctionCreator: Send + Sync + Debug { +/// +/// An `AggregateFunctionCreator` often has a companion struct, that +/// can store the input data types (impl [AggrFuncTypeStore]), and knows the output and states +/// types of an Accumulator. +pub trait AggregateFunctionCreator: AggrFuncTypeStore { /// Create a function that can create a new accumulator with some input data type. fn creator(&self) -> AccumulatorCreatorFunction; - /// Get the input data type of the Accumulator. - fn input_types(&self) -> Result>; - - /// Store the input data type that is provided by DataFusion at runtime. - fn set_input_types(&self, input_types: Vec) -> Result<()>; - /// Get the Accumulator's output data type. fn output_type(&self) -> Result; @@ -66,6 +60,20 @@ pub trait AggregateFunctionCreator: Send + Sync + Debug { fn state_types(&self) -> Result>; } +/// `AggrFuncTypeStore` stores the aggregate function's input data's types. +/// +/// When creating Accumulator generically, we have to know the input data's types. +/// However, DataFusion does not provide the input data's types at the time of creating Accumulator. +/// To solve the problem, we store the datatypes upfront here. +pub trait AggrFuncTypeStore: Send + Sync + Debug { + /// Get the input data types of the Accumulator. + fn input_types(&self) -> Result>; + + /// Store the input data types that are provided by DataFusion at runtime (when it is evaluating + /// return type function). + fn set_input_types(&self, input_types: Vec) -> Result<()>; +} + pub fn make_accumulator_function( creator: Arc, ) -> AccumulatorFunctionImpl { diff --git a/src/common/query/src/logical_plan/mod.rs b/src/common/query/src/logical_plan/mod.rs index af12b9fc3d..c19ca49f54 100644 --- a/src/common/query/src/logical_plan/mod.rs +++ b/src/common/query/src/logical_plan/mod.rs @@ -1,4 +1,4 @@ -mod accumulator; +pub mod accumulator; mod expr; mod udaf; mod udf; @@ -177,11 +177,7 @@ mod tests { #[derive(Debug)] struct DummyAccumulatorCreator; - impl AggregateFunctionCreator for DummyAccumulatorCreator { - fn creator(&self) -> AccumulatorCreatorFunction { - Arc::new(|_| Ok(Box::new(DummyAccumulator))) - } - + impl AggrFuncTypeStore for DummyAccumulatorCreator { fn input_types(&self) -> Result> { Ok(vec![ConcreteDataType::float64_datatype()]) } @@ -189,6 +185,12 @@ mod tests { fn set_input_types(&self, _: Vec) -> Result<()> { Ok(()) } + } + + impl AggregateFunctionCreator for DummyAccumulatorCreator { + fn creator(&self) -> AccumulatorCreatorFunction { + Arc::new(|_| Ok(Box::new(DummyAccumulator))) + } fn output_type(&self) -> Result { Ok(self.input_types()?.into_iter().next().unwrap()) diff --git a/src/query/Cargo.toml b/src/query/Cargo.toml index 86abf5a798..59fc6ac4fc 100644 --- a/src/query/Cargo.toml +++ b/src/query/Cargo.toml @@ -32,6 +32,7 @@ tokio = "1.0" [dev-dependencies] approx_eq = "0.1" +common-function-macro = { path = "../common/function-macro" } num = "0.4" num-traits = "0.2" format_num = "0.1" diff --git a/src/query/tests/my_sum_udaf_example.rs b/src/query/tests/my_sum_udaf_example.rs index 8b4651d07d..eb2144ae89 100644 --- a/src/query/tests/my_sum_udaf_example.rs +++ b/src/query/tests/my_sum_udaf_example.rs @@ -2,12 +2,12 @@ use std::fmt::Debug; use std::marker::PhantomData; use std::sync::Arc; -use arc_swap::ArcSwapOption; use catalog::memory::{MemoryCatalogList, MemoryCatalogProvider, MemorySchemaProvider}; use catalog::{ CatalogList, CatalogProvider, SchemaProvider, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, }; use common_function::scalars::aggregate::AggregateFunctionMeta; +use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore}; use common_query::error::CreateAccumulatorSnafu; use common_query::error::Result as QueryResult; use common_query::logical_plan::Accumulator; @@ -54,10 +54,9 @@ where } } -#[derive(Debug, Default)] -struct MySumAccumulatorCreator { - input_type: ArcSwapOption>, -} +#[as_aggr_func_creator] +#[derive(Debug, Default, AggrFuncTypeStore)] +struct MySumAccumulatorCreator {} impl AggregateFunctionCreator for MySumAccumulatorCreator { fn creator(&self) -> AccumulatorCreatorFunction { @@ -80,26 +79,6 @@ impl AggregateFunctionCreator for MySumAccumulatorCreator { creator } - fn input_types(&self) -> QueryResult> { - Ok(self.input_type - .load() - .as_ref() - .expect("input_type is not present, check if DataFusion has changed its UDAF execution logic") - .as_ref() - .clone()) - } - - fn set_input_types(&self, input_types: Vec) -> QueryResult<()> { - let old = self.input_type.swap(Some(Arc::new(input_types.clone()))); - if let Some(old) = old { - assert_eq!(old.len(), input_types.len()); - old.iter().zip(input_types.iter()).for_each(|(x, y)| - assert_eq!(x, y, "input type {:?} != {:?}, check if DataFusion has changed its UDAF execution logic", x, y) - ); - } - Ok(()) - } - fn output_type(&self) -> QueryResult { let input_type = &self.input_types()?[0]; with_match_primitive_type_id!(