diff --git a/src/common/function/src/scalars/aggregate.rs b/src/common/function/src/scalars/aggregate.rs index 8a4712a1b8..0373b54712 100644 --- a/src/common/function/src/scalars/aggregate.rs +++ b/src/common/function/src/scalars/aggregate.rs @@ -35,8 +35,6 @@ pub use polyval::PolyvalAccumulatorCreator; pub use scipy_stats_norm_cdf::ScipyStatsNormCdfAccumulatorCreator; pub use scipy_stats_norm_pdf::ScipyStatsNormPdfAccumulatorCreator; -use crate::scalars::FunctionRegistry; - /// A function creates `AggregateFunctionCreator`. /// "Aggregator" *is* AggregatorFunction. Since the later one is long, we named an short alias for it. /// The two names might be used interchangeably. diff --git a/src/common/function/src/scalars/aggregate/argmax.rs b/src/common/function/src/scalars/aggregate/argmax.rs index 0b63a766bd..63a45fa855 100644 --- a/src/common/function/src/scalars/aggregate/argmax.rs +++ b/src/common/function/src/scalars/aggregate/argmax.rs @@ -20,24 +20,22 @@ use common_query::error::{BadAccumulatorImplSnafu, CreateAccumulatorSnafu, Resul use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; use common_query::prelude::*; use datatypes::prelude::*; -use datatypes::vectors::ConstantVector; +use datatypes::types::{LogicalPrimitiveType, WrapperType}; +use datatypes::vectors::{ConstantVector, Helper}; use datatypes::with_match_primitive_type_id; use snafu::ensure; // https://numpy.org/doc/stable/reference/generated/numpy.argmax.html // return the index of the max value #[derive(Debug, Default)] -pub struct Argmax -where - T: Primitive + PartialOrd, -{ +pub struct Argmax { max: Option, n: u64, } impl Argmax where - T: Primitive + PartialOrd, + T: PartialOrd, { fn update(&mut self, value: T, index: u64) { if let Some(Ordering::Less) = self.max.partial_cmp(&Some(value)) { @@ -49,8 +47,7 @@ where impl Accumulator for Argmax where - T: Primitive + PartialOrd, - for<'a> T: Scalar = T>, + T: WrapperType + PartialOrd, { fn state(&self) -> Result> { match self.max { @@ -66,10 +63,10 @@ where let column = &values[0]; let column: &::VectorType = if column.is_const() { - let column: &ConstantVector = unsafe { VectorHelper::static_cast(column) }; - unsafe { VectorHelper::static_cast(column.inner()) } + let column: &ConstantVector = unsafe { Helper::static_cast(column) }; + unsafe { Helper::static_cast(column.inner()) } } else { - unsafe { VectorHelper::static_cast(column) } + unsafe { Helper::static_cast(column) } }; for (i, v) in column.iter_data().enumerate() { if let Some(value) = v { @@ -93,8 +90,8 @@ where let max = &states[0]; let index = &states[1]; - let max: &::VectorType = unsafe { VectorHelper::static_cast(max) }; - let index: &::VectorType = unsafe { VectorHelper::static_cast(index) }; + let max: &::VectorType = unsafe { Helper::static_cast(max) }; + let index: &::VectorType = unsafe { Helper::static_cast(index) }; index .iter_data() .flatten() @@ -122,7 +119,7 @@ impl AggregateFunctionCreator for ArgmaxAccumulatorCreator { with_match_primitive_type_id!( input_type.logical_type_id(), |$S| { - Ok(Box::new(Argmax::<$S>::default())) + Ok(Box::new(Argmax::<<$S as LogicalPrimitiveType>::Wrapper>::default())) }, { let err_msg = format!( @@ -154,7 +151,7 @@ impl AggregateFunctionCreator for ArgmaxAccumulatorCreator { #[cfg(test)] mod test { - use datatypes::vectors::PrimitiveVector; + use datatypes::vectors::Int32Vector; use super::*; #[test] @@ -166,21 +163,19 @@ mod test { // test update one not-null value let mut argmax = Argmax::::default(); - let v: Vec = vec![Arc::new(PrimitiveVector::::from(vec![Some(42)]))]; + let v: Vec = vec![Arc::new(Int32Vector::from(vec![Some(42)]))]; assert!(argmax.update_batch(&v).is_ok()); assert_eq!(Value::from(0_u64), argmax.evaluate().unwrap()); // test update one null value let mut argmax = Argmax::::default(); - let v: Vec = vec![Arc::new(PrimitiveVector::::from(vec![ - Option::::None, - ]))]; + let v: Vec = vec![Arc::new(Int32Vector::from(vec![Option::::None]))]; assert!(argmax.update_batch(&v).is_ok()); assert_eq!(Value::Null, argmax.evaluate().unwrap()); // test update no null-value batch let mut argmax = Argmax::::default(); - let v: Vec = vec![Arc::new(PrimitiveVector::::from(vec![ + let v: Vec = vec![Arc::new(Int32Vector::from(vec![ Some(-1i32), Some(1), Some(3), @@ -190,7 +185,7 @@ mod test { // test update null-value batch let mut argmax = Argmax::::default(); - let v: Vec = vec![Arc::new(PrimitiveVector::::from(vec![ + let v: Vec = vec![Arc::new(Int32Vector::from(vec![ Some(-2i32), None, Some(4), @@ -201,7 +196,7 @@ mod test { // test update with constant vector let mut argmax = Argmax::::default(); let v: Vec = vec![Arc::new(ConstantVector::new( - Arc::new(PrimitiveVector::::from_vec(vec![4])), + Arc::new(Int32Vector::from_vec(vec![4])), 10, ))]; assert!(argmax.update_batch(&v).is_ok()); diff --git a/src/common/function/src/scalars/aggregate/diff.rs b/src/common/function/src/scalars/aggregate/diff.rs index 2e0b38f1e7..3f7ecc2400 100644 --- a/src/common/function/src/scalars/aggregate/diff.rs +++ b/src/common/function/src/scalars/aggregate/diff.rs @@ -22,7 +22,6 @@ use common_query::error::{ use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; use common_query::prelude::*; use datatypes::prelude::*; -use datatypes::types::{LogicalPrimitiveType, WrapperType}; use datatypes::value::ListValue; use datatypes::vectors::{ConstantVector, Helper, ListVector}; use datatypes::with_match_primitive_type_id; @@ -32,20 +31,12 @@ use snafu::{ensure, OptionExt, ResultExt}; // https://numpy.org/doc/stable/reference/generated/numpy.diff.html // I is the input type, O is the output type. #[derive(Debug, Default)] -pub struct Diff -where - I: WrapperType, - O: WrapperType, -{ +pub struct Diff { values: Vec, _phantom: PhantomData, } -impl Diff -where - I: WrapperType, - O: WrapperType, -{ +impl Diff { fn push(&mut self, value: I) { self.values.push(value); } diff --git a/src/common/function/src/scalars/aggregate/median.rs b/src/common/function/src/scalars/aggregate/median.rs index 4c445c0fb9..facbd8702a 100644 --- a/src/common/function/src/scalars/aggregate/median.rs +++ b/src/common/function/src/scalars/aggregate/median.rs @@ -25,7 +25,7 @@ use common_query::prelude::*; use datatypes::prelude::*; use datatypes::types::OrdPrimitive; use datatypes::value::ListValue; -use datatypes::vectors::{ConstantVector, ListVector}; +use datatypes::vectors::{ConstantVector, Helper, ListVector}; use datatypes::with_match_primitive_type_id; use num::NumCast; use snafu::{ensure, OptionExt, ResultExt}; @@ -51,7 +51,7 @@ use snafu::{ensure, OptionExt, ResultExt}; #[derive(Debug, Default)] pub struct Median where - T: Primitive, + T: WrapperType, { greater: BinaryHeap>>, not_greater: BinaryHeap>, @@ -59,7 +59,7 @@ where impl Median where - T: Primitive, + T: WrapperType, { fn push(&mut self, value: T) { let value = OrdPrimitive::(value); @@ -87,8 +87,7 @@ where // to use them. impl Accumulator for Median where - T: Primitive, - for<'a> T: Scalar = T>, + T: WrapperType + NumCast, { // This function serializes our state to `ScalarValue`, which DataFusion uses to pass this // state between execution stages. Note that this can be arbitrary data. @@ -105,7 +104,7 @@ where .collect::>(); Ok(vec![Value::List(ListValue::new( Some(Box::new(nums)), - T::default().into().data_type(), + T::LogicalType::build_data_type(), ))]) } @@ -123,10 +122,10 @@ where let mut len = 1; let column: &::VectorType = if column.is_const() { len = column.len(); - let column: &ConstantVector = unsafe { VectorHelper::static_cast(column) }; - unsafe { VectorHelper::static_cast(column.inner()) } + let column: &ConstantVector = unsafe { Helper::static_cast(column) }; + unsafe { Helper::static_cast(column.inner()) } } else { - unsafe { VectorHelper::static_cast(column) } + unsafe { Helper::static_cast(column) } }; (0..len).for_each(|_| { for v in column.iter_data().flatten() { @@ -156,9 +155,10 @@ where ), })?; for state in states.values_iter() { - let state = state.context(FromScalarValueSnafu)?; - // merging state is simply accumulate stored numbers from others', so just call update - self.update_batch(&[state])? + if let Some(state) = state.context(FromScalarValueSnafu)? { + // merging state is simply accumulate stored numbers from others', so just call update + self.update_batch(&[state])? + } } Ok(()) } @@ -202,7 +202,7 @@ impl AggregateFunctionCreator for MedianAccumulatorCreator { with_match_primitive_type_id!( input_type.logical_type_id(), |$S| { - Ok(Box::new(Median::<$S>::default())) + Ok(Box::new(Median::<<$S as LogicalPrimitiveType>::Wrapper>::default())) }, { let err_msg = format!( @@ -230,7 +230,7 @@ impl AggregateFunctionCreator for MedianAccumulatorCreator { #[cfg(test)] mod test { - use datatypes::vectors::PrimitiveVector; + use datatypes::vectors::Int32Vector; use super::*; #[test] @@ -244,21 +244,19 @@ mod test { // test update one not-null value let mut median = Median::::default(); - let v: Vec = vec![Arc::new(PrimitiveVector::::from(vec![Some(42)]))]; + let v: Vec = vec![Arc::new(Int32Vector::from(vec![Some(42)]))]; assert!(median.update_batch(&v).is_ok()); assert_eq!(Value::Int32(42), median.evaluate().unwrap()); // test update one null value let mut median = Median::::default(); - let v: Vec = vec![Arc::new(PrimitiveVector::::from(vec![ - Option::::None, - ]))]; + let v: Vec = vec![Arc::new(Int32Vector::from(vec![Option::::None]))]; assert!(median.update_batch(&v).is_ok()); assert_eq!(Value::Null, median.evaluate().unwrap()); // test update no null-value batch let mut median = Median::::default(); - let v: Vec = vec![Arc::new(PrimitiveVector::::from(vec![ + let v: Vec = vec![Arc::new(Int32Vector::from(vec![ Some(-1i32), Some(1), Some(2), @@ -268,7 +266,7 @@ mod test { // test update null-value batch let mut median = Median::::default(); - let v: Vec = vec![Arc::new(PrimitiveVector::::from(vec![ + let v: Vec = vec![Arc::new(Int32Vector::from(vec![ Some(-2i32), None, Some(3), @@ -280,7 +278,7 @@ mod test { // test update with constant vector let mut median = Median::::default(); let v: Vec = vec![Arc::new(ConstantVector::new( - Arc::new(PrimitiveVector::::from_vec(vec![4])), + Arc::new(Int32Vector::from_vec(vec![4])), 10, ))]; assert!(median.update_batch(&v).is_ok()); diff --git a/src/common/function/src/scalars/aggregate/scipy_stats_norm_cdf.rs b/src/common/function/src/scalars/aggregate/scipy_stats_norm_cdf.rs index 8f43b64e92..9a50fb2c55 100644 --- a/src/common/function/src/scalars/aggregate/scipy_stats_norm_cdf.rs +++ b/src/common/function/src/scalars/aggregate/scipy_stats_norm_cdf.rs @@ -23,7 +23,7 @@ use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; use common_query::prelude::*; use datatypes::prelude::*; use datatypes::value::{ListValue, OrderedFloat}; -use datatypes::vectors::{ConstantVector, Float64Vector, ListVector}; +use datatypes::vectors::{ConstantVector, Float64Vector, Helper, ListVector}; use datatypes::with_match_primitive_type_id; use num_traits::AsPrimitive; use snafu::{ensure, OptionExt, ResultExt}; @@ -33,18 +33,12 @@ use statrs::statistics::Statistics; // https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.norm.html #[derive(Debug, Default)] -pub struct ScipyStatsNormCdf -where - T: Primitive + AsPrimitive + std::iter::Sum, -{ +pub struct ScipyStatsNormCdf { values: Vec, x: Option, } -impl ScipyStatsNormCdf -where - T: Primitive + AsPrimitive + std::iter::Sum, -{ +impl ScipyStatsNormCdf { fn push(&mut self, value: T) { self.values.push(value); } @@ -52,8 +46,8 @@ where impl Accumulator for ScipyStatsNormCdf where - T: Primitive + AsPrimitive + std::iter::Sum, - for<'a> T: Scalar = T>, + T: WrapperType + std::iter::Sum, + T::Native: AsPrimitive, { fn state(&self) -> Result> { let nums = self @@ -64,7 +58,7 @@ where Ok(vec![ Value::List(ListValue::new( Some(Box::new(nums)), - T::default().into().data_type(), + T::LogicalType::build_data_type(), )), self.x.into(), ]) @@ -86,14 +80,14 @@ where let mut len = 1; let column: &::VectorType = if column.is_const() { len = column.len(); - let column: &ConstantVector = unsafe { VectorHelper::static_cast(column) }; - unsafe { VectorHelper::static_cast(column.inner()) } + let column: &ConstantVector = unsafe { Helper::static_cast(column) }; + unsafe { Helper::static_cast(column.inner()) } } else { - unsafe { VectorHelper::static_cast(column) } + unsafe { Helper::static_cast(column) } }; let x = &values[1]; - let x = VectorHelper::check_get_scalar::(x).context(error::InvalidInputsSnafu { + let x = Helper::check_get_scalar::(x).context(error::InvalidInputsSnafu { err_msg: "expecting \"SCIPYSTATSNORMCDF\" function's second argument to be a positive integer", })?; let first = x.get(0); @@ -160,19 +154,19 @@ where ), })?; for value in values.values_iter() { - let value = value.context(FromScalarValueSnafu)?; - let column: &::VectorType = unsafe { VectorHelper::static_cast(&value) }; - for v in column.iter_data().flatten() { - self.push(v); + if let Some(value) = value.context(FromScalarValueSnafu)? { + let column: &::VectorType = unsafe { Helper::static_cast(&value) }; + for v in column.iter_data().flatten() { + self.push(v); + } } } Ok(()) } fn evaluate(&self) -> Result { - let values = self.values.iter().map(|&v| v.as_()).collect::>(); - let mean = values.clone().mean(); - let std_dev = values.std_dev(); + let mean = self.values.iter().map(|v| v.into_native().as_()).mean(); + let std_dev = self.values.iter().map(|v| v.into_native().as_()).std_dev(); if mean.is_nan() || std_dev.is_nan() { Ok(Value::Null) } else { @@ -198,7 +192,7 @@ impl AggregateFunctionCreator for ScipyStatsNormCdfAccumulatorCreator { with_match_primitive_type_id!( input_type.logical_type_id(), |$S| { - Ok(Box::new(ScipyStatsNormCdf::<$S>::default())) + Ok(Box::new(ScipyStatsNormCdf::<<$S as LogicalPrimitiveType>::Wrapper>::default())) }, { let err_msg = format!( @@ -230,7 +224,7 @@ impl AggregateFunctionCreator for ScipyStatsNormCdfAccumulatorCreator { #[cfg(test)] mod test { - use datatypes::vectors::PrimitiveVector; + use datatypes::vectors::{Float64Vector, Int32Vector}; use super::*; #[test] @@ -244,12 +238,8 @@ mod test { // test update no null-value batch let mut scipy_stats_norm_cdf = ScipyStatsNormCdf::::default(); let v: Vec = vec![ - Arc::new(PrimitiveVector::::from(vec![ - Some(-1i32), - Some(1), - Some(2), - ])), - Arc::new(PrimitiveVector::::from(vec![ + Arc::new(Int32Vector::from(vec![Some(-1i32), Some(1), Some(2)])), + Arc::new(Float64Vector::from(vec![ Some(2.0_f64), Some(2.0_f64), Some(2.0_f64), @@ -264,13 +254,8 @@ mod test { // test update null-value batch let mut scipy_stats_norm_cdf = ScipyStatsNormCdf::::default(); let v: Vec = vec![ - Arc::new(PrimitiveVector::::from(vec![ - Some(-2i32), - None, - Some(3), - Some(4), - ])), - Arc::new(PrimitiveVector::::from(vec![ + Arc::new(Int32Vector::from(vec![Some(-2i32), None, Some(3), Some(4)])), + Arc::new(Float64Vector::from(vec![ Some(2.0_f64), None, Some(2.0_f64), diff --git a/src/common/function/src/scalars/math/rate.rs b/src/common/function/src/scalars/math/rate.rs index 628a19408a..ad03485a36 100644 --- a/src/common/function/src/scalars/math/rate.rs +++ b/src/common/function/src/scalars/math/rate.rs @@ -14,10 +14,10 @@ use std::fmt; -use arrow::array::Array; -use common_query::error::{FromArrowArraySnafu, Result, TypeCastSnafu}; +use common_query::error::{self, Result}; use common_query::prelude::{Signature, Volatility}; -use datatypes::arrow; +use datatypes::arrow::compute::kernels::{arithmetic, cast}; +use datatypes::arrow::datatypes::DataType; use datatypes::prelude::*; use datatypes::vectors::{Helper, VectorRef}; use snafu::ResultExt; @@ -51,28 +51,21 @@ impl Function for RateFunction { let val = &columns[0].to_arrow_array(); let val_0 = val.slice(0, val.len() - 1); let val_1 = val.slice(1, val.len() - 1); - let dv = arrow::compute::arithmetics::sub(&*val_1, &*val_0); + let dv = arithmetic::subtract_dyn(&val_1, &val_0).context(error::ArrowComputeSnafu)?; let ts = &columns[1].to_arrow_array(); let ts_0 = ts.slice(0, ts.len() - 1); let ts_1 = ts.slice(1, ts.len() - 1); - let dt = arrow::compute::arithmetics::sub(&*ts_1, &*ts_0); - fn all_to_f64(array: &dyn Array) -> Result> { - Ok(arrow::compute::cast::cast( - array, - &arrow::datatypes::DataType::Float64, - arrow::compute::cast::CastOptions { - wrapped: true, - partial: true, - }, - ) - .context(TypeCastSnafu { - typ: arrow::datatypes::DataType::Float64, - })?) - } - let dv = all_to_f64(&*dv)?; - let dt = all_to_f64(&*dt)?; - let rate = arrow::compute::arithmetics::div(&*dv, &*dt); - let v = Helper::try_into_vector(&rate).context(FromArrowArraySnafu)?; + let dt = arithmetic::subtract_dyn(&ts_1, &ts_0).context(error::ArrowComputeSnafu)?; + + let dv = cast::cast(&dv, &DataType::Float64).context(error::TypeCastSnafu { + typ: DataType::Float64, + })?; + let dt = cast::cast(&dt, &DataType::Float64).context(error::TypeCastSnafu { + typ: DataType::Float64, + })?; + let rate = arithmetic::divide_dyn(&dv, &dt).context(error::ArrowComputeSnafu)?; + let v = Helper::try_into_vector(&rate).context(error::FromArrowArraySnafu)?; + Ok(v) } } @@ -81,9 +74,8 @@ impl Function for RateFunction { mod tests { use std::sync::Arc; - use arrow::array::Float64Array; use common_query::prelude::TypeSignature; - use datatypes::vectors::{Float32Vector, Int64Vector}; + use datatypes::vectors::{Float32Vector, Float64Vector, Int64Vector}; use super::*; #[test] @@ -108,9 +100,7 @@ mod tests { Arc::new(Int64Vector::from_vec(ts)), ]; let vector = rate.eval(FunctionContext::default(), &args).unwrap(); - let arr = vector.to_arrow_array(); - let expect = Arc::new(Float64Array::from_vec(vec![2.0, 3.0])); - let res = arrow::compute::comparison::eq(&*arr, &*expect); - res.iter().for_each(|x| assert!(matches!(x, Some(true)))); + let expect: VectorRef = Arc::new(Float64Vector::from_vec(vec![2.0, 3.0])); + assert_eq!(expect, vector); } } diff --git a/src/common/query/src/error.rs b/src/common/query/src/error.rs index 1d1e842d29..6df18e841d 100644 --- a/src/common/query/src/error.rs +++ b/src/common/query/src/error.rs @@ -127,10 +127,20 @@ pub enum Error { source: BoxedError, }, - #[snafu(display("Fail to cast array to {:?}, source: {}", typ, source))] + #[snafu(display("Failed to cast array to {:?}, source: {}", typ, source))] TypeCast { source: ArrowError, typ: arrow::datatypes::DataType, + backtrace: Backtrace, + }, + + #[snafu(display( + "Failed to perform compute operation on arrow arrays, source: {}", + source + ))] + ArrowCompute { + source: ArrowError, + backtrace: Backtrace, }, #[snafu(display("Query engine fail to cast value: {}", source))] @@ -139,7 +149,7 @@ pub enum Error { source: DataTypeError, }, - #[snafu(display("Fail to get scalar vector, {}", source))] + #[snafu(display("Failed to get scalar vector, {}", source))] GetScalarVector { #[snafu(backtrace)] source: DataTypeError, @@ -159,7 +169,8 @@ impl ErrorExt for Error { | Error::InvalidInputCol { .. } | Error::BadAccumulatorImpl { .. } | Error::ToScalarValue { .. } - | Error::GetScalarVector { .. } => StatusCode::EngineExecuteQuery, + | Error::GetScalarVector { .. } + | Error::ArrowCompute { .. } => StatusCode::EngineExecuteQuery, Error::InvalidInputs { source, .. } | Error::IntoVector { source, .. } diff --git a/src/datatypes/src/prelude.rs b/src/datatypes/src/prelude.rs index f6bd298316..b1afe93042 100644 --- a/src/datatypes/src/prelude.rs +++ b/src/datatypes/src/prelude.rs @@ -16,5 +16,6 @@ pub use crate::data_type::{ConcreteDataType, DataType, DataTypeRef}; pub use crate::macros::*; pub use crate::scalars::{Scalar, ScalarRef, ScalarVector, ScalarVectorBuilder}; pub use crate::type_id::LogicalTypeId; +pub use crate::types::{LogicalPrimitiveType, WrapperType}; pub use crate::value::{Value, ValueRef}; pub use crate::vectors::{MutableVector, Validity, Vector, VectorRef}; diff --git a/src/datatypes/src/types.rs b/src/datatypes/src/types.rs index 186704fdfd..8f40c563de 100644 --- a/src/datatypes/src/types.rs +++ b/src/datatypes/src/types.rs @@ -31,7 +31,10 @@ pub use list_type::ListType; pub use null_type::NullType; pub use primitive_type::{ Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, LogicalPrimitiveType, - NativeType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, WrapperType, + NativeType, OrdPrimitive, UInt16Type, UInt32Type, UInt64Type, UInt8Type, WrapperType, }; pub use string_type::StringType; -pub use timestamp_type::*; +pub use timestamp_type::{ + TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, + TimestampSecondType, TimestampType, +};