diff --git a/src/common/function/src/aggr/hll.rs b/src/common/function/src/aggr/hll.rs index b4df0d77f8..2f37f1525b 100644 --- a/src/common/function/src/aggr/hll.rs +++ b/src/common/function/src/aggr/hll.rs @@ -12,6 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! Two UDAFs are implemented for HyperLogLog: +//! +//! - `hll`: Accepts a string column and aggregates the values into a +//! HyperLogLog state. +//! - `hll_merge`: Accepts a binary column of states generated by `hll` +//! and merges them into a single state. +//! +//! The states can be then used to estimate the cardinality of the +//! values in the column by `hll_count` UDF. + use std::sync::Arc; use common_query::prelude::*; diff --git a/src/common/function/src/aggr/uddsketch_state.rs b/src/common/function/src/aggr/uddsketch_state.rs index e1eac765da..3ac138736d 100644 --- a/src/common/function/src/aggr/uddsketch_state.rs +++ b/src/common/function/src/aggr/uddsketch_state.rs @@ -12,6 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! Implementation of the `uddsketch_state` UDAF that generate the state of +//! UDDSketch for a given set of values. +//! +//! The generated state can be used to compute approximate quantiles using +//! `uddsketch_calc` UDF. + use std::sync::Arc; use common_query::prelude::*; diff --git a/src/common/function/src/scalars/aggregate.rs b/src/common/function/src/scalars/aggregate.rs index 81eea378df..65c82ba99c 100644 --- a/src/common/function/src/scalars/aggregate.rs +++ b/src/common/function/src/scalars/aggregate.rs @@ -12,24 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -mod argmax; -mod argmin; -mod diff; -mod mean; -mod polyval; -mod scipy_stats_norm_cdf; -mod scipy_stats_norm_pdf; +//! # Deprecate Warning: +//! +//! This module is deprecated and will be removed in the future. +//! All UDAF implementation here are not maintained and should +//! not be used before they are refactored into the `src/aggr` +//! version. use std::sync::Arc; -pub use argmax::ArgmaxAccumulatorCreator; -pub use argmin::ArgminAccumulatorCreator; use common_query::logical_plan::AggregateFunctionCreatorRef; -pub use diff::DiffAccumulatorCreator; -pub use mean::MeanAccumulatorCreator; -pub use polyval::PolyvalAccumulatorCreator; -pub use scipy_stats_norm_cdf::ScipyStatsNormCdfAccumulatorCreator; -pub use scipy_stats_norm_pdf::ScipyStatsNormPdfAccumulatorCreator; use crate::function_registry::FunctionRegistry; use crate::scalars::vector::product::VectorProductCreator; @@ -76,31 +68,22 @@ pub(crate) struct AggregateFunctions; impl AggregateFunctions { pub fn register(registry: &FunctionRegistry) { - macro_rules! register_aggr_func { - ($name :expr, $arg_count :expr, $creator :ty) => { - registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new( - $name, - $arg_count, - Arc::new(|| Arc::new(<$creator>::default())), - ))); - }; - } - - register_aggr_func!("diff", 1, DiffAccumulatorCreator); - register_aggr_func!("mean", 1, MeanAccumulatorCreator); - register_aggr_func!("polyval", 2, PolyvalAccumulatorCreator); - register_aggr_func!("argmax", 1, ArgmaxAccumulatorCreator); - register_aggr_func!("argmin", 1, ArgminAccumulatorCreator); - register_aggr_func!("scipystatsnormcdf", 2, ScipyStatsNormCdfAccumulatorCreator); - register_aggr_func!("scipystatsnormpdf", 2, ScipyStatsNormPdfAccumulatorCreator); - register_aggr_func!("vec_sum", 1, VectorSumCreator); - register_aggr_func!("vec_product", 1, VectorProductCreator); + registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new( + "vec_sum", + 1, + Arc::new(|| Arc::new(VectorSumCreator::default())), + ))); + registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new( + "vec_product", + 1, + Arc::new(|| Arc::new(VectorProductCreator::default())), + ))); #[cfg(feature = "geo")] - register_aggr_func!( + registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new( "json_encode_path", 3, - super::geo::encoding::JsonPathEncodeFunctionCreator - ); + Arc::new(|| Arc::new(super::geo::encoding::JsonPathEncodeFunctionCreator::default())), + ))); } } diff --git a/src/common/function/src/scalars/aggregate/argmax.rs b/src/common/function/src/scalars/aggregate/argmax.rs deleted file mode 100644 index 4749ff9a3a..0000000000 --- a/src/common/function/src/scalars/aggregate/argmax.rs +++ /dev/null @@ -1,208 +0,0 @@ -// Copyright 2023 Greptime Team -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::cmp::Ordering; -use std::sync::Arc; - -use common_macro::{as_aggr_func_creator, AggrFuncTypeStore}; -use common_query::error::{ - BadAccumulatorImplSnafu, CreateAccumulatorSnafu, InvalidInputStateSnafu, Result, -}; -use common_query::logical_plan::accumulator::AggrFuncTypeStore; -use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; -use common_query::prelude::*; -use datatypes::prelude::*; -use datatypes::types::{LogicalPrimitiveType, WrapperType}; -use datatypes::vectors::{ConstantVector, Helper}; -use datatypes::with_match_primitive_type_id; -use snafu::ensure; - -// https://numpy.org/doc/stable/reference/generated/numpy.argmax.html -// return the index of the max value -#[derive(Debug, Default)] -pub struct Argmax { - max: Option, - n: u64, -} - -impl Argmax -where - T: PartialOrd + Copy, -{ - fn update(&mut self, value: T, index: u64) { - if let Some(Ordering::Less) = self.max.partial_cmp(&Some(value)) { - self.max = Some(value); - self.n = index; - } - } -} - -impl Accumulator for Argmax -where - T: WrapperType + PartialOrd, -{ - fn state(&self) -> Result> { - match self.max { - Some(max) => Ok(vec![max.into(), self.n.into()]), - _ => Ok(vec![Value::Null, self.n.into()]), - } - } - - fn update_batch(&mut self, values: &[VectorRef]) -> Result<()> { - if values.is_empty() { - return Ok(()); - } - - let column = &values[0]; - let column: &::VectorType = if column.is_const() { - let column: &ConstantVector = unsafe { Helper::static_cast(column) }; - unsafe { Helper::static_cast(column.inner()) } - } else { - unsafe { Helper::static_cast(column) } - }; - for (i, v) in column.iter_data().enumerate() { - if let Some(value) = v { - self.update(value, i as u64); - } - } - Ok(()) - } - - fn merge_batch(&mut self, states: &[VectorRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - } - - ensure!( - states.len() == 2, - BadAccumulatorImplSnafu { - err_msg: "expect 2 states in `merge_batch`", - } - ); - - let max = &states[0]; - let index = &states[1]; - let max: &::VectorType = unsafe { Helper::static_cast(max) }; - let index: &::VectorType = unsafe { Helper::static_cast(index) }; - index - .iter_data() - .flatten() - .zip(max.iter_data().flatten()) - .for_each(|(i, max)| self.update(max, i)); - Ok(()) - } - - fn evaluate(&self) -> Result { - match self.max { - Some(_) => Ok(self.n.into()), - _ => Ok(Value::Null), - } - } -} - -#[as_aggr_func_creator] -#[derive(Debug, Default, AggrFuncTypeStore)] -pub struct ArgmaxAccumulatorCreator {} - -impl AggregateFunctionCreator for ArgmaxAccumulatorCreator { - fn creator(&self) -> AccumulatorCreatorFunction { - let creator: AccumulatorCreatorFunction = Arc::new(move |types: &[ConcreteDataType]| { - let input_type = &types[0]; - with_match_primitive_type_id!( - input_type.logical_type_id(), - |$S| { - Ok(Box::new(Argmax::<<$S as LogicalPrimitiveType>::Wrapper>::default())) - }, - { - let err_msg = format!( - "\"ARGMAX\" aggregate function not support data type {:?}", - input_type.logical_type_id(), - ); - CreateAccumulatorSnafu { err_msg }.fail()? - } - ) - }); - creator - } - - fn output_type(&self) -> Result { - Ok(ConcreteDataType::uint64_datatype()) - } - - fn state_types(&self) -> Result> { - let input_types = self.input_types()?; - - ensure!(input_types.len() == 1, InvalidInputStateSnafu); - - Ok(vec![ - input_types.into_iter().next().unwrap(), - ConcreteDataType::uint64_datatype(), - ]) - } -} - -#[cfg(test)] -mod test { - use datatypes::vectors::Int32Vector; - - use super::*; - #[test] - fn test_update_batch() { - // test update empty batch, expect not updating anything - let mut argmax = Argmax::::default(); - argmax.update_batch(&[]).unwrap(); - assert_eq!(Value::Null, argmax.evaluate().unwrap()); - - // test update one not-null value - let mut argmax = Argmax::::default(); - let v: Vec = vec![Arc::new(Int32Vector::from(vec![Some(42)]))]; - argmax.update_batch(&v).unwrap(); - assert_eq!(Value::from(0_u64), argmax.evaluate().unwrap()); - - // test update one null value - let mut argmax = Argmax::::default(); - let v: Vec = vec![Arc::new(Int32Vector::from(vec![Option::::None]))]; - argmax.update_batch(&v).unwrap(); - assert_eq!(Value::Null, argmax.evaluate().unwrap()); - - // test update no null-value batch - let mut argmax = Argmax::::default(); - let v: Vec = vec![Arc::new(Int32Vector::from(vec![ - Some(-1i32), - Some(1), - Some(3), - ]))]; - argmax.update_batch(&v).unwrap(); - assert_eq!(Value::from(2_u64), argmax.evaluate().unwrap()); - - // test update null-value batch - let mut argmax = Argmax::::default(); - let v: Vec = vec![Arc::new(Int32Vector::from(vec![ - Some(-2i32), - None, - Some(4), - ]))]; - argmax.update_batch(&v).unwrap(); - assert_eq!(Value::from(2_u64), argmax.evaluate().unwrap()); - - // test update with constant vector - let mut argmax = Argmax::::default(); - let v: Vec = vec![Arc::new(ConstantVector::new( - Arc::new(Int32Vector::from_vec(vec![4])), - 10, - ))]; - argmax.update_batch(&v).unwrap(); - assert_eq!(Value::from(0_u64), argmax.evaluate().unwrap()); - } -} diff --git a/src/common/function/src/scalars/aggregate/argmin.rs b/src/common/function/src/scalars/aggregate/argmin.rs deleted file mode 100644 index fe89184460..0000000000 --- a/src/common/function/src/scalars/aggregate/argmin.rs +++ /dev/null @@ -1,216 +0,0 @@ -// Copyright 2023 Greptime Team -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::cmp::Ordering; -use std::sync::Arc; - -use common_macro::{as_aggr_func_creator, AggrFuncTypeStore}; -use common_query::error::{ - BadAccumulatorImplSnafu, CreateAccumulatorSnafu, InvalidInputStateSnafu, Result, -}; -use common_query::logical_plan::accumulator::AggrFuncTypeStore; -use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; -use common_query::prelude::*; -use datatypes::prelude::*; -use datatypes::vectors::{ConstantVector, Helper}; -use datatypes::with_match_primitive_type_id; -use snafu::ensure; - -// // https://numpy.org/doc/stable/reference/generated/numpy.argmin.html -#[derive(Debug, Default)] -pub struct Argmin { - min: Option, - n: u32, -} - -impl Argmin -where - T: Copy + PartialOrd, -{ - fn update(&mut self, value: T, index: u32) { - match self.min { - Some(min) => { - if let Some(Ordering::Greater) = min.partial_cmp(&value) { - self.min = Some(value); - self.n = index; - } - } - None => { - self.min = Some(value); - self.n = index; - } - } - } -} - -impl Accumulator for Argmin -where - T: WrapperType + PartialOrd, -{ - fn state(&self) -> Result> { - match self.min { - Some(min) => Ok(vec![min.into(), self.n.into()]), - _ => Ok(vec![Value::Null, self.n.into()]), - } - } - - fn update_batch(&mut self, values: &[VectorRef]) -> Result<()> { - if values.is_empty() { - return Ok(()); - } - - ensure!(values.len() == 1, InvalidInputStateSnafu); - - let column = &values[0]; - let column: &::VectorType = if column.is_const() { - let column: &ConstantVector = unsafe { Helper::static_cast(column) }; - unsafe { Helper::static_cast(column.inner()) } - } else { - unsafe { Helper::static_cast(column) } - }; - for (i, v) in column.iter_data().enumerate() { - if let Some(value) = v { - self.update(value, i as u32); - } - } - Ok(()) - } - - fn merge_batch(&mut self, states: &[VectorRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - } - - ensure!( - states.len() == 2, - BadAccumulatorImplSnafu { - err_msg: "expect 2 states in `merge_batch`", - } - ); - - let min = &states[0]; - let index = &states[1]; - let min: &::VectorType = unsafe { Helper::static_cast(min) }; - let index: &::VectorType = unsafe { Helper::static_cast(index) }; - index - .iter_data() - .flatten() - .zip(min.iter_data().flatten()) - .for_each(|(i, min)| self.update(min, i)); - Ok(()) - } - - fn evaluate(&self) -> Result { - match self.min { - Some(_) => Ok(self.n.into()), - _ => Ok(Value::Null), - } - } -} - -#[as_aggr_func_creator] -#[derive(Debug, Default, AggrFuncTypeStore)] -pub struct ArgminAccumulatorCreator {} - -impl AggregateFunctionCreator for ArgminAccumulatorCreator { - fn creator(&self) -> AccumulatorCreatorFunction { - let creator: AccumulatorCreatorFunction = Arc::new(move |types: &[ConcreteDataType]| { - let input_type = &types[0]; - with_match_primitive_type_id!( - input_type.logical_type_id(), - |$S| { - Ok(Box::new(Argmin::<<$S as LogicalPrimitiveType>::Wrapper>::default())) - }, - { - let err_msg = format!( - "\"ARGMIN\" aggregate function not support data type {:?}", - input_type.logical_type_id(), - ); - CreateAccumulatorSnafu { err_msg }.fail()? - } - ) - }); - creator - } - - fn output_type(&self) -> Result { - Ok(ConcreteDataType::uint32_datatype()) - } - - fn state_types(&self) -> Result> { - let input_types = self.input_types()?; - - ensure!(input_types.len() == 1, InvalidInputStateSnafu); - - Ok(vec![ - input_types.into_iter().next().unwrap(), - ConcreteDataType::uint32_datatype(), - ]) - } -} - -#[cfg(test)] -mod test { - use datatypes::vectors::Int32Vector; - - use super::*; - #[test] - fn test_update_batch() { - // test update empty batch, expect not updating anything - let mut argmin = Argmin::::default(); - argmin.update_batch(&[]).unwrap(); - assert_eq!(Value::Null, argmin.evaluate().unwrap()); - - // test update one not-null value - let mut argmin = Argmin::::default(); - let v: Vec = vec![Arc::new(Int32Vector::from(vec![Some(42)]))]; - argmin.update_batch(&v).unwrap(); - assert_eq!(Value::from(0_u32), argmin.evaluate().unwrap()); - - // test update one null value - let mut argmin = Argmin::::default(); - let v: Vec = vec![Arc::new(Int32Vector::from(vec![Option::::None]))]; - argmin.update_batch(&v).unwrap(); - assert_eq!(Value::Null, argmin.evaluate().unwrap()); - - // test update no null-value batch - let mut argmin = Argmin::::default(); - let v: Vec = vec![Arc::new(Int32Vector::from(vec![ - Some(-1i32), - Some(1), - Some(3), - ]))]; - argmin.update_batch(&v).unwrap(); - assert_eq!(Value::from(0_u32), argmin.evaluate().unwrap()); - - // test update null-value batch - let mut argmin = Argmin::::default(); - let v: Vec = vec![Arc::new(Int32Vector::from(vec![ - Some(-2i32), - None, - Some(4), - ]))]; - argmin.update_batch(&v).unwrap(); - assert_eq!(Value::from(0_u32), argmin.evaluate().unwrap()); - - // test update with constant vector - let mut argmin = Argmin::::default(); - let v: Vec = vec![Arc::new(ConstantVector::new( - Arc::new(Int32Vector::from_vec(vec![4])), - 10, - ))]; - argmin.update_batch(&v).unwrap(); - assert_eq!(Value::from(0_u32), argmin.evaluate().unwrap()); - } -} diff --git a/src/common/function/src/scalars/aggregate/diff.rs b/src/common/function/src/scalars/aggregate/diff.rs deleted file mode 100644 index 25d1614e4b..0000000000 --- a/src/common/function/src/scalars/aggregate/diff.rs +++ /dev/null @@ -1,252 +0,0 @@ -// Copyright 2023 Greptime Team -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::marker::PhantomData; -use std::sync::Arc; - -use common_macro::{as_aggr_func_creator, AggrFuncTypeStore}; -use common_query::error::{ - CreateAccumulatorSnafu, DowncastVectorSnafu, FromScalarValueSnafu, InvalidInputStateSnafu, - Result, -}; -use common_query::logical_plan::accumulator::AggrFuncTypeStore; -use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; -use common_query::prelude::*; -use datatypes::prelude::*; -use datatypes::value::ListValue; -use datatypes::vectors::{ConstantVector, Helper, ListVector}; -use datatypes::with_match_primitive_type_id; -use num_traits::AsPrimitive; -use snafu::{ensure, OptionExt, ResultExt}; - -// https://numpy.org/doc/stable/reference/generated/numpy.diff.html -// I is the input type, O is the output type. -#[derive(Debug, Default)] -pub struct Diff { - values: Vec, - _phantom: PhantomData, -} - -impl Diff { - fn push(&mut self, value: I) { - self.values.push(value); - } -} - -impl Accumulator for Diff -where - I: WrapperType, - O: WrapperType, - I::Native: AsPrimitive, - O::Native: std::ops::Sub, -{ - fn state(&self) -> Result> { - let nums = self - .values - .iter() - .map(|&n| n.into()) - .collect::>(); - Ok(vec![Value::List(ListValue::new( - nums, - I::LogicalType::build_data_type(), - ))]) - } - - fn update_batch(&mut self, values: &[VectorRef]) -> Result<()> { - if values.is_empty() { - return Ok(()); - } - - ensure!(values.len() == 1, InvalidInputStateSnafu); - - let column = &values[0]; - let mut len = 1; - let column: &::VectorType = if column.is_const() { - len = column.len(); - let column: &ConstantVector = unsafe { Helper::static_cast(column) }; - unsafe { Helper::static_cast(column.inner()) } - } else { - unsafe { Helper::static_cast(column) } - }; - (0..len).for_each(|_| { - for v in column.iter_data().flatten() { - self.push(v); - } - }); - Ok(()) - } - - fn merge_batch(&mut self, states: &[VectorRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - } - - let states = &states[0]; - let states = states - .as_any() - .downcast_ref::() - .with_context(|| DowncastVectorSnafu { - err_msg: format!( - "expect ListVector, got vector type {}", - states.vector_type_name() - ), - })?; - for state in states.values_iter() { - if let Some(state) = state.context(FromScalarValueSnafu)? { - self.update_batch(&[state])?; - } - } - Ok(()) - } - - fn evaluate(&self) -> Result { - if self.values.is_empty() || self.values.len() == 1 { - return Ok(Value::Null); - } - let diff = self - .values - .windows(2) - .map(|x| { - let native = x[1].into_native().as_() - x[0].into_native().as_(); - O::from_native(native).into() - }) - .collect::>(); - let diff = Value::List(ListValue::new(diff, O::LogicalType::build_data_type())); - Ok(diff) - } -} - -#[as_aggr_func_creator] -#[derive(Debug, Default, AggrFuncTypeStore)] -pub struct DiffAccumulatorCreator {} - -impl AggregateFunctionCreator for DiffAccumulatorCreator { - fn creator(&self) -> AccumulatorCreatorFunction { - let creator: AccumulatorCreatorFunction = Arc::new(move |types: &[ConcreteDataType]| { - let input_type = &types[0]; - with_match_primitive_type_id!( - input_type.logical_type_id(), - |$S| { - Ok(Box::new(Diff::<<$S as LogicalPrimitiveType>::Wrapper, <<$S as LogicalPrimitiveType>::LargestType as LogicalPrimitiveType>::Wrapper>::default())) - }, - { - let err_msg = format!( - "\"DIFF\" aggregate function not support data type {:?}", - input_type.logical_type_id(), - ); - CreateAccumulatorSnafu { err_msg }.fail()? - } - ) - }); - creator - } - - fn output_type(&self) -> Result { - let input_types = self.input_types()?; - ensure!(input_types.len() == 1, InvalidInputStateSnafu); - with_match_primitive_type_id!( - input_types[0].logical_type_id(), - |$S| { - Ok(ConcreteDataType::list_datatype($S::default().into())) - }, - { - unreachable!() - } - ) - } - - fn state_types(&self) -> Result> { - let input_types = self.input_types()?; - ensure!(input_types.len() == 1, InvalidInputStateSnafu); - with_match_primitive_type_id!( - input_types[0].logical_type_id(), - |$S| { - Ok(vec![ConcreteDataType::list_datatype($S::default().into())]) - }, - { - unreachable!() - } - ) - } -} - -#[cfg(test)] -mod test { - use datatypes::vectors::Int32Vector; - - use super::*; - - #[test] - fn test_update_batch() { - // test update empty batch, expect not updating anything - let mut diff = Diff::::default(); - diff.update_batch(&[]).unwrap(); - assert!(diff.values.is_empty()); - assert_eq!(Value::Null, diff.evaluate().unwrap()); - - // test update one not-null value - let mut diff = Diff::::default(); - let v: Vec = vec![Arc::new(Int32Vector::from(vec![Some(42)]))]; - diff.update_batch(&v).unwrap(); - assert_eq!(Value::Null, diff.evaluate().unwrap()); - - // test update one null value - let mut diff = Diff::::default(); - let v: Vec = vec![Arc::new(Int32Vector::from(vec![Option::::None]))]; - diff.update_batch(&v).unwrap(); - assert_eq!(Value::Null, diff.evaluate().unwrap()); - - // test update no null-value batch - let mut diff = Diff::::default(); - let v: Vec = vec![Arc::new(Int32Vector::from(vec![ - Some(-1i32), - Some(1), - Some(2), - ]))]; - let values = vec![Value::from(2_i64), Value::from(1_i64)]; - diff.update_batch(&v).unwrap(); - assert_eq!( - Value::List(ListValue::new(values, ConcreteDataType::int64_datatype())), - diff.evaluate().unwrap() - ); - - // test update null-value batch - let mut diff = Diff::::default(); - let v: Vec = vec![Arc::new(Int32Vector::from(vec![ - Some(-2i32), - None, - Some(3), - Some(4), - ]))]; - let values = vec![Value::from(5_i64), Value::from(1_i64)]; - diff.update_batch(&v).unwrap(); - assert_eq!( - Value::List(ListValue::new(values, ConcreteDataType::int64_datatype())), - diff.evaluate().unwrap() - ); - - // test update with constant vector - let mut diff = Diff::::default(); - let v: Vec = vec![Arc::new(ConstantVector::new( - Arc::new(Int32Vector::from_vec(vec![4])), - 4, - ))]; - let values = vec![Value::from(0_i64), Value::from(0_i64), Value::from(0_i64)]; - diff.update_batch(&v).unwrap(); - assert_eq!( - Value::List(ListValue::new(values, ConcreteDataType::int64_datatype())), - diff.evaluate().unwrap() - ); - } -} diff --git a/src/common/function/src/scalars/aggregate/mean.rs b/src/common/function/src/scalars/aggregate/mean.rs deleted file mode 100644 index ed66c90bdb..0000000000 --- a/src/common/function/src/scalars/aggregate/mean.rs +++ /dev/null @@ -1,238 +0,0 @@ -// Copyright 2023 Greptime Team -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::marker::PhantomData; -use std::sync::Arc; - -use common_macro::{as_aggr_func_creator, AggrFuncTypeStore}; -use common_query::error::{ - BadAccumulatorImplSnafu, CreateAccumulatorSnafu, DowncastVectorSnafu, InvalidInputStateSnafu, - Result, -}; -use common_query::logical_plan::accumulator::AggrFuncTypeStore; -use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; -use common_query::prelude::*; -use datatypes::prelude::*; -use datatypes::types::WrapperType; -use datatypes::vectors::{ConstantVector, Float64Vector, Helper, UInt64Vector}; -use datatypes::with_match_primitive_type_id; -use num_traits::AsPrimitive; -use snafu::{ensure, OptionExt}; - -#[derive(Debug, Default)] -pub struct Mean { - sum: f64, - n: u64, - _phantom: PhantomData, -} - -impl Mean -where - T: WrapperType, - T::Native: AsPrimitive, -{ - #[inline(always)] - fn push(&mut self, value: T) { - self.sum += value.into_native().as_(); - self.n += 1; - } - - #[inline(always)] - fn update(&mut self, sum: f64, n: u64) { - self.sum += sum; - self.n += n; - } -} - -impl Accumulator for Mean -where - T: WrapperType, - T::Native: AsPrimitive, -{ - fn state(&self) -> Result> { - Ok(vec![self.sum.into(), self.n.into()]) - } - - fn update_batch(&mut self, values: &[VectorRef]) -> Result<()> { - if values.is_empty() { - return Ok(()); - } - - ensure!(values.len() == 1, InvalidInputStateSnafu); - let column = &values[0]; - let mut len = 1; - let column: &::VectorType = if column.is_const() { - len = column.len(); - let column: &ConstantVector = unsafe { Helper::static_cast(column) }; - unsafe { Helper::static_cast(column.inner()) } - } else { - unsafe { Helper::static_cast(column) } - }; - (0..len).for_each(|_| { - for v in column.iter_data().flatten() { - self.push(v); - } - }); - - Ok(()) - } - - fn merge_batch(&mut self, states: &[VectorRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - } - - ensure!( - states.len() == 2, - BadAccumulatorImplSnafu { - err_msg: "expect 2 states in `merge_batch`", - } - ); - - let sum = &states[0]; - let n = &states[1]; - - let sum = sum - .as_any() - .downcast_ref::() - .with_context(|| DowncastVectorSnafu { - err_msg: format!( - "expect Float64Vector, got vector type {}", - sum.vector_type_name() - ), - })?; - - let n = n - .as_any() - .downcast_ref::() - .with_context(|| DowncastVectorSnafu { - err_msg: format!( - "expect UInt64Vector, got vector type {}", - sum.vector_type_name() - ), - })?; - - sum.iter_data().zip(n.iter_data()).for_each(|(sum, n)| { - if let (Some(sum), Some(n)) = (sum, n) { - self.update(sum, n); - } - }); - Ok(()) - } - - fn evaluate(&self) -> Result { - if self.n == 0 { - return Ok(Value::Null); - } - let values = self.sum / self.n as f64; - Ok(values.into()) - } -} - -#[as_aggr_func_creator] -#[derive(Debug, Default, AggrFuncTypeStore)] -pub struct MeanAccumulatorCreator {} - -impl AggregateFunctionCreator for MeanAccumulatorCreator { - fn creator(&self) -> AccumulatorCreatorFunction { - let creator: AccumulatorCreatorFunction = Arc::new(move |types: &[ConcreteDataType]| { - let input_type = &types[0]; - with_match_primitive_type_id!( - input_type.logical_type_id(), - |$S| { - Ok(Box::new(Mean::<<$S as LogicalPrimitiveType>::Native>::default())) - }, - { - let err_msg = format!( - "\"MEAN\" aggregate function not support data type {:?}", - input_type.logical_type_id(), - ); - CreateAccumulatorSnafu { err_msg }.fail()? - } - ) - }); - creator - } - - fn output_type(&self) -> Result { - let input_types = self.input_types()?; - ensure!(input_types.len() == 1, InvalidInputStateSnafu); - Ok(ConcreteDataType::float64_datatype()) - } - - fn state_types(&self) -> Result> { - let input_types = self.input_types()?; - ensure!(input_types.len() == 1, InvalidInputStateSnafu); - Ok(vec![ - ConcreteDataType::float64_datatype(), - ConcreteDataType::uint64_datatype(), - ]) - } -} - -#[cfg(test)] -mod test { - use datatypes::vectors::Int32Vector; - - use super::*; - #[test] - fn test_update_batch() { - // test update empty batch, expect not updating anything - let mut mean = Mean::::default(); - mean.update_batch(&[]).unwrap(); - assert_eq!(Value::Null, mean.evaluate().unwrap()); - - // test update one not-null value - let mut mean = Mean::::default(); - let v: Vec = vec![Arc::new(Int32Vector::from(vec![Some(42)]))]; - mean.update_batch(&v).unwrap(); - assert_eq!(Value::from(42.0_f64), mean.evaluate().unwrap()); - - // test update one null value - let mut mean = Mean::::default(); - let v: Vec = vec![Arc::new(Int32Vector::from(vec![Option::::None]))]; - mean.update_batch(&v).unwrap(); - assert_eq!(Value::Null, mean.evaluate().unwrap()); - - // test update no null-value batch - let mut mean = Mean::::default(); - let v: Vec = vec![Arc::new(Int32Vector::from(vec![ - Some(-1i32), - Some(1), - Some(2), - ]))]; - mean.update_batch(&v).unwrap(); - assert_eq!(Value::from(0.6666666666666666), mean.evaluate().unwrap()); - - // test update null-value batch - let mut mean = Mean::::default(); - let v: Vec = vec![Arc::new(Int32Vector::from(vec![ - Some(-2i32), - None, - Some(3), - Some(4), - ]))]; - mean.update_batch(&v).unwrap(); - assert_eq!(Value::from(1.6666666666666667), mean.evaluate().unwrap()); - - // test update with constant vector - let mut mean = Mean::::default(); - let v: Vec = vec![Arc::new(ConstantVector::new( - Arc::new(Int32Vector::from_vec(vec![4])), - 10, - ))]; - mean.update_batch(&v).unwrap(); - assert_eq!(Value::from(4.0), mean.evaluate().unwrap()); - } -} diff --git a/src/common/function/src/scalars/aggregate/polyval.rs b/src/common/function/src/scalars/aggregate/polyval.rs deleted file mode 100644 index bc3986fd0e..0000000000 --- a/src/common/function/src/scalars/aggregate/polyval.rs +++ /dev/null @@ -1,329 +0,0 @@ -// Copyright 2023 Greptime Team -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::marker::PhantomData; -use std::sync::Arc; - -use common_macro::{as_aggr_func_creator, AggrFuncTypeStore}; -use common_query::error::{ - self, BadAccumulatorImplSnafu, CreateAccumulatorSnafu, DowncastVectorSnafu, - FromScalarValueSnafu, InvalidInputColSnafu, InvalidInputStateSnafu, Result, -}; -use common_query::logical_plan::accumulator::AggrFuncTypeStore; -use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; -use common_query::prelude::*; -use datatypes::prelude::*; -use datatypes::types::{LogicalPrimitiveType, WrapperType}; -use datatypes::value::ListValue; -use datatypes::vectors::{ConstantVector, Helper, Int64Vector, ListVector}; -use datatypes::with_match_primitive_type_id; -use num_traits::AsPrimitive; -use snafu::{ensure, OptionExt, ResultExt}; - -// https://numpy.org/doc/stable/reference/generated/numpy.polyval.html -#[derive(Debug, Default)] -pub struct Polyval -where - T: WrapperType, - T::Native: AsPrimitive, - PolyT: WrapperType, - PolyT::Native: std::ops::Mul, -{ - values: Vec, - // DataFusion casts constant in into i64 type. - x: Option, - _phantom: PhantomData, -} - -impl Polyval -where - T: WrapperType, - T::Native: AsPrimitive, - PolyT: WrapperType, - PolyT::Native: std::ops::Mul, -{ - fn push(&mut self, value: T) { - self.values.push(value); - } -} - -impl Accumulator for Polyval -where - T: WrapperType, - T::Native: AsPrimitive, - PolyT: WrapperType + std::iter::Sum<::Native>, - PolyT::Native: std::ops::Mul + std::iter::Sum, - i64: AsPrimitive<::Native>, -{ - fn state(&self) -> Result> { - let nums = self - .values - .iter() - .map(|&n| n.into()) - .collect::>(); - Ok(vec![ - Value::List(ListValue::new(nums, T::LogicalType::build_data_type())), - self.x.into(), - ]) - } - - fn update_batch(&mut self, values: &[VectorRef]) -> Result<()> { - if values.is_empty() { - return Ok(()); - } - - ensure!(values.len() == 2, InvalidInputStateSnafu); - ensure!(values[0].len() == values[1].len(), InvalidInputStateSnafu); - if values[0].len() == 0 { - return Ok(()); - } - // This is a unary accumulator, so only one column is provided. - let column = &values[0]; - let mut len = 1; - let column: &::VectorType = if column.is_const() { - len = column.len(); - let column: &ConstantVector = unsafe { Helper::static_cast(column) }; - unsafe { Helper::static_cast(column.inner()) } - } else { - unsafe { Helper::static_cast(column) } - }; - (0..len).for_each(|_| { - for v in column.iter_data().flatten() { - self.push(v); - } - }); - - let x = &values[1]; - let x = Helper::check_get_scalar::(x).context(error::InvalidInputTypeSnafu { - err_msg: "expecting \"POLYVAL\" function's second argument to be a positive integer", - })?; - // `get(0)` is safe because we have checked `values[1].len() == values[0].len() != 0` - let first = x.get(0); - ensure!(!first.is_null(), InvalidInputColSnafu); - - for i in 1..x.len() { - ensure!(first == x.get(i), InvalidInputColSnafu); - } - - let first = match first { - Value::Int64(v) => v, - // unreachable because we have checked `first` is not null and is i64 above - _ => unreachable!(), - }; - if let Some(x) = self.x { - ensure!(x == first, InvalidInputColSnafu); - } else { - self.x = Some(first); - }; - Ok(()) - } - - // DataFusion executes accumulators in partitions. In some execution stage, DataFusion will - // merge states from other accumulators (returned by `state()` method). - fn merge_batch(&mut self, states: &[VectorRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - } - - ensure!( - states.len() == 2, - BadAccumulatorImplSnafu { - err_msg: "expect 2 states in `merge_batch`", - } - ); - - let x = &states[1]; - let x = x - .as_any() - .downcast_ref::() - .with_context(|| DowncastVectorSnafu { - err_msg: format!( - "expect Int64Vector, got vector type {}", - x.vector_type_name() - ), - })?; - let x = x.get(0); - if x.is_null() { - return Ok(()); - } - let x = match x { - Value::Int64(x) => x, - _ => unreachable!(), - }; - self.x = Some(x); - - let values = &states[0]; - let values = values - .as_any() - .downcast_ref::() - .with_context(|| DowncastVectorSnafu { - err_msg: format!( - "expect ListVector, got vector type {}", - values.vector_type_name() - ), - })?; - for value in values.values_iter() { - if let Some(value) = value.context(FromScalarValueSnafu)? { - let column: &::VectorType = unsafe { Helper::static_cast(&value) }; - for v in column.iter_data().flatten() { - self.push(v); - } - } - } - - Ok(()) - } - - // DataFusion expects this function to return the final value of this aggregator. - fn evaluate(&self) -> Result { - if self.values.is_empty() { - return Ok(Value::Null); - } - let x = if let Some(x) = self.x { - x - } else { - return Ok(Value::Null); - }; - let len = self.values.len(); - let polyval: PolyT = self - .values - .iter() - .enumerate() - .map(|(i, &value)| value.into_native().as_() * x.pow((len - 1 - i) as u32).as_()) - .sum(); - Ok(polyval.into()) - } -} - -#[as_aggr_func_creator] -#[derive(Debug, Default, AggrFuncTypeStore)] -pub struct PolyvalAccumulatorCreator {} - -impl AggregateFunctionCreator for PolyvalAccumulatorCreator { - fn creator(&self) -> AccumulatorCreatorFunction { - let creator: AccumulatorCreatorFunction = Arc::new(move |types: &[ConcreteDataType]| { - let input_type = &types[0]; - with_match_primitive_type_id!( - input_type.logical_type_id(), - |$S| { - Ok(Box::new(Polyval::<<$S as LogicalPrimitiveType>::Wrapper, <<$S as LogicalPrimitiveType>::LargestType as LogicalPrimitiveType>::Wrapper>::default())) - }, - { - let err_msg = format!( - "\"POLYVAL\" aggregate function not support data type {:?}", - input_type.logical_type_id(), - ); - CreateAccumulatorSnafu { err_msg }.fail()? - } - ) - }); - creator - } - - fn output_type(&self) -> Result { - let input_types = self.input_types()?; - ensure!(input_types.len() == 2, InvalidInputStateSnafu); - let input_type = self.input_types()?[0].logical_type_id(); - with_match_primitive_type_id!( - input_type, - |$S| { - Ok(<<$S as LogicalPrimitiveType>::LargestType as LogicalPrimitiveType>::build_data_type()) - }, - { - unreachable!() - } - ) - } - - fn state_types(&self) -> Result> { - let input_types = self.input_types()?; - ensure!(input_types.len() == 2, InvalidInputStateSnafu); - Ok(vec![ - ConcreteDataType::list_datatype(input_types.into_iter().next().unwrap()), - ConcreteDataType::int64_datatype(), - ]) - } -} - -#[cfg(test)] -mod test { - use datatypes::vectors::Int32Vector; - - use super::*; - #[test] - fn test_update_batch() { - // test update empty batch, expect not updating anything - let mut polyval = Polyval::::default(); - polyval.update_batch(&[]).unwrap(); - assert!(polyval.values.is_empty()); - assert_eq!(Value::Null, polyval.evaluate().unwrap()); - - // test update one not-null value - let mut polyval = Polyval::::default(); - let v: Vec = vec![ - Arc::new(Int32Vector::from(vec![Some(3)])), - Arc::new(Int64Vector::from(vec![Some(2_i64)])), - ]; - polyval.update_batch(&v).unwrap(); - assert_eq!(Value::Int64(3), polyval.evaluate().unwrap()); - - // test update one null value - let mut polyval = Polyval::::default(); - let v: Vec = vec![ - Arc::new(Int32Vector::from(vec![Option::::None])), - Arc::new(Int64Vector::from(vec![Some(2_i64)])), - ]; - polyval.update_batch(&v).unwrap(); - assert_eq!(Value::Null, polyval.evaluate().unwrap()); - - // test update no null-value batch - let mut polyval = Polyval::::default(); - let v: Vec = vec![ - Arc::new(Int32Vector::from(vec![Some(3), Some(0), Some(1)])), - Arc::new(Int64Vector::from(vec![ - Some(2_i64), - Some(2_i64), - Some(2_i64), - ])), - ]; - polyval.update_batch(&v).unwrap(); - assert_eq!(Value::Int64(13), polyval.evaluate().unwrap()); - - // test update null-value batch - let mut polyval = Polyval::::default(); - let v: Vec = vec![ - Arc::new(Int32Vector::from(vec![Some(3), Some(0), None, Some(1)])), - Arc::new(Int64Vector::from(vec![ - Some(2_i64), - Some(2_i64), - Some(2_i64), - Some(2_i64), - ])), - ]; - polyval.update_batch(&v).unwrap(); - assert_eq!(Value::Int64(13), polyval.evaluate().unwrap()); - - // test update with constant vector - let mut polyval = Polyval::::default(); - let v: Vec = vec![ - Arc::new(ConstantVector::new( - Arc::new(Int32Vector::from_vec(vec![4])), - 2, - )), - Arc::new(Int64Vector::from(vec![Some(5_i64), Some(5_i64)])), - ]; - polyval.update_batch(&v).unwrap(); - assert_eq!(Value::Int64(24), polyval.evaluate().unwrap()); - } -} diff --git a/src/common/function/src/scalars/aggregate/scipy_stats_norm_cdf.rs b/src/common/function/src/scalars/aggregate/scipy_stats_norm_cdf.rs deleted file mode 100644 index 09a9c820d8..0000000000 --- a/src/common/function/src/scalars/aggregate/scipy_stats_norm_cdf.rs +++ /dev/null @@ -1,270 +0,0 @@ -// Copyright 2023 Greptime Team -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::sync::Arc; - -use common_macro::{as_aggr_func_creator, AggrFuncTypeStore}; -use common_query::error::{ - self, BadAccumulatorImplSnafu, CreateAccumulatorSnafu, DowncastVectorSnafu, - FromScalarValueSnafu, GenerateFunctionSnafu, InvalidInputColSnafu, InvalidInputStateSnafu, - Result, -}; -use common_query::logical_plan::accumulator::AggrFuncTypeStore; -use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; -use common_query::prelude::*; -use datatypes::prelude::*; -use datatypes::value::{ListValue, OrderedFloat}; -use datatypes::vectors::{ConstantVector, Float64Vector, Helper, ListVector}; -use datatypes::with_match_primitive_type_id; -use num_traits::AsPrimitive; -use snafu::{ensure, OptionExt, ResultExt}; -use statrs::distribution::{ContinuousCDF, Normal}; -use statrs::statistics::Statistics; - -// https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.norm.html - -#[derive(Debug, Default)] -pub struct ScipyStatsNormCdf { - values: Vec, - x: Option, -} - -impl ScipyStatsNormCdf { - fn push(&mut self, value: T) { - self.values.push(value); - } -} - -impl Accumulator for ScipyStatsNormCdf -where - T: WrapperType + std::iter::Sum, - T::Native: AsPrimitive, -{ - fn state(&self) -> Result> { - let nums = self - .values - .iter() - .map(|&x| x.into()) - .collect::>(); - Ok(vec![ - Value::List(ListValue::new(nums, T::LogicalType::build_data_type())), - self.x.into(), - ]) - } - - fn update_batch(&mut self, values: &[VectorRef]) -> Result<()> { - if values.is_empty() { - return Ok(()); - } - - ensure!(values.len() == 2, InvalidInputStateSnafu); - ensure!(values[1].len() == values[0].len(), InvalidInputStateSnafu); - - if values[0].len() == 0 { - return Ok(()); - } - - let column = &values[0]; - let mut len = 1; - let column: &::VectorType = if column.is_const() { - len = column.len(); - let column: &ConstantVector = unsafe { Helper::static_cast(column) }; - unsafe { Helper::static_cast(column.inner()) } - } else { - unsafe { Helper::static_cast(column) } - }; - - let x = &values[1]; - let x = Helper::check_get_scalar::(x).context(error::InvalidInputTypeSnafu { - err_msg: "expecting \"SCIPYSTATSNORMCDF\" function's second argument to be a positive integer", - })?; - let first = x.get(0); - ensure!(!first.is_null(), InvalidInputColSnafu); - let first = match first { - Value::Float64(OrderedFloat(v)) => v, - // unreachable because we have checked `first` is not null and is i64 above - _ => unreachable!(), - }; - if let Some(x) = self.x { - ensure!(x == first, InvalidInputColSnafu); - } else { - self.x = Some(first); - }; - - (0..len).for_each(|_| { - for v in column.iter_data().flatten() { - self.push(v); - } - }); - Ok(()) - } - - fn merge_batch(&mut self, states: &[VectorRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - } - - ensure!( - states.len() == 2, - BadAccumulatorImplSnafu { - err_msg: "expect 2 states in `merge_batch`", - } - ); - - let x = &states[1]; - let x = x - .as_any() - .downcast_ref::() - .with_context(|| DowncastVectorSnafu { - err_msg: format!( - "expect Float64Vector, got vector type {}", - x.vector_type_name() - ), - })?; - let x = x.get(0); - if x.is_null() { - return Ok(()); - } - let x = match x { - Value::Float64(OrderedFloat(x)) => x, - _ => unreachable!(), - }; - self.x = Some(x); - - let values = &states[0]; - let values = values - .as_any() - .downcast_ref::() - .with_context(|| DowncastVectorSnafu { - err_msg: format!( - "expect ListVector, got vector type {}", - values.vector_type_name() - ), - })?; - for value in values.values_iter() { - if let Some(value) = value.context(FromScalarValueSnafu)? { - let column: &::VectorType = unsafe { Helper::static_cast(&value) }; - for v in column.iter_data().flatten() { - self.push(v); - } - } - } - Ok(()) - } - - fn evaluate(&self) -> Result { - let mean = self.values.iter().map(|v| v.into_native().as_()).mean(); - let std_dev = self.values.iter().map(|v| v.into_native().as_()).std_dev(); - if mean.is_nan() || std_dev.is_nan() { - Ok(Value::Null) - } else { - let x = if let Some(x) = self.x { - x - } else { - return Ok(Value::Null); - }; - let n = Normal::new(mean, std_dev).context(GenerateFunctionSnafu)?; - Ok(n.cdf(x).into()) - } - } -} - -#[as_aggr_func_creator] -#[derive(Debug, Default, AggrFuncTypeStore)] -pub struct ScipyStatsNormCdfAccumulatorCreator {} - -impl AggregateFunctionCreator for ScipyStatsNormCdfAccumulatorCreator { - fn creator(&self) -> AccumulatorCreatorFunction { - let creator: AccumulatorCreatorFunction = Arc::new(move |types: &[ConcreteDataType]| { - let input_type = &types[0]; - with_match_primitive_type_id!( - input_type.logical_type_id(), - |$S| { - Ok(Box::new(ScipyStatsNormCdf::<<$S as LogicalPrimitiveType>::Wrapper>::default())) - }, - { - let err_msg = format!( - "\"SCIPYSTATSNORMCDF\" aggregate function not support data type {:?}", - input_type.logical_type_id(), - ); - CreateAccumulatorSnafu { err_msg }.fail()? - } - ) - }); - creator - } - - fn output_type(&self) -> Result { - let input_types = self.input_types()?; - ensure!(input_types.len() == 2, InvalidInputStateSnafu); - Ok(ConcreteDataType::float64_datatype()) - } - - fn state_types(&self) -> Result> { - let input_types = self.input_types()?; - ensure!(input_types.len() == 2, InvalidInputStateSnafu); - Ok(vec![ - ConcreteDataType::list_datatype(input_types[0].clone()), - ConcreteDataType::float64_datatype(), - ]) - } -} - -#[cfg(test)] -mod test { - use datatypes::vectors::{Float64Vector, Int32Vector}; - - use super::*; - #[test] - fn test_update_batch() { - // test update empty batch, expect not updating anything - let mut scipy_stats_norm_cdf = ScipyStatsNormCdf::::default(); - scipy_stats_norm_cdf.update_batch(&[]).unwrap(); - assert!(scipy_stats_norm_cdf.values.is_empty()); - assert_eq!(Value::Null, scipy_stats_norm_cdf.evaluate().unwrap()); - - // test update no null-value batch - let mut scipy_stats_norm_cdf = ScipyStatsNormCdf::::default(); - let v: Vec = vec![ - Arc::new(Int32Vector::from(vec![Some(-1i32), Some(1), Some(2)])), - Arc::new(Float64Vector::from(vec![ - Some(2.0_f64), - Some(2.0_f64), - Some(2.0_f64), - ])), - ]; - scipy_stats_norm_cdf.update_batch(&v).unwrap(); - assert_eq!( - Value::from(0.8086334555398362), - scipy_stats_norm_cdf.evaluate().unwrap() - ); - - // test update null-value batch - let mut scipy_stats_norm_cdf = ScipyStatsNormCdf::::default(); - let v: Vec = vec![ - Arc::new(Int32Vector::from(vec![Some(-2i32), None, Some(3), Some(4)])), - Arc::new(Float64Vector::from(vec![ - Some(2.0_f64), - None, - Some(2.0_f64), - Some(2.0_f64), - ])), - ]; - scipy_stats_norm_cdf.update_batch(&v).unwrap(); - assert_eq!( - Value::from(0.5412943699039795), - scipy_stats_norm_cdf.evaluate().unwrap() - ); - } -} diff --git a/src/common/function/src/scalars/aggregate/scipy_stats_norm_pdf.rs b/src/common/function/src/scalars/aggregate/scipy_stats_norm_pdf.rs deleted file mode 100644 index 2d5025ea3a..0000000000 --- a/src/common/function/src/scalars/aggregate/scipy_stats_norm_pdf.rs +++ /dev/null @@ -1,271 +0,0 @@ -// Copyright 2023 Greptime Team -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::sync::Arc; - -use common_macro::{as_aggr_func_creator, AggrFuncTypeStore}; -use common_query::error::{ - self, BadAccumulatorImplSnafu, CreateAccumulatorSnafu, DowncastVectorSnafu, - FromScalarValueSnafu, GenerateFunctionSnafu, InvalidInputColSnafu, InvalidInputStateSnafu, - Result, -}; -use common_query::logical_plan::accumulator::AggrFuncTypeStore; -use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; -use common_query::prelude::*; -use datatypes::prelude::*; -use datatypes::value::{ListValue, OrderedFloat}; -use datatypes::vectors::{ConstantVector, Float64Vector, Helper, ListVector}; -use datatypes::with_match_primitive_type_id; -use num_traits::AsPrimitive; -use snafu::{ensure, OptionExt, ResultExt}; -use statrs::distribution::{Continuous, Normal}; -use statrs::statistics::Statistics; - -// https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.norm.html - -#[derive(Debug, Default)] -pub struct ScipyStatsNormPdf { - values: Vec, - x: Option, -} - -impl ScipyStatsNormPdf { - fn push(&mut self, value: T) { - self.values.push(value); - } -} - -impl Accumulator for ScipyStatsNormPdf -where - T: WrapperType, - T::Native: AsPrimitive + std::iter::Sum, -{ - fn state(&self) -> Result> { - let nums = self - .values - .iter() - .map(|&x| x.into()) - .collect::>(); - Ok(vec![ - Value::List(ListValue::new(nums, T::LogicalType::build_data_type())), - self.x.into(), - ]) - } - - fn update_batch(&mut self, values: &[VectorRef]) -> Result<()> { - if values.is_empty() { - return Ok(()); - } - - ensure!(values.len() == 2, InvalidInputStateSnafu); - ensure!(values[1].len() == values[0].len(), InvalidInputStateSnafu); - - if values[0].len() == 0 { - return Ok(()); - } - - let column = &values[0]; - let mut len = 1; - let column: &::VectorType = if column.is_const() { - len = column.len(); - let column: &ConstantVector = unsafe { Helper::static_cast(column) }; - unsafe { Helper::static_cast(column.inner()) } - } else { - unsafe { Helper::static_cast(column) } - }; - - let x = &values[1]; - let x = Helper::check_get_scalar::(x).context(error::InvalidInputTypeSnafu { - err_msg: "expecting \"SCIPYSTATSNORMPDF\" function's second argument to be a positive integer", - })?; - let first = x.get(0); - ensure!(!first.is_null(), InvalidInputColSnafu); - let first = match first { - Value::Float64(OrderedFloat(v)) => v, - // unreachable because we have checked `first` is not null and is i64 above - _ => unreachable!(), - }; - if let Some(x) = self.x { - ensure!(x == first, InvalidInputColSnafu); - } else { - self.x = Some(first); - }; - - (0..len).for_each(|_| { - for v in column.iter_data().flatten() { - self.push(v); - } - }); - Ok(()) - } - - fn merge_batch(&mut self, states: &[VectorRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - } - - ensure!( - states.len() == 2, - BadAccumulatorImplSnafu { - err_msg: "expect 2 states in `merge_batch`", - } - ); - - let x = &states[1]; - let x = x - .as_any() - .downcast_ref::() - .with_context(|| DowncastVectorSnafu { - err_msg: format!( - "expect Float64Vector, got vector type {}", - x.vector_type_name() - ), - })?; - let x = x.get(0); - if x.is_null() { - return Ok(()); - } - let x = match x { - Value::Float64(OrderedFloat(x)) => x, - _ => unreachable!(), - }; - self.x = Some(x); - - let values = &states[0]; - let values = values - .as_any() - .downcast_ref::() - .with_context(|| DowncastVectorSnafu { - err_msg: format!( - "expect ListVector, got vector type {}", - values.vector_type_name() - ), - })?; - for value in values.values_iter() { - if let Some(value) = value.context(FromScalarValueSnafu)? { - let column: &::VectorType = unsafe { Helper::static_cast(&value) }; - for v in column.iter_data().flatten() { - self.push(v); - } - } - } - Ok(()) - } - - fn evaluate(&self) -> Result { - let mean = self.values.iter().map(|v| v.into_native().as_()).mean(); - let std_dev = self.values.iter().map(|v| v.into_native().as_()).std_dev(); - - if mean.is_nan() || std_dev.is_nan() { - Ok(Value::Null) - } else { - let x = if let Some(x) = self.x { - x - } else { - return Ok(Value::Null); - }; - let n = Normal::new(mean, std_dev).context(GenerateFunctionSnafu)?; - Ok(n.pdf(x).into()) - } - } -} - -#[as_aggr_func_creator] -#[derive(Debug, Default, AggrFuncTypeStore)] -pub struct ScipyStatsNormPdfAccumulatorCreator {} - -impl AggregateFunctionCreator for ScipyStatsNormPdfAccumulatorCreator { - fn creator(&self) -> AccumulatorCreatorFunction { - let creator: AccumulatorCreatorFunction = Arc::new(move |types: &[ConcreteDataType]| { - let input_type = &types[0]; - with_match_primitive_type_id!( - input_type.logical_type_id(), - |$S| { - Ok(Box::new(ScipyStatsNormPdf::<<$S as LogicalPrimitiveType>::Wrapper>::default())) - }, - { - let err_msg = format!( - "\"SCIPYSTATSNORMpdf\" aggregate function not support data type {:?}", - input_type.logical_type_id(), - ); - CreateAccumulatorSnafu { err_msg }.fail()? - } - ) - }); - creator - } - - fn output_type(&self) -> Result { - let input_types = self.input_types()?; - ensure!(input_types.len() == 2, InvalidInputStateSnafu); - Ok(ConcreteDataType::float64_datatype()) - } - - fn state_types(&self) -> Result> { - let input_types = self.input_types()?; - ensure!(input_types.len() == 2, InvalidInputStateSnafu); - Ok(vec![ - ConcreteDataType::list_datatype(input_types[0].clone()), - ConcreteDataType::float64_datatype(), - ]) - } -} - -#[cfg(test)] -mod test { - use datatypes::vectors::{Float64Vector, Int32Vector}; - - use super::*; - #[test] - fn test_update_batch() { - // test update empty batch, expect not updating anything - let mut scipy_stats_norm_pdf = ScipyStatsNormPdf::::default(); - scipy_stats_norm_pdf.update_batch(&[]).unwrap(); - assert!(scipy_stats_norm_pdf.values.is_empty()); - assert_eq!(Value::Null, scipy_stats_norm_pdf.evaluate().unwrap()); - - // test update no null-value batch - let mut scipy_stats_norm_pdf = ScipyStatsNormPdf::::default(); - let v: Vec = vec![ - Arc::new(Int32Vector::from(vec![Some(-1i32), Some(1), Some(2)])), - Arc::new(Float64Vector::from(vec![ - Some(2.0_f64), - Some(2.0_f64), - Some(2.0_f64), - ])), - ]; - scipy_stats_norm_pdf.update_batch(&v).unwrap(); - assert_eq!( - Value::from(0.17843340219081558), - scipy_stats_norm_pdf.evaluate().unwrap() - ); - - // test update null-value batch - let mut scipy_stats_norm_pdf = ScipyStatsNormPdf::::default(); - let v: Vec = vec![ - Arc::new(Int32Vector::from(vec![Some(-2i32), None, Some(3), Some(4)])), - Arc::new(Float64Vector::from(vec![ - Some(2.0_f64), - None, - Some(2.0_f64), - Some(2.0_f64), - ])), - ]; - scipy_stats_norm_pdf.update_batch(&v).unwrap(); - assert_eq!( - Value::from(0.12343972049858312), - scipy_stats_norm_pdf.evaluate().unwrap() - ); - } -} diff --git a/src/common/query/src/error.rs b/src/common/query/src/error.rs index b81d4cde8b..408bbab95d 100644 --- a/src/common/query/src/error.rs +++ b/src/common/query/src/error.rs @@ -24,7 +24,6 @@ use datatypes::arrow::datatypes::DataType as ArrowDatatype; use datatypes::error::Error as DataTypeError; use datatypes::prelude::ConcreteDataType; use snafu::{Location, Snafu}; -use statrs::StatsError; #[derive(Snafu)] #[snafu(visibility(pub))] @@ -38,14 +37,6 @@ pub enum Error { location: Location, }, - #[snafu(display("Failed to generate function"))] - GenerateFunction { - #[snafu(source)] - error: StatsError, - #[snafu(implicit)] - location: Location, - }, - #[snafu(display("Failed to cast scalar value into vector"))] FromScalarValue { #[snafu(implicit)] @@ -97,12 +88,6 @@ pub enum Error { location: Location, }, - #[snafu(display("unexpected: not constant column"))] - InvalidInputCol { - #[snafu(implicit)] - location: Location, - }, - #[snafu(display("General DataFusion error"))] GeneralDataFusion { #[snafu(source)] @@ -248,8 +233,6 @@ impl ErrorExt for Error { Error::CreateAccumulator { .. } | Error::DowncastVector { .. } | Error::InvalidInputState { .. } - | Error::InvalidInputCol { .. } - | Error::GenerateFunction { .. } | Error::BadAccumulatorImpl { .. } | Error::ToScalarValue { .. } | Error::GetScalarVector { .. } diff --git a/src/query/src/tests.rs b/src/query/src/tests.rs index cbce67a4fe..f2f2e40bf3 100644 --- a/src/query/src/tests.rs +++ b/src/query/src/tests.rs @@ -21,14 +21,8 @@ use table::TableRef; use crate::parser::QueryLanguageParser; use crate::{QueryEngineFactory, QueryEngineRef}; -mod argmax_test; -mod argmin_test; -mod mean_test; mod my_sum_udaf_example; -mod polyval_test; mod query_engine_test; -mod scipy_stats_norm_cdf_test; -mod scipy_stats_norm_pdf; mod time_range_filter_test; mod function; diff --git a/src/query/src/tests/argmax_test.rs b/src/query/src/tests/argmax_test.rs deleted file mode 100644 index 9f9c86e8e7..0000000000 --- a/src/query/src/tests/argmax_test.rs +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright 2023 Greptime Team -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::sync::Arc; - -use datatypes::for_all_primitive_types; -use datatypes::prelude::*; -use datatypes::types::WrapperType; - -use crate::error::Result; -use crate::tests::{exec_selection, function}; -use crate::QueryEngine; - -#[tokio::test] -async fn test_argmax_aggregator() -> Result<()> { - common_telemetry::init_default_ut_logging(); - let engine = function::create_query_engine(); - - macro_rules! test_argmax { - ([], $( { $T:ty } ),*) => { - $( - let column_name = format!("{}_number", std::any::type_name::<$T>()); - test_argmax_success::<$T>(&column_name, "numbers", engine.clone()).await?; - )* - } - } - for_all_primitive_types! { test_argmax } - Ok(()) -} - -async fn test_argmax_success( - column_name: &str, - table_name: &str, - engine: Arc, -) -> Result<()> -where - T: WrapperType + PartialOrd, -{ - let sql = format!("select ARGMAX({column_name}) as argmax from {table_name}"); - let result = exec_selection(engine.clone(), &sql).await; - let value = function::get_value_from_batches("argmax", result); - - let numbers = - function::get_numbers_from_table::(column_name, table_name, engine.clone()).await; - let expected_value = match numbers.len() { - 0 => 0_u64, - _ => { - let mut index = 0; - let mut max = numbers[0]; - for (i, &number) in numbers.iter().enumerate() { - if max < number { - max = number; - index = i; - } - } - index as u64 - } - }; - let expected_value = Value::from(expected_value); - assert_eq!(value, expected_value); - Ok(()) -} diff --git a/src/query/src/tests/argmin_test.rs b/src/query/src/tests/argmin_test.rs deleted file mode 100644 index 5baa532cc6..0000000000 --- a/src/query/src/tests/argmin_test.rs +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright 2023 Greptime Team -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::sync::Arc; - -use datatypes::for_all_primitive_types; -use datatypes::prelude::*; -use datatypes::types::WrapperType; - -use crate::error::Result; -use crate::tests::{exec_selection, function}; -use crate::QueryEngine; - -#[tokio::test] -async fn test_argmin_aggregator() -> Result<()> { - common_telemetry::init_default_ut_logging(); - let engine = function::create_query_engine(); - - macro_rules! test_argmin { - ([], $( { $T:ty } ),*) => { - $( - let column_name = format!("{}_number", std::any::type_name::<$T>()); - test_argmin_success::<$T>(&column_name, "numbers", engine.clone()).await?; - )* - } - } - for_all_primitive_types! { test_argmin } - Ok(()) -} - -async fn test_argmin_success( - column_name: &str, - table_name: &str, - engine: Arc, -) -> Result<()> -where - T: WrapperType + PartialOrd, -{ - let sql = format!("select argmin({column_name}) as argmin from {table_name}"); - let result = exec_selection(engine.clone(), &sql).await; - let value = function::get_value_from_batches("argmin", result); - - let numbers = - function::get_numbers_from_table::(column_name, table_name, engine.clone()).await; - let expected_value = match numbers.len() { - 0 => 0_u32, - _ => { - let mut index = 0; - let mut min = numbers[0]; - for (i, &number) in numbers.iter().enumerate() { - if min > number { - min = number; - index = i; - } - } - index as u32 - } - }; - let expected_value = Value::from(expected_value); - assert_eq!(value, expected_value); - Ok(()) -} diff --git a/src/query/src/tests/function.rs b/src/query/src/tests/function.rs index 49ed1b8850..9a5071f199 100644 --- a/src/query/src/tests/function.rs +++ b/src/query/src/tests/function.rs @@ -16,42 +16,14 @@ use std::sync::Arc; use common_function::scalars::vector::impl_conv::veclit_to_binlit; use common_recordbatch::RecordBatch; -use datatypes::for_all_primitive_types; use datatypes::prelude::*; use datatypes::schema::{ColumnSchema, Schema}; -use datatypes::types::WrapperType; -use datatypes::vectors::{BinaryVector, Helper}; +use datatypes::vectors::BinaryVector; use rand::Rng; use table::test_util::MemTable; -use crate::tests::{exec_selection, new_query_engine_with_table}; -use crate::{QueryEngine, QueryEngineRef}; - -pub fn create_query_engine() -> QueryEngineRef { - let mut column_schemas = vec![]; - let mut columns = vec![]; - macro_rules! create_number_table { - ([], $( { $T:ty } ),*) => { - $( - let mut rng = rand::thread_rng(); - - let column_name = format!("{}_number", std::any::type_name::<$T>()); - let column_schema = ColumnSchema::new(column_name, Value::from(<$T>::default()).data_type(), true); - column_schemas.push(column_schema); - - let numbers = (1..=10).map(|_| rng.gen::<$T>()).collect::>(); - let column: VectorRef = Arc::new(<$T as Scalar>::VectorType::from_vec(numbers.to_vec())); - columns.push(column); - )* - } - } - for_all_primitive_types! { create_number_table } - - let schema = Arc::new(Schema::new(column_schemas.clone())); - let recordbatch = RecordBatch::new(schema, columns).unwrap(); - let number_table = MemTable::table("numbers", recordbatch); - new_query_engine_with_table(number_table) -} +use crate::tests::new_query_engine_with_table; +use crate::QueryEngineRef; pub fn create_query_engine_for_vector10x3() -> QueryEngineRef { let mut column_schemas = vec![]; @@ -81,22 +53,6 @@ pub fn create_query_engine_for_vector10x3() -> QueryEngineRef { new_query_engine_with_table(vector_table) } -pub async fn get_numbers_from_table<'s, T>( - column_name: &'s str, - table_name: &'s str, - engine: Arc, -) -> Vec -where - T: WrapperType, -{ - let sql = format!("SELECT {column_name} FROM {table_name}"); - let numbers = exec_selection(engine, &sql).await; - - let column = numbers[0].column(0); - let column: &::VectorType = unsafe { Helper::static_cast(column) }; - column.iter_data().flatten().collect::>() -} - pub fn get_value_from_batches(column_name: &str, batches: Vec) -> Value { assert_eq!(1, batches.len()); assert_eq!(batches[0].num_columns(), 1); diff --git a/src/query/src/tests/mean_test.rs b/src/query/src/tests/mean_test.rs deleted file mode 100644 index 288d25ba2a..0000000000 --- a/src/query/src/tests/mean_test.rs +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2023 Greptime Team -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::sync::Arc; - -use datatypes::for_all_primitive_types; -use datatypes::prelude::*; -use datatypes::types::WrapperType; -use datatypes::value::OrderedFloat; -use num_traits::AsPrimitive; - -use crate::error::Result; -use crate::tests::{exec_selection, function}; -use crate::QueryEngine; - -#[tokio::test] -async fn test_mean_aggregator() -> Result<()> { - common_telemetry::init_default_ut_logging(); - let engine = function::create_query_engine(); - - macro_rules! test_mean { - ([], $( { $T:ty } ),*) => { - $( - let column_name = format!("{}_number", std::any::type_name::<$T>()); - test_mean_success::<$T>(&column_name, "numbers", engine.clone()).await?; - )* - } - } - for_all_primitive_types! { test_mean } - Ok(()) -} - -async fn test_mean_success( - column_name: &str, - table_name: &str, - engine: Arc, -) -> Result<()> -where - T: WrapperType + AsPrimitive, -{ - let sql = format!("select MEAN({column_name}) as mean from {table_name}"); - let result = exec_selection(engine.clone(), &sql).await; - let value = function::get_value_from_batches("mean", result); - - let numbers = - function::get_numbers_from_table::(column_name, table_name, engine.clone()).await; - let numbers = numbers.iter().map(|&n| n.as_()).collect::>(); - let expected = numbers.iter().sum::() / (numbers.len() as f64); - let Value::Float64(OrderedFloat(value)) = value else { - unreachable!() - }; - assert!( - (value - expected).abs() < 1e-3, - "expected {expected}, actual {value}" - ); - Ok(()) -} diff --git a/src/query/src/tests/polyval_test.rs b/src/query/src/tests/polyval_test.rs deleted file mode 100644 index 5e0f44d559..0000000000 --- a/src/query/src/tests/polyval_test.rs +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright 2023 Greptime Team -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::sync::Arc; - -use datatypes::for_all_primitive_types; -use datatypes::prelude::*; -use datatypes::types::WrapperType; -use num_traits::AsPrimitive; - -use crate::error::Result; -use crate::tests::{exec_selection, function}; -use crate::QueryEngine; - -#[tokio::test] -async fn test_polyval_aggregator() -> Result<()> { - common_telemetry::init_default_ut_logging(); - let engine = function::create_query_engine(); - - macro_rules! test_polyval { - ([], $( { $T:ty } ),*) => { - $( - let column_name = format!("{}_number", std::any::type_name::<$T>()); - test_polyval_success::<$T, <<<$T as WrapperType>::LogicalType as LogicalPrimitiveType>::LargestType as LogicalPrimitiveType>::Wrapper>(&column_name, "numbers", engine.clone()).await?; - )* - } - } - for_all_primitive_types! { test_polyval } - Ok(()) -} - -async fn test_polyval_success( - column_name: &str, - table_name: &str, - engine: Arc, -) -> Result<()> -where - T: WrapperType, - PolyT: WrapperType, - T::Native: AsPrimitive, - PolyT::Native: std::ops::Mul + std::iter::Sum, - i64: AsPrimitive, -{ - let sql = format!("select POLYVAL({column_name}, 0) as polyval from {table_name}"); - let result = exec_selection(engine.clone(), &sql).await; - let value = function::get_value_from_batches("polyval", result); - - let numbers = - function::get_numbers_from_table::(column_name, table_name, engine.clone()).await; - let expected_value = numbers.iter().copied(); - let x = 0i64; - let len = expected_value.len(); - let expected_native: PolyT::Native = expected_value - .enumerate() - .map(|(i, v)| v.into_native().as_() * (x.pow((len - 1 - i) as u32)).as_()) - .sum(); - assert_eq!(value, PolyT::from_native(expected_native).into()); - Ok(()) -} diff --git a/src/query/src/tests/scipy_stats_norm_cdf_test.rs b/src/query/src/tests/scipy_stats_norm_cdf_test.rs deleted file mode 100644 index de4015c0b7..0000000000 --- a/src/query/src/tests/scipy_stats_norm_cdf_test.rs +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2023 Greptime Team -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::sync::Arc; - -use datatypes::for_all_primitive_types; -use datatypes::types::WrapperType; -use num_traits::AsPrimitive; -use statrs::distribution::{ContinuousCDF, Normal}; -use statrs::statistics::Statistics; - -use crate::error::Result; -use crate::tests::{exec_selection, function}; -use crate::QueryEngine; - -#[tokio::test] -async fn test_scipy_stats_norm_cdf_aggregator() -> Result<()> { - common_telemetry::init_default_ut_logging(); - let engine = function::create_query_engine(); - - macro_rules! test_scipy_stats_norm_cdf { - ([], $( { $T:ty } ),*) => { - $( - let column_name = format!("{}_number", std::any::type_name::<$T>()); - test_scipy_stats_norm_cdf_success::<$T>(&column_name, "numbers", engine.clone()).await?; - )* - } - } - for_all_primitive_types! { test_scipy_stats_norm_cdf } - Ok(()) -} - -async fn test_scipy_stats_norm_cdf_success( - column_name: &str, - table_name: &str, - engine: Arc, -) -> Result<()> -where - T: WrapperType + AsPrimitive, -{ - let sql = format!( - "select SCIPYSTATSNORMCDF({column_name},2.0) as scipy_stats_norm_cdf from {table_name}", - ); - let result = exec_selection(engine.clone(), &sql).await; - let value = function::get_value_from_batches("scipy_stats_norm_cdf", result); - - let numbers = - function::get_numbers_from_table::(column_name, table_name, engine.clone()).await; - let expected_value = numbers.iter().map(|&n| n.as_()).collect::>(); - let mean = expected_value.clone().mean(); - let stddev = expected_value.std_dev(); - - let n = Normal::new(mean, stddev).unwrap(); - let expected_value = n.cdf(2.0); - - assert_eq!(value, expected_value.into()); - Ok(()) -} diff --git a/src/query/src/tests/scipy_stats_norm_pdf.rs b/src/query/src/tests/scipy_stats_norm_pdf.rs deleted file mode 100644 index 85e0cd7771..0000000000 --- a/src/query/src/tests/scipy_stats_norm_pdf.rs +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2023 Greptime Team -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::sync::Arc; - -use datatypes::for_all_primitive_types; -use datatypes::types::WrapperType; -use num_traits::AsPrimitive; -use statrs::distribution::{Continuous, Normal}; -use statrs::statistics::Statistics; - -use crate::error::Result; -use crate::tests::{exec_selection, function}; -use crate::QueryEngine; - -#[tokio::test] -async fn test_scipy_stats_norm_pdf_aggregator() -> Result<()> { - common_telemetry::init_default_ut_logging(); - let engine = function::create_query_engine(); - - macro_rules! test_scipy_stats_norm_pdf { - ([], $( { $T:ty } ),*) => { - $( - let column_name = format!("{}_number", std::any::type_name::<$T>()); - test_scipy_stats_norm_pdf_success::<$T>(&column_name, "numbers", engine.clone()).await?; - )* - } - } - for_all_primitive_types! { test_scipy_stats_norm_pdf } - Ok(()) -} - -async fn test_scipy_stats_norm_pdf_success( - column_name: &str, - table_name: &str, - engine: Arc, -) -> Result<()> -where - T: WrapperType + AsPrimitive, -{ - let sql = format!( - "select SCIPYSTATSNORMPDF({column_name},2.0) as scipy_stats_norm_pdf from {table_name}" - ); - let result = exec_selection(engine.clone(), &sql).await; - let value = function::get_value_from_batches("scipy_stats_norm_pdf", result); - - let numbers = - function::get_numbers_from_table::(column_name, table_name, engine.clone()).await; - let expected_value = numbers.iter().map(|&n| n.as_()).collect::>(); - let mean = expected_value.clone().mean(); - let stddev = expected_value.std_dev(); - - let n = Normal::new(mean, stddev).unwrap(); - let expected_value = n.pdf(2.0); - - assert_eq!(value, expected_value.into()); - Ok(()) -}