mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-26 18:00:41 +00:00
feat: procedure macro to help writing UDAF (#242)
* feat: procedure macro to help writing UDAF * resolve code review comments Co-authored-by: luofucong <luofucong@greptime.com>
This commit is contained in:
15
Cargo.lock
generated
15
Cargo.lock
generated
@@ -816,6 +816,7 @@ dependencies = [
|
||||
"arrow2",
|
||||
"chrono-tz",
|
||||
"common-error",
|
||||
"common-function-macro",
|
||||
"common-query",
|
||||
"datafusion-common",
|
||||
"datatypes",
|
||||
@@ -836,6 +837,19 @@ dependencies = [
|
||||
"statrs",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "common-function-macro"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"arc-swap",
|
||||
"common-query",
|
||||
"datatypes",
|
||||
"quote",
|
||||
"snafu",
|
||||
"static_assertions",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "common-grpc"
|
||||
version = "0.1.0"
|
||||
@@ -3670,6 +3684,7 @@ dependencies = [
|
||||
"catalog",
|
||||
"common-error",
|
||||
"common-function",
|
||||
"common-function-macro",
|
||||
"common-query",
|
||||
"common-recordbatch",
|
||||
"common-telemetry",
|
||||
|
||||
@@ -6,6 +6,7 @@ members = [
|
||||
"src/common/base",
|
||||
"src/common/error",
|
||||
"src/common/function",
|
||||
"src/common/function-macro",
|
||||
"src/common/grpc",
|
||||
"src/common/query",
|
||||
"src/common/recordbatch",
|
||||
|
||||
@@ -8,27 +8,29 @@ So is there a way we can make an aggregate function that automatically match the
|
||||
|
||||
# 1. Impl `AggregateFunctionCreator` trait for your accumulator creator.
|
||||
|
||||
You must first define a struct that can store the input data's type. For example,
|
||||
You must first define a struct that will be used to create your accumulator. For example,
|
||||
|
||||
```Rust
|
||||
struct MySumAccumulatorCreator {
|
||||
input_types: ArcSwapOption<Vec<ConcreteDataType>>,
|
||||
}
|
||||
#[as_aggr_func_creator]
|
||||
#[derive(Debug, AggrFuncTypeStore)]
|
||||
struct MySumAccumulatorCreator {}
|
||||
```
|
||||
|
||||
Attribute macro `#[as_aggr_func_creator]` and derive macro `#[derive(Debug, AggrFuncTypeStore)]` must both annotated on the struct. They work together to provide a storage of aggregate function's input data types, which are needed for creating generic accumulator later.
|
||||
|
||||
> Note that the `as_aggr_func_creator` macro will add fields to the struct, so the struct cannot be defined as an empty struct without field like `struct Foo;`, neither as a new type like `struct Foo(bar)`.
|
||||
|
||||
Then impl `AggregateFunctionCreator` trait on it. The definition of the trait is:
|
||||
|
||||
```Rust
|
||||
pub trait AggregateFunctionCreator: Send + Sync + Debug {
|
||||
fn creator(&self) -> AccumulatorCreatorFunction;
|
||||
fn input_types(&self) -> Vec<ConcreteDataType>;
|
||||
fn set_input_types(&self, input_types: Vec<ConcreteDataType>);
|
||||
fn output_type(&self) -> ConcreteDataType;
|
||||
fn state_types(&self) -> Vec<ConcreteDataType>;
|
||||
}
|
||||
```
|
||||
|
||||
our query engine will call `set_input_types` the very first, so you can use input data's type in methods that return output type and state types.
|
||||
You can use input data's type in methods that return output type and state types (just invoke `input_types()`).
|
||||
|
||||
The output type is aggregate function's output data's type. For example, `SUM` aggregate function's output type is `u64` for a `u32` datatype column. The state types are accumulator's internal states' types. Take `AVG` aggregate function on a `i32` column as example, it's state types are `i64` (for sum) and `u64` (for count).
|
||||
|
||||
|
||||
18
src/common/function-macro/Cargo.toml
Normal file
18
src/common/function-macro/Cargo.toml
Normal file
@@ -0,0 +1,18 @@
|
||||
[package]
|
||||
name = "common-function-macro"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
proc-macro = true
|
||||
|
||||
[dependencies]
|
||||
common-query = { path = "../query" }
|
||||
datatypes = { path = "../../datatypes" }
|
||||
quote = "1.0"
|
||||
snafu = { version = "0.7", features = ["backtraces"] }
|
||||
syn = "1.0"
|
||||
|
||||
[dev-dependencies]
|
||||
arc-swap = "1.0"
|
||||
static_assertions = "1.1.0"
|
||||
71
src/common/function-macro/src/lib.rs
Normal file
71
src/common/function-macro/src/lib.rs
Normal file
@@ -0,0 +1,71 @@
|
||||
use proc_macro::TokenStream;
|
||||
use quote::{quote, quote_spanned};
|
||||
use syn::parse::Parser;
|
||||
use syn::spanned::Spanned;
|
||||
use syn::{parse_macro_input, DeriveInput, ItemStruct};
|
||||
|
||||
/// Make struct implemented trait [AggrFuncTypeStore], which is necessary when writing UDAF.
|
||||
/// This derive macro is expect to be used along with attribute macro [as_aggr_func_creator].
|
||||
#[proc_macro_derive(AggrFuncTypeStore)]
|
||||
pub fn aggr_func_type_store_derive(input: TokenStream) -> TokenStream {
|
||||
let ast = parse_macro_input!(input as DeriveInput);
|
||||
impl_aggr_func_type_store(&ast)
|
||||
}
|
||||
|
||||
fn impl_aggr_func_type_store(ast: &DeriveInput) -> TokenStream {
|
||||
let name = &ast.ident;
|
||||
let gen = quote! {
|
||||
use common_query::logical_plan::accumulator::AggrFuncTypeStore;
|
||||
use common_query::error::{InvalidInputStateSnafu, Error as QueryError};
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
|
||||
impl AggrFuncTypeStore for #name {
|
||||
fn input_types(&self) -> std::result::Result<Vec<ConcreteDataType>, QueryError> {
|
||||
let input_types = self.input_types.load();
|
||||
snafu::ensure!(input_types.is_some(), InvalidInputStateSnafu);
|
||||
Ok(input_types.as_ref().unwrap().as_ref().clone())
|
||||
}
|
||||
|
||||
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> std::result::Result<(), QueryError> {
|
||||
let old = self.input_types.swap(Some(std::sync::Arc::new(input_types.clone())));
|
||||
if let Some(old) = old {
|
||||
snafu::ensure!(old.len() == input_types.len(), InvalidInputStateSnafu);
|
||||
for (x, y) in old.iter().zip(input_types.iter()) {
|
||||
snafu::ensure!(x == y, InvalidInputStateSnafu);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
};
|
||||
gen.into()
|
||||
}
|
||||
|
||||
/// A struct can be used as a creator for aggregate function if it has been annotated with this
|
||||
/// attribute first. This attribute add a necessary field which is intended to store the input
|
||||
/// data's types to the struct.
|
||||
/// This attribute is expected to be used along with derive macro [AggrFuncTypeStore].
|
||||
#[proc_macro_attribute]
|
||||
pub fn as_aggr_func_creator(_args: TokenStream, input: TokenStream) -> TokenStream {
|
||||
let mut item_struct = parse_macro_input!(input as ItemStruct);
|
||||
if let syn::Fields::Named(ref mut fields) = item_struct.fields {
|
||||
let result = syn::Field::parse_named.parse2(quote! {
|
||||
input_types: arc_swap::ArcSwapOption<Vec<ConcreteDataType>>
|
||||
});
|
||||
match result {
|
||||
Ok(field) => fields.named.push(field),
|
||||
Err(e) => return e.into_compile_error().into(),
|
||||
}
|
||||
} else {
|
||||
return quote_spanned!(
|
||||
item_struct.fields.span() => compile_error!(
|
||||
"This attribute macro needs to add fields to the its annotated struct, \
|
||||
so the struct must have \"{}\".")
|
||||
)
|
||||
.into();
|
||||
}
|
||||
quote! {
|
||||
#item_struct
|
||||
}
|
||||
.into()
|
||||
}
|
||||
14
src/common/function-macro/tests/test_derive.rs
Normal file
14
src/common/function-macro/tests/test_derive.rs
Normal file
@@ -0,0 +1,14 @@
|
||||
use common_function_macro::as_aggr_func_creator;
|
||||
use common_function_macro::AggrFuncTypeStore;
|
||||
use static_assertions::{assert_fields, assert_impl_all};
|
||||
|
||||
#[as_aggr_func_creator]
|
||||
#[derive(Debug, Default, AggrFuncTypeStore)]
|
||||
struct Foo {}
|
||||
|
||||
#[test]
|
||||
fn test_derive() {
|
||||
Foo::default();
|
||||
assert_fields!(Foo: input_types);
|
||||
assert_impl_all!(Foo: std::fmt::Debug, Default, AggrFuncTypeStore);
|
||||
}
|
||||
@@ -13,6 +13,7 @@ arc-swap = "1.0"
|
||||
chrono-tz = "0.6"
|
||||
common-error = { path = "../error" }
|
||||
common-query = { path = "../query" }
|
||||
common-function-macro = { path = "../function-macro" }
|
||||
datafusion-common = { git = "https://github.com/apache/arrow-datafusion.git" , branch = "arrow2" }
|
||||
datatypes = { path = "../../datatypes" }
|
||||
libc = "0.2"
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
use std::cmp::Ordering;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arc_swap::ArcSwapOption;
|
||||
use common_query::error::{
|
||||
BadAccumulatorImplSnafu, CreateAccumulatorSnafu, InvalidInputStateSnafu, Result,
|
||||
};
|
||||
use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore};
|
||||
use common_query::error::{BadAccumulatorImplSnafu, CreateAccumulatorSnafu, Result};
|
||||
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
|
||||
use common_query::prelude::*;
|
||||
use datatypes::vectors::ConstantVector;
|
||||
@@ -98,10 +96,9 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct ArgmaxAccumulatorCreator {
|
||||
input_types: ArcSwapOption<Vec<ConcreteDataType>>,
|
||||
}
|
||||
#[as_aggr_func_creator]
|
||||
#[derive(Debug, Default, AggrFuncTypeStore)]
|
||||
pub struct ArgmaxAccumulatorCreator {}
|
||||
|
||||
impl AggregateFunctionCreator for ArgmaxAccumulatorCreator {
|
||||
fn creator(&self) -> AccumulatorCreatorFunction {
|
||||
@@ -124,23 +121,6 @@ impl AggregateFunctionCreator for ArgmaxAccumulatorCreator {
|
||||
creator
|
||||
}
|
||||
|
||||
fn input_types(&self) -> Result<Vec<ConcreteDataType>> {
|
||||
let input_types = self.input_types.load();
|
||||
ensure!(input_types.is_some(), InvalidInputStateSnafu);
|
||||
Ok(input_types.as_ref().unwrap().as_ref().clone())
|
||||
}
|
||||
|
||||
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> Result<()> {
|
||||
let old = self.input_types.swap(Some(Arc::new(input_types.clone())));
|
||||
if let Some(old) = old {
|
||||
ensure!(old.len() == input_types.len(), InvalidInputStateSnafu);
|
||||
for (x, y) in old.iter().zip(input_types.iter()) {
|
||||
ensure!(x == y, InvalidInputStateSnafu);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn output_type(&self) -> Result<ConcreteDataType> {
|
||||
Ok(ConcreteDataType::uint64_datatype())
|
||||
}
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
use std::cmp::Ordering;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arc_swap::ArcSwapOption;
|
||||
use common_query::error::{
|
||||
BadAccumulatorImplSnafu, CreateAccumulatorSnafu, InvalidInputStateSnafu, Result,
|
||||
};
|
||||
use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore};
|
||||
use common_query::error::{BadAccumulatorImplSnafu, CreateAccumulatorSnafu, Result};
|
||||
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
|
||||
use common_query::prelude::*;
|
||||
use datatypes::vectors::ConstantVector;
|
||||
@@ -107,10 +105,9 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct ArgminAccumulatorCreator {
|
||||
input_types: ArcSwapOption<Vec<ConcreteDataType>>,
|
||||
}
|
||||
#[as_aggr_func_creator]
|
||||
#[derive(Debug, Default, AggrFuncTypeStore)]
|
||||
pub struct ArgminAccumulatorCreator {}
|
||||
|
||||
impl AggregateFunctionCreator for ArgminAccumulatorCreator {
|
||||
fn creator(&self) -> AccumulatorCreatorFunction {
|
||||
@@ -133,23 +130,6 @@ impl AggregateFunctionCreator for ArgminAccumulatorCreator {
|
||||
creator
|
||||
}
|
||||
|
||||
fn input_types(&self) -> Result<Vec<ConcreteDataType>> {
|
||||
let input_types = self.input_types.load();
|
||||
ensure!(input_types.is_some(), InvalidInputStateSnafu);
|
||||
Ok(input_types.as_ref().unwrap().as_ref().clone())
|
||||
}
|
||||
|
||||
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> Result<()> {
|
||||
let old = self.input_types.swap(Some(Arc::new(input_types.clone())));
|
||||
if let Some(old) = old {
|
||||
ensure!(old.len() == input_types.len(), InvalidInputStateSnafu);
|
||||
for (x, y) in old.iter().zip(input_types.iter()) {
|
||||
ensure!(x == y, InvalidInputStateSnafu);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn output_type(&self) -> Result<ConcreteDataType> {
|
||||
Ok(ConcreteDataType::uint32_datatype())
|
||||
}
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
use std::marker::PhantomData;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arc_swap::ArcSwapOption;
|
||||
use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore};
|
||||
use common_query::error::{
|
||||
CreateAccumulatorSnafu, DowncastVectorSnafu, FromScalarValueSnafu, InvalidInputStateSnafu,
|
||||
Result,
|
||||
CreateAccumulatorSnafu, DowncastVectorSnafu, FromScalarValueSnafu, Result,
|
||||
};
|
||||
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
|
||||
use common_query::prelude::*;
|
||||
@@ -118,10 +117,9 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct DiffAccumulatorCreator {
|
||||
input_types: ArcSwapOption<Vec<ConcreteDataType>>,
|
||||
}
|
||||
#[as_aggr_func_creator]
|
||||
#[derive(Debug, Default, AggrFuncTypeStore)]
|
||||
pub struct DiffAccumulatorCreator {}
|
||||
|
||||
impl AggregateFunctionCreator for DiffAccumulatorCreator {
|
||||
fn creator(&self) -> AccumulatorCreatorFunction {
|
||||
@@ -144,24 +142,6 @@ impl AggregateFunctionCreator for DiffAccumulatorCreator {
|
||||
creator
|
||||
}
|
||||
|
||||
fn input_types(&self) -> Result<Vec<ConcreteDataType>> {
|
||||
let input_types = self.input_types.load();
|
||||
ensure!(input_types.is_some(), InvalidInputStateSnafu);
|
||||
|
||||
Ok(input_types.as_ref().unwrap().as_ref().clone())
|
||||
}
|
||||
|
||||
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> Result<()> {
|
||||
let old = self.input_types.swap(Some(Arc::new(input_types.clone())));
|
||||
if let Some(old) = old {
|
||||
ensure!(old.len() != input_types.len(), InvalidInputStateSnafu);
|
||||
for (x, y) in old.iter().zip(input_types.iter()) {
|
||||
ensure!(x == y, InvalidInputStateSnafu);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn output_type(&self) -> Result<ConcreteDataType> {
|
||||
let input_types = self.input_types()?;
|
||||
ensure!(input_types.len() == 1, InvalidInputStateSnafu);
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
use std::marker::PhantomData;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arc_swap::ArcSwapOption;
|
||||
use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore};
|
||||
use common_query::error::{
|
||||
BadAccumulatorImplSnafu, CreateAccumulatorSnafu, DowncastVectorSnafu, InvalidInputStateSnafu,
|
||||
Result,
|
||||
BadAccumulatorImplSnafu, CreateAccumulatorSnafu, DowncastVectorSnafu, Result,
|
||||
};
|
||||
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
|
||||
use common_query::prelude::*;
|
||||
@@ -125,10 +124,9 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct MeanAccumulatorCreator {
|
||||
input_types: ArcSwapOption<Vec<ConcreteDataType>>,
|
||||
}
|
||||
#[as_aggr_func_creator]
|
||||
#[derive(Debug, Default, AggrFuncTypeStore)]
|
||||
pub struct MeanAccumulatorCreator {}
|
||||
|
||||
impl AggregateFunctionCreator for MeanAccumulatorCreator {
|
||||
fn creator(&self) -> AccumulatorCreatorFunction {
|
||||
@@ -151,23 +149,6 @@ impl AggregateFunctionCreator for MeanAccumulatorCreator {
|
||||
creator
|
||||
}
|
||||
|
||||
fn input_types(&self) -> Result<Vec<ConcreteDataType>> {
|
||||
let input_types = self.input_types.load();
|
||||
ensure!(input_types.is_some(), InvalidInputStateSnafu);
|
||||
Ok(input_types.as_ref().unwrap().as_ref().clone())
|
||||
}
|
||||
|
||||
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> Result<()> {
|
||||
let old = self.input_types.swap(Some(Arc::new(input_types.clone())));
|
||||
if let Some(old) = old {
|
||||
ensure!(old.len() == input_types.len(), InvalidInputStateSnafu);
|
||||
for (x, y) in old.iter().zip(input_types.iter()) {
|
||||
ensure!(x == y, InvalidInputStateSnafu);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn output_type(&self) -> Result<ConcreteDataType> {
|
||||
let input_types = self.input_types()?;
|
||||
ensure!(input_types.len() == 1, InvalidInputStateSnafu);
|
||||
|
||||
@@ -2,10 +2,9 @@ use std::cmp::Reverse;
|
||||
use std::collections::BinaryHeap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arc_swap::ArcSwapOption;
|
||||
use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore};
|
||||
use common_query::error::{
|
||||
CreateAccumulatorSnafu, DowncastVectorSnafu, FromScalarValueSnafu, InvalidInputStateSnafu,
|
||||
Result,
|
||||
CreateAccumulatorSnafu, DowncastVectorSnafu, FromScalarValueSnafu, Result,
|
||||
};
|
||||
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
|
||||
use common_query::prelude::*;
|
||||
@@ -175,10 +174,9 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct MedianAccumulatorCreator {
|
||||
input_types: ArcSwapOption<Vec<ConcreteDataType>>,
|
||||
}
|
||||
#[as_aggr_func_creator]
|
||||
#[derive(Debug, Default, AggrFuncTypeStore)]
|
||||
pub struct MedianAccumulatorCreator {}
|
||||
|
||||
impl AggregateFunctionCreator for MedianAccumulatorCreator {
|
||||
fn creator(&self) -> AccumulatorCreatorFunction {
|
||||
@@ -201,23 +199,6 @@ impl AggregateFunctionCreator for MedianAccumulatorCreator {
|
||||
creator
|
||||
}
|
||||
|
||||
fn input_types(&self) -> Result<Vec<ConcreteDataType>> {
|
||||
let input_types = self.input_types.load();
|
||||
ensure!(input_types.is_some(), InvalidInputStateSnafu);
|
||||
Ok(input_types.as_ref().unwrap().as_ref().clone())
|
||||
}
|
||||
|
||||
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> Result<()> {
|
||||
let old = self.input_types.swap(Some(Arc::new(input_types.clone())));
|
||||
if let Some(old) = old {
|
||||
ensure!(old.len() == input_types.len(), InvalidInputStateSnafu);
|
||||
for (x, y) in old.iter().zip(input_types.iter()) {
|
||||
ensure!(x == y, InvalidInputStateSnafu);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn output_type(&self) -> Result<ConcreteDataType> {
|
||||
let input_types = self.input_types()?;
|
||||
ensure!(input_types.len() == 1, InvalidInputStateSnafu);
|
||||
|
||||
@@ -64,55 +64,24 @@ pub(crate) struct AggregateFunctions;
|
||||
|
||||
impl AggregateFunctions {
|
||||
pub fn register(registry: &FunctionRegistry) {
|
||||
registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new(
|
||||
"median",
|
||||
1,
|
||||
Arc::new(|| Arc::new(MedianAccumulatorCreator::default())),
|
||||
)));
|
||||
registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new(
|
||||
"diff",
|
||||
1,
|
||||
Arc::new(|| Arc::new(DiffAccumulatorCreator::default())),
|
||||
)));
|
||||
registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new(
|
||||
"mean",
|
||||
1,
|
||||
Arc::new(|| Arc::new(MeanAccumulatorCreator::default())),
|
||||
)));
|
||||
registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new(
|
||||
"polyval",
|
||||
2,
|
||||
Arc::new(|| Arc::new(PolyvalAccumulatorCreator::default())),
|
||||
)));
|
||||
registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new(
|
||||
"argmax",
|
||||
1,
|
||||
Arc::new(|| Arc::new(ArgmaxAccumulatorCreator::default())),
|
||||
)));
|
||||
registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new(
|
||||
"argmin",
|
||||
1,
|
||||
Arc::new(|| Arc::new(ArgminAccumulatorCreator::default())),
|
||||
)));
|
||||
registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new(
|
||||
"diff",
|
||||
1,
|
||||
Arc::new(|| Arc::new(DiffAccumulatorCreator::default())),
|
||||
)));
|
||||
registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new(
|
||||
"percentile",
|
||||
2,
|
||||
Arc::new(|| Arc::new(PercentileAccumulatorCreator::default())),
|
||||
)));
|
||||
registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new(
|
||||
"scipystatsnormcdf",
|
||||
2,
|
||||
Arc::new(|| Arc::new(ScipyStatsNormCdfAccumulatorCreator::default())),
|
||||
)));
|
||||
registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new(
|
||||
"scipystatsnormpdf",
|
||||
2,
|
||||
Arc::new(|| Arc::new(ScipyStatsNormPdfAccumulatorCreator::default())),
|
||||
)));
|
||||
macro_rules! register_aggr_func {
|
||||
($name :expr, $arg_count :expr, $creator :ty) => {
|
||||
registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new(
|
||||
$name,
|
||||
$arg_count,
|
||||
Arc::new(|| Arc::new(<$creator>::default())),
|
||||
)));
|
||||
};
|
||||
}
|
||||
|
||||
register_aggr_func!("median", 1, MedianAccumulatorCreator);
|
||||
register_aggr_func!("diff", 1, DiffAccumulatorCreator);
|
||||
register_aggr_func!("mean", 1, MeanAccumulatorCreator);
|
||||
register_aggr_func!("polyval", 2, PolyvalAccumulatorCreator);
|
||||
register_aggr_func!("argmax", 1, ArgmaxAccumulatorCreator);
|
||||
register_aggr_func!("argmin", 1, ArgminAccumulatorCreator);
|
||||
register_aggr_func!("percentile", 2, PercentileAccumulatorCreator);
|
||||
register_aggr_func!("scipystatsnormcdf", 2, ScipyStatsNormCdfAccumulatorCreator);
|
||||
register_aggr_func!("scipystatsnormpdf", 2, ScipyStatsNormPdfAccumulatorCreator);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,10 +2,10 @@ use std::cmp::Reverse;
|
||||
use std::collections::BinaryHeap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arc_swap::ArcSwapOption;
|
||||
use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore};
|
||||
use common_query::error::{
|
||||
self, BadAccumulatorImplSnafu, CreateAccumulatorSnafu, DowncastVectorSnafu,
|
||||
FromScalarValueSnafu, InvalidInputColSnafu, InvalidInputStateSnafu, Result,
|
||||
FromScalarValueSnafu, InvalidInputColSnafu, Result,
|
||||
};
|
||||
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
|
||||
use common_query::prelude::*;
|
||||
@@ -231,10 +231,9 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct PercentileAccumulatorCreator {
|
||||
input_types: ArcSwapOption<Vec<ConcreteDataType>>,
|
||||
}
|
||||
#[as_aggr_func_creator]
|
||||
#[derive(Debug, Default, AggrFuncTypeStore)]
|
||||
pub struct PercentileAccumulatorCreator {}
|
||||
|
||||
impl AggregateFunctionCreator for PercentileAccumulatorCreator {
|
||||
fn creator(&self) -> AccumulatorCreatorFunction {
|
||||
@@ -257,23 +256,6 @@ impl AggregateFunctionCreator for PercentileAccumulatorCreator {
|
||||
creator
|
||||
}
|
||||
|
||||
fn input_types(&self) -> Result<Vec<ConcreteDataType>> {
|
||||
let input_types = self.input_types.load();
|
||||
ensure!(input_types.is_some(), InvalidInputStateSnafu);
|
||||
Ok(input_types.as_ref().unwrap().as_ref().clone())
|
||||
}
|
||||
|
||||
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> Result<()> {
|
||||
let old = self.input_types.swap(Some(Arc::new(input_types.clone())));
|
||||
if let Some(old) = old {
|
||||
ensure!(old.len() == input_types.len(), InvalidInputStateSnafu);
|
||||
for (x, y) in old.iter().zip(input_types.iter()) {
|
||||
ensure!(x == y, InvalidInputStateSnafu);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn output_type(&self) -> Result<ConcreteDataType> {
|
||||
let input_types = self.input_types()?;
|
||||
ensure!(input_types.len() == 2, InvalidInputStateSnafu);
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
use std::marker::PhantomData;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arc_swap::ArcSwapOption;
|
||||
use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore};
|
||||
use common_query::error::{
|
||||
self, BadAccumulatorImplSnafu, CreateAccumulatorSnafu, DowncastVectorSnafu,
|
||||
FromScalarValueSnafu, InvalidInputColSnafu, InvalidInputStateSnafu, Result,
|
||||
FromScalarValueSnafu, InvalidInputColSnafu, Result,
|
||||
};
|
||||
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
|
||||
use common_query::prelude::*;
|
||||
@@ -187,10 +187,9 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct PolyvalAccumulatorCreator {
|
||||
input_types: ArcSwapOption<Vec<ConcreteDataType>>,
|
||||
}
|
||||
#[as_aggr_func_creator]
|
||||
#[derive(Debug, Default, AggrFuncTypeStore)]
|
||||
pub struct PolyvalAccumulatorCreator {}
|
||||
|
||||
impl AggregateFunctionCreator for PolyvalAccumulatorCreator {
|
||||
fn creator(&self) -> AccumulatorCreatorFunction {
|
||||
@@ -213,23 +212,6 @@ impl AggregateFunctionCreator for PolyvalAccumulatorCreator {
|
||||
creator
|
||||
}
|
||||
|
||||
fn input_types(&self) -> Result<Vec<ConcreteDataType>> {
|
||||
let input_types = self.input_types.load();
|
||||
ensure!(input_types.is_some(), InvalidInputStateSnafu);
|
||||
Ok(input_types.as_ref().unwrap().as_ref().clone())
|
||||
}
|
||||
|
||||
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> Result<()> {
|
||||
let old = self.input_types.swap(Some(Arc::new(input_types.clone())));
|
||||
if let Some(old) = old {
|
||||
ensure!(old.len() == input_types.len(), InvalidInputStateSnafu);
|
||||
for (x, y) in old.iter().zip(input_types.iter()) {
|
||||
ensure!(x == y, InvalidInputStateSnafu);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn output_type(&self) -> Result<ConcreteDataType> {
|
||||
let input_types = self.input_types()?;
|
||||
ensure!(input_types.len() == 2, InvalidInputStateSnafu);
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use arc_swap::ArcSwapOption;
|
||||
use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore};
|
||||
use common_query::error::{
|
||||
self, BadAccumulatorImplSnafu, CreateAccumulatorSnafu, DowncastVectorSnafu,
|
||||
FromScalarValueSnafu, GenerateFunctionSnafu, InvalidInputColSnafu, InvalidInputStateSnafu,
|
||||
Result,
|
||||
FromScalarValueSnafu, GenerateFunctionSnafu, InvalidInputColSnafu, Result,
|
||||
};
|
||||
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
|
||||
use common_query::prelude::*;
|
||||
@@ -174,10 +173,9 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct ScipyStatsNormCdfAccumulatorCreator {
|
||||
input_types: ArcSwapOption<Vec<ConcreteDataType>>,
|
||||
}
|
||||
#[as_aggr_func_creator]
|
||||
#[derive(Debug, Default, AggrFuncTypeStore)]
|
||||
pub struct ScipyStatsNormCdfAccumulatorCreator {}
|
||||
|
||||
impl AggregateFunctionCreator for ScipyStatsNormCdfAccumulatorCreator {
|
||||
fn creator(&self) -> AccumulatorCreatorFunction {
|
||||
@@ -200,23 +198,6 @@ impl AggregateFunctionCreator for ScipyStatsNormCdfAccumulatorCreator {
|
||||
creator
|
||||
}
|
||||
|
||||
fn input_types(&self) -> Result<Vec<ConcreteDataType>> {
|
||||
let input_types = self.input_types.load();
|
||||
ensure!(input_types.is_some(), InvalidInputStateSnafu);
|
||||
Ok(input_types.as_ref().unwrap().as_ref().clone())
|
||||
}
|
||||
|
||||
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> Result<()> {
|
||||
let old = self.input_types.swap(Some(Arc::new(input_types.clone())));
|
||||
if let Some(old) = old {
|
||||
ensure!(old.len() == input_types.len(), InvalidInputStateSnafu);
|
||||
for (x, y) in old.iter().zip(input_types.iter()) {
|
||||
ensure!(x == y, InvalidInputStateSnafu);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn output_type(&self) -> Result<ConcreteDataType> {
|
||||
let input_types = self.input_types()?;
|
||||
ensure!(input_types.len() == 2, InvalidInputStateSnafu);
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use arc_swap::ArcSwapOption;
|
||||
use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore};
|
||||
use common_query::error::{
|
||||
self, BadAccumulatorImplSnafu, CreateAccumulatorSnafu, DowncastVectorSnafu,
|
||||
FromScalarValueSnafu, GenerateFunctionSnafu, InvalidInputColSnafu, InvalidInputStateSnafu,
|
||||
Result,
|
||||
FromScalarValueSnafu, GenerateFunctionSnafu, InvalidInputColSnafu, Result,
|
||||
};
|
||||
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
|
||||
use common_query::prelude::*;
|
||||
@@ -174,10 +173,9 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct ScipyStatsNormPdfAccumulatorCreator {
|
||||
input_types: ArcSwapOption<Vec<ConcreteDataType>>,
|
||||
}
|
||||
#[as_aggr_func_creator]
|
||||
#[derive(Debug, Default, AggrFuncTypeStore)]
|
||||
pub struct ScipyStatsNormPdfAccumulatorCreator {}
|
||||
|
||||
impl AggregateFunctionCreator for ScipyStatsNormPdfAccumulatorCreator {
|
||||
fn creator(&self) -> AccumulatorCreatorFunction {
|
||||
@@ -200,23 +198,6 @@ impl AggregateFunctionCreator for ScipyStatsNormPdfAccumulatorCreator {
|
||||
creator
|
||||
}
|
||||
|
||||
fn input_types(&self) -> Result<Vec<ConcreteDataType>> {
|
||||
let input_types = self.input_types.load();
|
||||
ensure!(input_types.is_some(), InvalidInputStateSnafu);
|
||||
Ok(input_types.as_ref().unwrap().as_ref().clone())
|
||||
}
|
||||
|
||||
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> Result<()> {
|
||||
let old = self.input_types.swap(Some(Arc::new(input_types.clone())));
|
||||
if let Some(old) = old {
|
||||
ensure!(old.len() == input_types.len(), InvalidInputStateSnafu);
|
||||
for (x, y) in old.iter().zip(input_types.iter()) {
|
||||
ensure!(x == y, InvalidInputStateSnafu);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn output_type(&self) -> Result<ConcreteDataType> {
|
||||
let input_types = self.input_types()?;
|
||||
ensure!(input_types.len() == 2, InvalidInputStateSnafu);
|
||||
|
||||
@@ -19,4 +19,4 @@ statrs = "0.15"
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { version = "1.0", features = ["full"] }
|
||||
common-base = {path = "../base"}
|
||||
common-base = { path = "../base" }
|
||||
|
||||
@@ -45,20 +45,14 @@ pub trait Accumulator: Send + Sync + Debug {
|
||||
}
|
||||
|
||||
/// An `AggregateFunctionCreator` dynamically creates `Accumulator`.
|
||||
/// DataFusion does not provide the input data's types when creating Accumulator, we have to stores
|
||||
/// it somewhere else ourself. So an `AggregateFunctionCreator` often has a companion struct, that
|
||||
/// can store the input data types, and knows the output and states types of an Accumulator.
|
||||
/// That's how we create the Accumulator generically.
|
||||
pub trait AggregateFunctionCreator: Send + Sync + Debug {
|
||||
///
|
||||
/// An `AggregateFunctionCreator` often has a companion struct, that
|
||||
/// can store the input data types (impl [AggrFuncTypeStore]), and knows the output and states
|
||||
/// types of an Accumulator.
|
||||
pub trait AggregateFunctionCreator: AggrFuncTypeStore {
|
||||
/// Create a function that can create a new accumulator with some input data type.
|
||||
fn creator(&self) -> AccumulatorCreatorFunction;
|
||||
|
||||
/// Get the input data type of the Accumulator.
|
||||
fn input_types(&self) -> Result<Vec<ConcreteDataType>>;
|
||||
|
||||
/// Store the input data type that is provided by DataFusion at runtime.
|
||||
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> Result<()>;
|
||||
|
||||
/// Get the Accumulator's output data type.
|
||||
fn output_type(&self) -> Result<ConcreteDataType>;
|
||||
|
||||
@@ -66,6 +60,20 @@ pub trait AggregateFunctionCreator: Send + Sync + Debug {
|
||||
fn state_types(&self) -> Result<Vec<ConcreteDataType>>;
|
||||
}
|
||||
|
||||
/// `AggrFuncTypeStore` stores the aggregate function's input data's types.
|
||||
///
|
||||
/// When creating Accumulator generically, we have to know the input data's types.
|
||||
/// However, DataFusion does not provide the input data's types at the time of creating Accumulator.
|
||||
/// To solve the problem, we store the datatypes upfront here.
|
||||
pub trait AggrFuncTypeStore: Send + Sync + Debug {
|
||||
/// Get the input data types of the Accumulator.
|
||||
fn input_types(&self) -> Result<Vec<ConcreteDataType>>;
|
||||
|
||||
/// Store the input data types that are provided by DataFusion at runtime (when it is evaluating
|
||||
/// return type function).
|
||||
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> Result<()>;
|
||||
}
|
||||
|
||||
pub fn make_accumulator_function(
|
||||
creator: Arc<dyn AggregateFunctionCreator>,
|
||||
) -> AccumulatorFunctionImpl {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
mod accumulator;
|
||||
pub mod accumulator;
|
||||
mod expr;
|
||||
mod udaf;
|
||||
mod udf;
|
||||
@@ -177,11 +177,7 @@ mod tests {
|
||||
#[derive(Debug)]
|
||||
struct DummyAccumulatorCreator;
|
||||
|
||||
impl AggregateFunctionCreator for DummyAccumulatorCreator {
|
||||
fn creator(&self) -> AccumulatorCreatorFunction {
|
||||
Arc::new(|_| Ok(Box::new(DummyAccumulator)))
|
||||
}
|
||||
|
||||
impl AggrFuncTypeStore for DummyAccumulatorCreator {
|
||||
fn input_types(&self) -> Result<Vec<ConcreteDataType>> {
|
||||
Ok(vec![ConcreteDataType::float64_datatype()])
|
||||
}
|
||||
@@ -189,6 +185,12 @@ mod tests {
|
||||
fn set_input_types(&self, _: Vec<ConcreteDataType>) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl AggregateFunctionCreator for DummyAccumulatorCreator {
|
||||
fn creator(&self) -> AccumulatorCreatorFunction {
|
||||
Arc::new(|_| Ok(Box::new(DummyAccumulator)))
|
||||
}
|
||||
|
||||
fn output_type(&self) -> Result<ConcreteDataType> {
|
||||
Ok(self.input_types()?.into_iter().next().unwrap())
|
||||
|
||||
@@ -32,6 +32,7 @@ tokio = "1.0"
|
||||
|
||||
[dev-dependencies]
|
||||
approx_eq = "0.1"
|
||||
common-function-macro = { path = "../common/function-macro" }
|
||||
num = "0.4"
|
||||
num-traits = "0.2"
|
||||
format_num = "0.1"
|
||||
|
||||
@@ -2,12 +2,12 @@ use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arc_swap::ArcSwapOption;
|
||||
use catalog::memory::{MemoryCatalogList, MemoryCatalogProvider, MemorySchemaProvider};
|
||||
use catalog::{
|
||||
CatalogList, CatalogProvider, SchemaProvider, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME,
|
||||
};
|
||||
use common_function::scalars::aggregate::AggregateFunctionMeta;
|
||||
use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore};
|
||||
use common_query::error::CreateAccumulatorSnafu;
|
||||
use common_query::error::Result as QueryResult;
|
||||
use common_query::logical_plan::Accumulator;
|
||||
@@ -54,10 +54,9 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct MySumAccumulatorCreator {
|
||||
input_type: ArcSwapOption<Vec<ConcreteDataType>>,
|
||||
}
|
||||
#[as_aggr_func_creator]
|
||||
#[derive(Debug, Default, AggrFuncTypeStore)]
|
||||
struct MySumAccumulatorCreator {}
|
||||
|
||||
impl AggregateFunctionCreator for MySumAccumulatorCreator {
|
||||
fn creator(&self) -> AccumulatorCreatorFunction {
|
||||
@@ -80,26 +79,6 @@ impl AggregateFunctionCreator for MySumAccumulatorCreator {
|
||||
creator
|
||||
}
|
||||
|
||||
fn input_types(&self) -> QueryResult<Vec<ConcreteDataType>> {
|
||||
Ok(self.input_type
|
||||
.load()
|
||||
.as_ref()
|
||||
.expect("input_type is not present, check if DataFusion has changed its UDAF execution logic")
|
||||
.as_ref()
|
||||
.clone())
|
||||
}
|
||||
|
||||
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> QueryResult<()> {
|
||||
let old = self.input_type.swap(Some(Arc::new(input_types.clone())));
|
||||
if let Some(old) = old {
|
||||
assert_eq!(old.len(), input_types.len());
|
||||
old.iter().zip(input_types.iter()).for_each(|(x, y)|
|
||||
assert_eq!(x, y, "input type {:?} != {:?}, check if DataFusion has changed its UDAF execution logic", x, y)
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn output_type(&self) -> QueryResult<ConcreteDataType> {
|
||||
let input_type = &self.input_types()?[0];
|
||||
with_match_primitive_type_id!(
|
||||
|
||||
Reference in New Issue
Block a user