feat: procedure macro to help writing UDAF (#242)

* feat: procedure macro to help writing UDAF

* resolve code review comments

Co-authored-by: luofucong <luofucong@greptime.com>
This commit is contained in:
LFC
2022-09-13 10:39:44 +08:00
committed by GitHub
parent 628cdb89e8
commit 64b6b2afe1
22 changed files with 226 additions and 317 deletions

View File

@@ -13,6 +13,7 @@ arc-swap = "1.0"
chrono-tz = "0.6"
common-error = { path = "../error" }
common-query = { path = "../query" }
common-function-macro = { path = "../function-macro" }
datafusion-common = { git = "https://github.com/apache/arrow-datafusion.git" , branch = "arrow2" }
datatypes = { path = "../../datatypes" }
libc = "0.2"

View File

@@ -1,10 +1,8 @@
use std::cmp::Ordering;
use std::sync::Arc;
use arc_swap::ArcSwapOption;
use common_query::error::{
BadAccumulatorImplSnafu, CreateAccumulatorSnafu, InvalidInputStateSnafu, Result,
};
use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore};
use common_query::error::{BadAccumulatorImplSnafu, CreateAccumulatorSnafu, Result};
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*;
use datatypes::vectors::ConstantVector;
@@ -98,10 +96,9 @@ where
}
}
#[derive(Debug, Default)]
pub struct ArgmaxAccumulatorCreator {
input_types: ArcSwapOption<Vec<ConcreteDataType>>,
}
#[as_aggr_func_creator]
#[derive(Debug, Default, AggrFuncTypeStore)]
pub struct ArgmaxAccumulatorCreator {}
impl AggregateFunctionCreator for ArgmaxAccumulatorCreator {
fn creator(&self) -> AccumulatorCreatorFunction {
@@ -124,23 +121,6 @@ impl AggregateFunctionCreator for ArgmaxAccumulatorCreator {
creator
}
fn input_types(&self) -> Result<Vec<ConcreteDataType>> {
let input_types = self.input_types.load();
ensure!(input_types.is_some(), InvalidInputStateSnafu);
Ok(input_types.as_ref().unwrap().as_ref().clone())
}
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> Result<()> {
let old = self.input_types.swap(Some(Arc::new(input_types.clone())));
if let Some(old) = old {
ensure!(old.len() == input_types.len(), InvalidInputStateSnafu);
for (x, y) in old.iter().zip(input_types.iter()) {
ensure!(x == y, InvalidInputStateSnafu);
}
}
Ok(())
}
fn output_type(&self) -> Result<ConcreteDataType> {
Ok(ConcreteDataType::uint64_datatype())
}

View File

@@ -1,10 +1,8 @@
use std::cmp::Ordering;
use std::sync::Arc;
use arc_swap::ArcSwapOption;
use common_query::error::{
BadAccumulatorImplSnafu, CreateAccumulatorSnafu, InvalidInputStateSnafu, Result,
};
use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore};
use common_query::error::{BadAccumulatorImplSnafu, CreateAccumulatorSnafu, Result};
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*;
use datatypes::vectors::ConstantVector;
@@ -107,10 +105,9 @@ where
}
}
#[derive(Debug, Default)]
pub struct ArgminAccumulatorCreator {
input_types: ArcSwapOption<Vec<ConcreteDataType>>,
}
#[as_aggr_func_creator]
#[derive(Debug, Default, AggrFuncTypeStore)]
pub struct ArgminAccumulatorCreator {}
impl AggregateFunctionCreator for ArgminAccumulatorCreator {
fn creator(&self) -> AccumulatorCreatorFunction {
@@ -133,23 +130,6 @@ impl AggregateFunctionCreator for ArgminAccumulatorCreator {
creator
}
fn input_types(&self) -> Result<Vec<ConcreteDataType>> {
let input_types = self.input_types.load();
ensure!(input_types.is_some(), InvalidInputStateSnafu);
Ok(input_types.as_ref().unwrap().as_ref().clone())
}
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> Result<()> {
let old = self.input_types.swap(Some(Arc::new(input_types.clone())));
if let Some(old) = old {
ensure!(old.len() == input_types.len(), InvalidInputStateSnafu);
for (x, y) in old.iter().zip(input_types.iter()) {
ensure!(x == y, InvalidInputStateSnafu);
}
}
Ok(())
}
fn output_type(&self) -> Result<ConcreteDataType> {
Ok(ConcreteDataType::uint32_datatype())
}

View File

@@ -1,10 +1,9 @@
use std::marker::PhantomData;
use std::sync::Arc;
use arc_swap::ArcSwapOption;
use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore};
use common_query::error::{
CreateAccumulatorSnafu, DowncastVectorSnafu, FromScalarValueSnafu, InvalidInputStateSnafu,
Result,
CreateAccumulatorSnafu, DowncastVectorSnafu, FromScalarValueSnafu, Result,
};
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*;
@@ -118,10 +117,9 @@ where
}
}
#[derive(Debug, Default)]
pub struct DiffAccumulatorCreator {
input_types: ArcSwapOption<Vec<ConcreteDataType>>,
}
#[as_aggr_func_creator]
#[derive(Debug, Default, AggrFuncTypeStore)]
pub struct DiffAccumulatorCreator {}
impl AggregateFunctionCreator for DiffAccumulatorCreator {
fn creator(&self) -> AccumulatorCreatorFunction {
@@ -144,24 +142,6 @@ impl AggregateFunctionCreator for DiffAccumulatorCreator {
creator
}
fn input_types(&self) -> Result<Vec<ConcreteDataType>> {
let input_types = self.input_types.load();
ensure!(input_types.is_some(), InvalidInputStateSnafu);
Ok(input_types.as_ref().unwrap().as_ref().clone())
}
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> Result<()> {
let old = self.input_types.swap(Some(Arc::new(input_types.clone())));
if let Some(old) = old {
ensure!(old.len() != input_types.len(), InvalidInputStateSnafu);
for (x, y) in old.iter().zip(input_types.iter()) {
ensure!(x == y, InvalidInputStateSnafu);
}
}
Ok(())
}
fn output_type(&self) -> Result<ConcreteDataType> {
let input_types = self.input_types()?;
ensure!(input_types.len() == 1, InvalidInputStateSnafu);

View File

@@ -1,10 +1,9 @@
use std::marker::PhantomData;
use std::sync::Arc;
use arc_swap::ArcSwapOption;
use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore};
use common_query::error::{
BadAccumulatorImplSnafu, CreateAccumulatorSnafu, DowncastVectorSnafu, InvalidInputStateSnafu,
Result,
BadAccumulatorImplSnafu, CreateAccumulatorSnafu, DowncastVectorSnafu, Result,
};
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*;
@@ -125,10 +124,9 @@ where
}
}
#[derive(Debug, Default)]
pub struct MeanAccumulatorCreator {
input_types: ArcSwapOption<Vec<ConcreteDataType>>,
}
#[as_aggr_func_creator]
#[derive(Debug, Default, AggrFuncTypeStore)]
pub struct MeanAccumulatorCreator {}
impl AggregateFunctionCreator for MeanAccumulatorCreator {
fn creator(&self) -> AccumulatorCreatorFunction {
@@ -151,23 +149,6 @@ impl AggregateFunctionCreator for MeanAccumulatorCreator {
creator
}
fn input_types(&self) -> Result<Vec<ConcreteDataType>> {
let input_types = self.input_types.load();
ensure!(input_types.is_some(), InvalidInputStateSnafu);
Ok(input_types.as_ref().unwrap().as_ref().clone())
}
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> Result<()> {
let old = self.input_types.swap(Some(Arc::new(input_types.clone())));
if let Some(old) = old {
ensure!(old.len() == input_types.len(), InvalidInputStateSnafu);
for (x, y) in old.iter().zip(input_types.iter()) {
ensure!(x == y, InvalidInputStateSnafu);
}
}
Ok(())
}
fn output_type(&self) -> Result<ConcreteDataType> {
let input_types = self.input_types()?;
ensure!(input_types.len() == 1, InvalidInputStateSnafu);

View File

@@ -2,10 +2,9 @@ use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::sync::Arc;
use arc_swap::ArcSwapOption;
use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore};
use common_query::error::{
CreateAccumulatorSnafu, DowncastVectorSnafu, FromScalarValueSnafu, InvalidInputStateSnafu,
Result,
CreateAccumulatorSnafu, DowncastVectorSnafu, FromScalarValueSnafu, Result,
};
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*;
@@ -175,10 +174,9 @@ where
}
}
#[derive(Debug, Default)]
pub struct MedianAccumulatorCreator {
input_types: ArcSwapOption<Vec<ConcreteDataType>>,
}
#[as_aggr_func_creator]
#[derive(Debug, Default, AggrFuncTypeStore)]
pub struct MedianAccumulatorCreator {}
impl AggregateFunctionCreator for MedianAccumulatorCreator {
fn creator(&self) -> AccumulatorCreatorFunction {
@@ -201,23 +199,6 @@ impl AggregateFunctionCreator for MedianAccumulatorCreator {
creator
}
fn input_types(&self) -> Result<Vec<ConcreteDataType>> {
let input_types = self.input_types.load();
ensure!(input_types.is_some(), InvalidInputStateSnafu);
Ok(input_types.as_ref().unwrap().as_ref().clone())
}
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> Result<()> {
let old = self.input_types.swap(Some(Arc::new(input_types.clone())));
if let Some(old) = old {
ensure!(old.len() == input_types.len(), InvalidInputStateSnafu);
for (x, y) in old.iter().zip(input_types.iter()) {
ensure!(x == y, InvalidInputStateSnafu);
}
}
Ok(())
}
fn output_type(&self) -> Result<ConcreteDataType> {
let input_types = self.input_types()?;
ensure!(input_types.len() == 1, InvalidInputStateSnafu);

View File

@@ -64,55 +64,24 @@ pub(crate) struct AggregateFunctions;
impl AggregateFunctions {
pub fn register(registry: &FunctionRegistry) {
registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new(
"median",
1,
Arc::new(|| Arc::new(MedianAccumulatorCreator::default())),
)));
registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new(
"diff",
1,
Arc::new(|| Arc::new(DiffAccumulatorCreator::default())),
)));
registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new(
"mean",
1,
Arc::new(|| Arc::new(MeanAccumulatorCreator::default())),
)));
registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new(
"polyval",
2,
Arc::new(|| Arc::new(PolyvalAccumulatorCreator::default())),
)));
registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new(
"argmax",
1,
Arc::new(|| Arc::new(ArgmaxAccumulatorCreator::default())),
)));
registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new(
"argmin",
1,
Arc::new(|| Arc::new(ArgminAccumulatorCreator::default())),
)));
registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new(
"diff",
1,
Arc::new(|| Arc::new(DiffAccumulatorCreator::default())),
)));
registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new(
"percentile",
2,
Arc::new(|| Arc::new(PercentileAccumulatorCreator::default())),
)));
registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new(
"scipystatsnormcdf",
2,
Arc::new(|| Arc::new(ScipyStatsNormCdfAccumulatorCreator::default())),
)));
registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new(
"scipystatsnormpdf",
2,
Arc::new(|| Arc::new(ScipyStatsNormPdfAccumulatorCreator::default())),
)));
macro_rules! register_aggr_func {
($name :expr, $arg_count :expr, $creator :ty) => {
registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new(
$name,
$arg_count,
Arc::new(|| Arc::new(<$creator>::default())),
)));
};
}
register_aggr_func!("median", 1, MedianAccumulatorCreator);
register_aggr_func!("diff", 1, DiffAccumulatorCreator);
register_aggr_func!("mean", 1, MeanAccumulatorCreator);
register_aggr_func!("polyval", 2, PolyvalAccumulatorCreator);
register_aggr_func!("argmax", 1, ArgmaxAccumulatorCreator);
register_aggr_func!("argmin", 1, ArgminAccumulatorCreator);
register_aggr_func!("percentile", 2, PercentileAccumulatorCreator);
register_aggr_func!("scipystatsnormcdf", 2, ScipyStatsNormCdfAccumulatorCreator);
register_aggr_func!("scipystatsnormpdf", 2, ScipyStatsNormPdfAccumulatorCreator);
}
}

View File

@@ -2,10 +2,10 @@ use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::sync::Arc;
use arc_swap::ArcSwapOption;
use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore};
use common_query::error::{
self, BadAccumulatorImplSnafu, CreateAccumulatorSnafu, DowncastVectorSnafu,
FromScalarValueSnafu, InvalidInputColSnafu, InvalidInputStateSnafu, Result,
FromScalarValueSnafu, InvalidInputColSnafu, Result,
};
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*;
@@ -231,10 +231,9 @@ where
}
}
#[derive(Debug, Default)]
pub struct PercentileAccumulatorCreator {
input_types: ArcSwapOption<Vec<ConcreteDataType>>,
}
#[as_aggr_func_creator]
#[derive(Debug, Default, AggrFuncTypeStore)]
pub struct PercentileAccumulatorCreator {}
impl AggregateFunctionCreator for PercentileAccumulatorCreator {
fn creator(&self) -> AccumulatorCreatorFunction {
@@ -257,23 +256,6 @@ impl AggregateFunctionCreator for PercentileAccumulatorCreator {
creator
}
fn input_types(&self) -> Result<Vec<ConcreteDataType>> {
let input_types = self.input_types.load();
ensure!(input_types.is_some(), InvalidInputStateSnafu);
Ok(input_types.as_ref().unwrap().as_ref().clone())
}
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> Result<()> {
let old = self.input_types.swap(Some(Arc::new(input_types.clone())));
if let Some(old) = old {
ensure!(old.len() == input_types.len(), InvalidInputStateSnafu);
for (x, y) in old.iter().zip(input_types.iter()) {
ensure!(x == y, InvalidInputStateSnafu);
}
}
Ok(())
}
fn output_type(&self) -> Result<ConcreteDataType> {
let input_types = self.input_types()?;
ensure!(input_types.len() == 2, InvalidInputStateSnafu);

View File

@@ -1,10 +1,10 @@
use std::marker::PhantomData;
use std::sync::Arc;
use arc_swap::ArcSwapOption;
use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore};
use common_query::error::{
self, BadAccumulatorImplSnafu, CreateAccumulatorSnafu, DowncastVectorSnafu,
FromScalarValueSnafu, InvalidInputColSnafu, InvalidInputStateSnafu, Result,
FromScalarValueSnafu, InvalidInputColSnafu, Result,
};
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*;
@@ -187,10 +187,9 @@ where
}
}
#[derive(Debug, Default)]
pub struct PolyvalAccumulatorCreator {
input_types: ArcSwapOption<Vec<ConcreteDataType>>,
}
#[as_aggr_func_creator]
#[derive(Debug, Default, AggrFuncTypeStore)]
pub struct PolyvalAccumulatorCreator {}
impl AggregateFunctionCreator for PolyvalAccumulatorCreator {
fn creator(&self) -> AccumulatorCreatorFunction {
@@ -213,23 +212,6 @@ impl AggregateFunctionCreator for PolyvalAccumulatorCreator {
creator
}
fn input_types(&self) -> Result<Vec<ConcreteDataType>> {
let input_types = self.input_types.load();
ensure!(input_types.is_some(), InvalidInputStateSnafu);
Ok(input_types.as_ref().unwrap().as_ref().clone())
}
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> Result<()> {
let old = self.input_types.swap(Some(Arc::new(input_types.clone())));
if let Some(old) = old {
ensure!(old.len() == input_types.len(), InvalidInputStateSnafu);
for (x, y) in old.iter().zip(input_types.iter()) {
ensure!(x == y, InvalidInputStateSnafu);
}
}
Ok(())
}
fn output_type(&self) -> Result<ConcreteDataType> {
let input_types = self.input_types()?;
ensure!(input_types.len() == 2, InvalidInputStateSnafu);

View File

@@ -1,10 +1,9 @@
use std::sync::Arc;
use arc_swap::ArcSwapOption;
use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore};
use common_query::error::{
self, BadAccumulatorImplSnafu, CreateAccumulatorSnafu, DowncastVectorSnafu,
FromScalarValueSnafu, GenerateFunctionSnafu, InvalidInputColSnafu, InvalidInputStateSnafu,
Result,
FromScalarValueSnafu, GenerateFunctionSnafu, InvalidInputColSnafu, Result,
};
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*;
@@ -174,10 +173,9 @@ where
}
}
#[derive(Debug, Default)]
pub struct ScipyStatsNormCdfAccumulatorCreator {
input_types: ArcSwapOption<Vec<ConcreteDataType>>,
}
#[as_aggr_func_creator]
#[derive(Debug, Default, AggrFuncTypeStore)]
pub struct ScipyStatsNormCdfAccumulatorCreator {}
impl AggregateFunctionCreator for ScipyStatsNormCdfAccumulatorCreator {
fn creator(&self) -> AccumulatorCreatorFunction {
@@ -200,23 +198,6 @@ impl AggregateFunctionCreator for ScipyStatsNormCdfAccumulatorCreator {
creator
}
fn input_types(&self) -> Result<Vec<ConcreteDataType>> {
let input_types = self.input_types.load();
ensure!(input_types.is_some(), InvalidInputStateSnafu);
Ok(input_types.as_ref().unwrap().as_ref().clone())
}
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> Result<()> {
let old = self.input_types.swap(Some(Arc::new(input_types.clone())));
if let Some(old) = old {
ensure!(old.len() == input_types.len(), InvalidInputStateSnafu);
for (x, y) in old.iter().zip(input_types.iter()) {
ensure!(x == y, InvalidInputStateSnafu);
}
}
Ok(())
}
fn output_type(&self) -> Result<ConcreteDataType> {
let input_types = self.input_types()?;
ensure!(input_types.len() == 2, InvalidInputStateSnafu);

View File

@@ -1,10 +1,9 @@
use std::sync::Arc;
use arc_swap::ArcSwapOption;
use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore};
use common_query::error::{
self, BadAccumulatorImplSnafu, CreateAccumulatorSnafu, DowncastVectorSnafu,
FromScalarValueSnafu, GenerateFunctionSnafu, InvalidInputColSnafu, InvalidInputStateSnafu,
Result,
FromScalarValueSnafu, GenerateFunctionSnafu, InvalidInputColSnafu, Result,
};
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*;
@@ -174,10 +173,9 @@ where
}
}
#[derive(Debug, Default)]
pub struct ScipyStatsNormPdfAccumulatorCreator {
input_types: ArcSwapOption<Vec<ConcreteDataType>>,
}
#[as_aggr_func_creator]
#[derive(Debug, Default, AggrFuncTypeStore)]
pub struct ScipyStatsNormPdfAccumulatorCreator {}
impl AggregateFunctionCreator for ScipyStatsNormPdfAccumulatorCreator {
fn creator(&self) -> AccumulatorCreatorFunction {
@@ -200,23 +198,6 @@ impl AggregateFunctionCreator for ScipyStatsNormPdfAccumulatorCreator {
creator
}
fn input_types(&self) -> Result<Vec<ConcreteDataType>> {
let input_types = self.input_types.load();
ensure!(input_types.is_some(), InvalidInputStateSnafu);
Ok(input_types.as_ref().unwrap().as_ref().clone())
}
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> Result<()> {
let old = self.input_types.swap(Some(Arc::new(input_types.clone())));
if let Some(old) = old {
ensure!(old.len() == input_types.len(), InvalidInputStateSnafu);
for (x, y) in old.iter().zip(input_types.iter()) {
ensure!(x == y, InvalidInputStateSnafu);
}
}
Ok(())
}
fn output_type(&self) -> Result<ConcreteDataType> {
let input_types = self.input_types()?;
ensure!(input_types.len() == 2, InvalidInputStateSnafu);