From 58c26def6b945bb53d4261eb8f9992a1d3de1eab Mon Sep 17 00:00:00 2001 From: Yingwen Date: Wed, 7 Dec 2022 19:55:07 +0800 Subject: [PATCH] fix: fix argmin/percentile/clip/interp/scipy_stats_norm_pdf errors (#718) fix: fix argmin/percentile/clip/interp/scipy_stats_norm_pdf compiler errors --- src/common/function/src/scalars/aggregate.rs | 2 + .../function/src/scalars/aggregate/argmin.rs | 38 +++--- .../src/scalars/aggregate/percentile.rs | 108 ++++++----------- .../scalars/aggregate/scipy_stats_norm_cdf.rs | 2 +- .../scalars/aggregate/scipy_stats_norm_pdf.rs | 62 ++++------ .../function/src/scalars/expression/binary.rs | 32 +++-- .../function/src/scalars/expression/ctx.rs | 3 +- .../function/src/scalars/expression/unary.rs | 5 +- src/common/function/src/scalars/numpy.rs | 1 - src/common/function/src/scalars/numpy/clip.rs | 8 +- .../function/src/scalars/numpy/interp.rs | 114 +++++++++--------- src/common/query/src/error.rs | 18 ++- src/datatypes/src/types/primitive_type.rs | 5 +- src/datatypes/src/value.rs | 2 +- 14 files changed, 177 insertions(+), 223 deletions(-) diff --git a/src/common/function/src/scalars/aggregate.rs b/src/common/function/src/scalars/aggregate.rs index 0373b54712..8a4712a1b8 100644 --- a/src/common/function/src/scalars/aggregate.rs +++ b/src/common/function/src/scalars/aggregate.rs @@ -35,6 +35,8 @@ pub use polyval::PolyvalAccumulatorCreator; pub use scipy_stats_norm_cdf::ScipyStatsNormCdfAccumulatorCreator; pub use scipy_stats_norm_pdf::ScipyStatsNormPdfAccumulatorCreator; +use crate::scalars::FunctionRegistry; + /// A function creates `AggregateFunctionCreator`. /// "Aggregator" *is* AggregatorFunction. Since the later one is long, we named an short alias for it. /// The two names might be used interchangeably. diff --git a/src/common/function/src/scalars/aggregate/argmin.rs b/src/common/function/src/scalars/aggregate/argmin.rs index bcbd6571c5..5b93561286 100644 --- a/src/common/function/src/scalars/aggregate/argmin.rs +++ b/src/common/function/src/scalars/aggregate/argmin.rs @@ -20,23 +20,20 @@ use common_query::error::{BadAccumulatorImplSnafu, CreateAccumulatorSnafu, Resul use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; use common_query::prelude::*; use datatypes::prelude::*; -use datatypes::vectors::ConstantVector; +use datatypes::vectors::{ConstantVector, Helper}; use datatypes::with_match_primitive_type_id; use snafu::ensure; // // https://numpy.org/doc/stable/reference/generated/numpy.argmin.html #[derive(Debug, Default)] -pub struct Argmin -where - T: Primitive + PartialOrd, -{ +pub struct Argmin { min: Option, n: u32, } impl Argmin where - T: Primitive + PartialOrd, + T: Copy + PartialOrd, { fn update(&mut self, value: T, index: u32) { match self.min { @@ -56,8 +53,7 @@ where impl Accumulator for Argmin where - T: Primitive + PartialOrd, - for<'a> T: Scalar = T>, + T: WrapperType + PartialOrd, { fn state(&self) -> Result> { match self.min { @@ -75,10 +71,10 @@ where let column = &values[0]; let column: &::VectorType = if column.is_const() { - let column: &ConstantVector = unsafe { VectorHelper::static_cast(column) }; - unsafe { VectorHelper::static_cast(column.inner()) } + let column: &ConstantVector = unsafe { Helper::static_cast(column) }; + unsafe { Helper::static_cast(column.inner()) } } else { - unsafe { VectorHelper::static_cast(column) } + unsafe { Helper::static_cast(column) } }; for (i, v) in column.iter_data().enumerate() { if let Some(value) = v { @@ -102,8 +98,8 @@ where let min = &states[0]; let index = &states[1]; - let min: &::VectorType = unsafe { VectorHelper::static_cast(min) }; - let index: &::VectorType = unsafe { VectorHelper::static_cast(index) }; + let min: &::VectorType = unsafe { Helper::static_cast(min) }; + let index: &::VectorType = unsafe { Helper::static_cast(index) }; index .iter_data() .flatten() @@ -131,7 +127,7 @@ impl AggregateFunctionCreator for ArgminAccumulatorCreator { with_match_primitive_type_id!( input_type.logical_type_id(), |$S| { - Ok(Box::new(Argmin::<$S>::default())) + Ok(Box::new(Argmin::<<$S as LogicalPrimitiveType>::Wrapper>::default())) }, { let err_msg = format!( @@ -163,7 +159,7 @@ impl AggregateFunctionCreator for ArgminAccumulatorCreator { #[cfg(test)] mod test { - use datatypes::vectors::PrimitiveVector; + use datatypes::vectors::Int32Vector; use super::*; #[test] @@ -175,21 +171,19 @@ mod test { // test update one not-null value let mut argmin = Argmin::::default(); - let v: Vec = vec![Arc::new(PrimitiveVector::::from(vec![Some(42)]))]; + let v: Vec = vec![Arc::new(Int32Vector::from(vec![Some(42)]))]; assert!(argmin.update_batch(&v).is_ok()); assert_eq!(Value::from(0_u32), argmin.evaluate().unwrap()); // test update one null value let mut argmin = Argmin::::default(); - let v: Vec = vec![Arc::new(PrimitiveVector::::from(vec![ - Option::::None, - ]))]; + let v: Vec = vec![Arc::new(Int32Vector::from(vec![Option::::None]))]; assert!(argmin.update_batch(&v).is_ok()); assert_eq!(Value::Null, argmin.evaluate().unwrap()); // test update no null-value batch let mut argmin = Argmin::::default(); - let v: Vec = vec![Arc::new(PrimitiveVector::::from(vec![ + let v: Vec = vec![Arc::new(Int32Vector::from(vec![ Some(-1i32), Some(1), Some(3), @@ -199,7 +193,7 @@ mod test { // test update null-value batch let mut argmin = Argmin::::default(); - let v: Vec = vec![Arc::new(PrimitiveVector::::from(vec![ + let v: Vec = vec![Arc::new(Int32Vector::from(vec![ Some(-2i32), None, Some(4), @@ -210,7 +204,7 @@ mod test { // test update with constant vector let mut argmin = Argmin::::default(); let v: Vec = vec![Arc::new(ConstantVector::new( - Arc::new(PrimitiveVector::::from_vec(vec![4])), + Arc::new(Int32Vector::from_vec(vec![4])), 10, ))]; assert!(argmin.update_batch(&v).is_ok()); diff --git a/src/common/function/src/scalars/aggregate/percentile.rs b/src/common/function/src/scalars/aggregate/percentile.rs index 1b642dd274..1517f90e62 100644 --- a/src/common/function/src/scalars/aggregate/percentile.rs +++ b/src/common/function/src/scalars/aggregate/percentile.rs @@ -26,7 +26,7 @@ use common_query::prelude::*; use datatypes::prelude::*; use datatypes::types::OrdPrimitive; use datatypes::value::{ListValue, OrderedFloat}; -use datatypes::vectors::{ConstantVector, Float64Vector, ListVector}; +use datatypes::vectors::{ConstantVector, Float64Vector, Helper, ListVector}; use datatypes::with_match_primitive_type_id; use num::NumCast; use snafu::{ensure, OptionExt, ResultExt}; @@ -44,15 +44,15 @@ use snafu::{ensure, OptionExt, ResultExt}; // This optional method parameter specifies the method to use when the desired quantile lies between two data points i < j. // If g is the fractional part of the index surrounded by i and alpha and beta are correction constants modifying i and j. // i+g = (q-alpha)/(n-alpha-beta+1) -// Below, ‘q’ is the quantile value, ‘n’ is the sample size and alpha and beta are constants. The following formula gives an interpolation “i + g” of where the quantile would be in the sorted sample. -// With ‘i’ being the floor and ‘g’ the fractional part of the result. +// Below, 'q' is the quantile value, 'n' is the sample size and alpha and beta are constants. The following formula gives an interpolation "i + g" of where the quantile would be in the sorted sample. +// With 'i' being the floor and 'g' the fractional part of the result. // the default method is linear where // alpha = 1 // beta = 1 #[derive(Debug, Default)] pub struct Percentile where - T: Primitive, + T: WrapperType, { greater: BinaryHeap>>, not_greater: BinaryHeap>, @@ -62,7 +62,7 @@ where impl Percentile where - T: Primitive, + T: WrapperType, { fn push(&mut self, value: T) { let value = OrdPrimitive::(value); @@ -93,8 +93,7 @@ where impl Accumulator for Percentile where - T: Primitive, - for<'a> T: Scalar = T>, + T: WrapperType, { fn state(&self) -> Result> { let nums = self @@ -107,7 +106,7 @@ where Ok(vec![ Value::List(ListValue::new( Some(Box::new(nums)), - T::default().into().data_type(), + T::LogicalType::build_data_type(), )), self.p.into(), ]) @@ -129,14 +128,14 @@ where let mut len = 1; let column: &::VectorType = if column.is_const() { len = column.len(); - let column: &ConstantVector = unsafe { VectorHelper::static_cast(column) }; - unsafe { VectorHelper::static_cast(column.inner()) } + let column: &ConstantVector = unsafe { Helper::static_cast(column) }; + unsafe { Helper::static_cast(column.inner()) } } else { - unsafe { VectorHelper::static_cast(column) } + unsafe { Helper::static_cast(column) } }; let x = &values[1]; - let x = VectorHelper::check_get_scalar::(x).context(error::InvalidInputsSnafu { + let x = Helper::check_get_scalar::(x).context(error::InvalidInputTypeSnafu { err_msg: "expecting \"POLYVAL\" function's second argument to be float64", })?; // `get(0)` is safe because we have checked `values[1].len() == values[0].len() != 0` @@ -209,10 +208,11 @@ where ), })?; for value in values.values_iter() { - let value = value.context(FromScalarValueSnafu)?; - let column: &::VectorType = unsafe { VectorHelper::static_cast(&value) }; - for v in column.iter_data().flatten() { - self.push(v); + if let Some(value) = value.context(FromScalarValueSnafu)? { + let column: &::VectorType = unsafe { Helper::static_cast(&value) }; + for v in column.iter_data().flatten() { + self.push(v); + } } } Ok(()) @@ -259,7 +259,7 @@ impl AggregateFunctionCreator for PercentileAccumulatorCreator { with_match_primitive_type_id!( input_type.logical_type_id(), |$S| { - Ok(Box::new(Percentile::<$S>::default())) + Ok(Box::new(Percentile::<<$S as LogicalPrimitiveType>::Wrapper>::default())) }, { let err_msg = format!( @@ -292,7 +292,7 @@ impl AggregateFunctionCreator for PercentileAccumulatorCreator { #[cfg(test)] mod test { - use datatypes::vectors::PrimitiveVector; + use datatypes::vectors::{Float64Vector, Int32Vector}; use super::*; #[test] @@ -307,8 +307,8 @@ mod test { // test update one not-null value let mut percentile = Percentile::::default(); let v: Vec = vec![ - Arc::new(PrimitiveVector::::from(vec![Some(42)])), - Arc::new(PrimitiveVector::::from(vec![Some(100.0_f64)])), + Arc::new(Int32Vector::from(vec![Some(42)])), + Arc::new(Float64Vector::from(vec![Some(100.0_f64)])), ]; assert!(percentile.update_batch(&v).is_ok()); assert_eq!(Value::from(42.0_f64), percentile.evaluate().unwrap()); @@ -316,8 +316,8 @@ mod test { // test update one null value let mut percentile = Percentile::::default(); let v: Vec = vec![ - Arc::new(PrimitiveVector::::from(vec![Option::::None])), - Arc::new(PrimitiveVector::::from(vec![Some(100.0_f64)])), + Arc::new(Int32Vector::from(vec![Option::::None])), + Arc::new(Float64Vector::from(vec![Some(100.0_f64)])), ]; assert!(percentile.update_batch(&v).is_ok()); assert_eq!(Value::Null, percentile.evaluate().unwrap()); @@ -325,12 +325,8 @@ mod test { // test update no null-value batch let mut percentile = Percentile::::default(); let v: Vec = vec![ - Arc::new(PrimitiveVector::::from(vec![ - Some(-1i32), - Some(1), - Some(2), - ])), - Arc::new(PrimitiveVector::::from(vec![ + Arc::new(Int32Vector::from(vec![Some(-1i32), Some(1), Some(2)])), + Arc::new(Float64Vector::from(vec![ Some(100.0_f64), Some(100.0_f64), Some(100.0_f64), @@ -342,13 +338,8 @@ mod test { // test update null-value batch let mut percentile = Percentile::::default(); let v: Vec = vec![ - Arc::new(PrimitiveVector::::from(vec![ - Some(-2i32), - None, - Some(3), - Some(4), - ])), - Arc::new(PrimitiveVector::::from(vec![ + Arc::new(Int32Vector::from(vec![Some(-2i32), None, Some(3), Some(4)])), + Arc::new(Float64Vector::from(vec![ Some(100.0_f64), Some(100.0_f64), Some(100.0_f64), @@ -362,13 +353,10 @@ mod test { let mut percentile = Percentile::::default(); let v: Vec = vec![ Arc::new(ConstantVector::new( - Arc::new(PrimitiveVector::::from_vec(vec![4])), + Arc::new(Int32Vector::from_vec(vec![4])), 2, )), - Arc::new(PrimitiveVector::::from(vec![ - Some(100.0_f64), - Some(100.0_f64), - ])), + Arc::new(Float64Vector::from(vec![Some(100.0_f64), Some(100.0_f64)])), ]; assert!(percentile.update_batch(&v).is_ok()); assert_eq!(Value::from(4_f64), percentile.evaluate().unwrap()); @@ -376,12 +364,8 @@ mod test { // test left border let mut percentile = Percentile::::default(); let v: Vec = vec![ - Arc::new(PrimitiveVector::::from(vec![ - Some(-1i32), - Some(1), - Some(2), - ])), - Arc::new(PrimitiveVector::::from(vec![ + Arc::new(Int32Vector::from(vec![Some(-1i32), Some(1), Some(2)])), + Arc::new(Float64Vector::from(vec![ Some(0.0_f64), Some(0.0_f64), Some(0.0_f64), @@ -393,12 +377,8 @@ mod test { // test medium let mut percentile = Percentile::::default(); let v: Vec = vec![ - Arc::new(PrimitiveVector::::from(vec![ - Some(-1i32), - Some(1), - Some(2), - ])), - Arc::new(PrimitiveVector::::from(vec![ + Arc::new(Int32Vector::from(vec![Some(-1i32), Some(1), Some(2)])), + Arc::new(Float64Vector::from(vec![ Some(50.0_f64), Some(50.0_f64), Some(50.0_f64), @@ -410,12 +390,8 @@ mod test { // test right border let mut percentile = Percentile::::default(); let v: Vec = vec![ - Arc::new(PrimitiveVector::::from(vec![ - Some(-1i32), - Some(1), - Some(2), - ])), - Arc::new(PrimitiveVector::::from(vec![ + Arc::new(Int32Vector::from(vec![Some(-1i32), Some(1), Some(2)])), + Arc::new(Float64Vector::from(vec![ Some(100.0_f64), Some(100.0_f64), Some(100.0_f64), @@ -431,12 +407,8 @@ mod test { // >> 6.400000000000 let mut percentile = Percentile::::default(); let v: Vec = vec![ - Arc::new(PrimitiveVector::::from(vec![ - Some(10i32), - Some(7), - Some(4), - ])), - Arc::new(PrimitiveVector::::from(vec![ + Arc::new(Int32Vector::from(vec![Some(10i32), Some(7), Some(4)])), + Arc::new(Float64Vector::from(vec![ Some(40.0_f64), Some(40.0_f64), Some(40.0_f64), @@ -451,12 +423,8 @@ mod test { // >> 9.7000000000000011 let mut percentile = Percentile::::default(); let v: Vec = vec![ - Arc::new(PrimitiveVector::::from(vec![ - Some(10i32), - Some(7), - Some(4), - ])), - Arc::new(PrimitiveVector::::from(vec![ + Arc::new(Int32Vector::from(vec![Some(10i32), Some(7), Some(4)])), + Arc::new(Float64Vector::from(vec![ Some(95.0_f64), Some(95.0_f64), Some(95.0_f64), diff --git a/src/common/function/src/scalars/aggregate/scipy_stats_norm_cdf.rs b/src/common/function/src/scalars/aggregate/scipy_stats_norm_cdf.rs index 9a50fb2c55..caa07248a3 100644 --- a/src/common/function/src/scalars/aggregate/scipy_stats_norm_cdf.rs +++ b/src/common/function/src/scalars/aggregate/scipy_stats_norm_cdf.rs @@ -87,7 +87,7 @@ where }; let x = &values[1]; - let x = Helper::check_get_scalar::(x).context(error::InvalidInputsSnafu { + let x = Helper::check_get_scalar::(x).context(error::InvalidInputTypeSnafu { err_msg: "expecting \"SCIPYSTATSNORMCDF\" function's second argument to be a positive integer", })?; let first = x.get(0); diff --git a/src/common/function/src/scalars/aggregate/scipy_stats_norm_pdf.rs b/src/common/function/src/scalars/aggregate/scipy_stats_norm_pdf.rs index e381d11b54..186d59a890 100644 --- a/src/common/function/src/scalars/aggregate/scipy_stats_norm_pdf.rs +++ b/src/common/function/src/scalars/aggregate/scipy_stats_norm_pdf.rs @@ -23,7 +23,7 @@ use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; use common_query::prelude::*; use datatypes::prelude::*; use datatypes::value::{ListValue, OrderedFloat}; -use datatypes::vectors::{ConstantVector, Float64Vector, ListVector}; +use datatypes::vectors::{ConstantVector, Float64Vector, Helper, ListVector}; use datatypes::with_match_primitive_type_id; use num_traits::AsPrimitive; use snafu::{ensure, OptionExt, ResultExt}; @@ -33,18 +33,12 @@ use statrs::statistics::Statistics; // https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.norm.html #[derive(Debug, Default)] -pub struct ScipyStatsNormPdf -where - T: Primitive + AsPrimitive + std::iter::Sum, -{ +pub struct ScipyStatsNormPdf { values: Vec, x: Option, } -impl ScipyStatsNormPdf -where - T: Primitive + AsPrimitive + std::iter::Sum, -{ +impl ScipyStatsNormPdf { fn push(&mut self, value: T) { self.values.push(value); } @@ -52,8 +46,8 @@ where impl Accumulator for ScipyStatsNormPdf where - T: Primitive + AsPrimitive + std::iter::Sum, - for<'a> T: Scalar = T>, + T: WrapperType, + T::Native: AsPrimitive + std::iter::Sum, { fn state(&self) -> Result> { let nums = self @@ -64,7 +58,7 @@ where Ok(vec![ Value::List(ListValue::new( Some(Box::new(nums)), - T::default().into().data_type(), + T::LogicalType::build_data_type(), )), self.x.into(), ]) @@ -86,14 +80,14 @@ where let mut len = 1; let column: &::VectorType = if column.is_const() { len = column.len(); - let column: &ConstantVector = unsafe { VectorHelper::static_cast(column) }; - unsafe { VectorHelper::static_cast(column.inner()) } + let column: &ConstantVector = unsafe { Helper::static_cast(column) }; + unsafe { Helper::static_cast(column.inner()) } } else { - unsafe { VectorHelper::static_cast(column) } + unsafe { Helper::static_cast(column) } }; let x = &values[1]; - let x = VectorHelper::check_get_scalar::(x).context(error::InvalidInputsSnafu { + let x = Helper::check_get_scalar::(x).context(error::InvalidInputTypeSnafu { err_msg: "expecting \"SCIPYSTATSNORMPDF\" function's second argument to be a positive integer", })?; let first = x.get(0); @@ -160,19 +154,20 @@ where ), })?; for value in values.values_iter() { - let value = value.context(FromScalarValueSnafu)?; - let column: &::VectorType = unsafe { VectorHelper::static_cast(&value) }; - for v in column.iter_data().flatten() { - self.push(v); + if let Some(value) = value.context(FromScalarValueSnafu)? { + let column: &::VectorType = unsafe { Helper::static_cast(&value) }; + for v in column.iter_data().flatten() { + self.push(v); + } } } Ok(()) } fn evaluate(&self) -> Result { - let values = self.values.iter().map(|&v| v.as_()).collect::>(); - let mean = values.clone().mean(); - let std_dev = values.std_dev(); + let mean = self.values.iter().map(|v| v.into_native().as_()).mean(); + let std_dev = self.values.iter().map(|v| v.into_native().as_()).std_dev(); + if mean.is_nan() || std_dev.is_nan() { Ok(Value::Null) } else { @@ -198,7 +193,7 @@ impl AggregateFunctionCreator for ScipyStatsNormPdfAccumulatorCreator { with_match_primitive_type_id!( input_type.logical_type_id(), |$S| { - Ok(Box::new(ScipyStatsNormPdf::<$S>::default())) + Ok(Box::new(ScipyStatsNormPdf::<<$S as LogicalPrimitiveType>::Wrapper>::default())) }, { let err_msg = format!( @@ -230,7 +225,7 @@ impl AggregateFunctionCreator for ScipyStatsNormPdfAccumulatorCreator { #[cfg(test)] mod test { - use datatypes::vectors::PrimitiveVector; + use datatypes::vectors::{Float64Vector, Int32Vector}; use super::*; #[test] @@ -244,12 +239,8 @@ mod test { // test update no null-value batch let mut scipy_stats_norm_pdf = ScipyStatsNormPdf::::default(); let v: Vec = vec![ - Arc::new(PrimitiveVector::::from(vec![ - Some(-1i32), - Some(1), - Some(2), - ])), - Arc::new(PrimitiveVector::::from(vec![ + Arc::new(Int32Vector::from(vec![Some(-1i32), Some(1), Some(2)])), + Arc::new(Float64Vector::from(vec![ Some(2.0_f64), Some(2.0_f64), Some(2.0_f64), @@ -264,13 +255,8 @@ mod test { // test update null-value batch let mut scipy_stats_norm_pdf = ScipyStatsNormPdf::::default(); let v: Vec = vec![ - Arc::new(PrimitiveVector::::from(vec![ - Some(-2i32), - None, - Some(3), - Some(4), - ])), - Arc::new(PrimitiveVector::::from(vec![ + Arc::new(Int32Vector::from(vec![Some(-2i32), None, Some(3), Some(4)])), + Arc::new(Float64Vector::from(vec![ Some(2.0_f64), None, Some(2.0_f64), diff --git a/src/common/function/src/scalars/expression/binary.rs b/src/common/function/src/scalars/expression/binary.rs index b02e46c937..d1a9db8eb9 100644 --- a/src/common/function/src/scalars/expression/binary.rs +++ b/src/common/function/src/scalars/expression/binary.rs @@ -14,10 +14,10 @@ use std::iter; +use common_query::error::Result; use datatypes::prelude::*; -use datatypes::vectors::ConstantVector; +use datatypes::vectors::{ConstantVector, Helper}; -use crate::error::Result; use crate::scalars::expression::ctx::EvalContext; pub fn scalar_binary_op( @@ -36,10 +36,9 @@ where let result = match (l.is_const(), r.is_const()) { (false, true) => { - let left: &::VectorType = unsafe { VectorHelper::static_cast(l) }; - let right: &ConstantVector = unsafe { VectorHelper::static_cast(r) }; - let right: &::VectorType = - unsafe { VectorHelper::static_cast(right.inner()) }; + let left: &::VectorType = unsafe { Helper::static_cast(l) }; + let right: &ConstantVector = unsafe { Helper::static_cast(r) }; + let right: &::VectorType = unsafe { Helper::static_cast(right.inner()) }; let b = right.get_data(0); let it = left.iter_data().map(|a| f(a, b, ctx)); @@ -47,8 +46,8 @@ where } (false, false) => { - let left: &::VectorType = unsafe { VectorHelper::static_cast(l) }; - let right: &::VectorType = unsafe { VectorHelper::static_cast(r) }; + let left: &::VectorType = unsafe { Helper::static_cast(l) }; + let right: &::VectorType = unsafe { Helper::static_cast(r) }; let it = left .iter_data() @@ -58,25 +57,22 @@ where } (true, false) => { - let left: &ConstantVector = unsafe { VectorHelper::static_cast(l) }; - let left: &::VectorType = - unsafe { VectorHelper::static_cast(left.inner()) }; + let left: &ConstantVector = unsafe { Helper::static_cast(l) }; + let left: &::VectorType = unsafe { Helper::static_cast(left.inner()) }; let a = left.get_data(0); - let right: &::VectorType = unsafe { VectorHelper::static_cast(r) }; + let right: &::VectorType = unsafe { Helper::static_cast(r) }; let it = right.iter_data().map(|b| f(a, b, ctx)); ::VectorType::from_owned_iterator(it) } (true, true) => { - let left: &ConstantVector = unsafe { VectorHelper::static_cast(l) }; - let left: &::VectorType = - unsafe { VectorHelper::static_cast(left.inner()) }; + let left: &ConstantVector = unsafe { Helper::static_cast(l) }; + let left: &::VectorType = unsafe { Helper::static_cast(left.inner()) }; let a = left.get_data(0); - let right: &ConstantVector = unsafe { VectorHelper::static_cast(r) }; - let right: &::VectorType = - unsafe { VectorHelper::static_cast(right.inner()) }; + let right: &ConstantVector = unsafe { Helper::static_cast(r) }; + let right: &::VectorType = unsafe { Helper::static_cast(right.inner()) }; let b = right.get_data(0); let it = iter::repeat(a) diff --git a/src/common/function/src/scalars/expression/ctx.rs b/src/common/function/src/scalars/expression/ctx.rs index 7910bb82b8..c6735bd1d0 100644 --- a/src/common/function/src/scalars/expression/ctx.rs +++ b/src/common/function/src/scalars/expression/ctx.rs @@ -13,8 +13,7 @@ // limitations under the License. use chrono_tz::Tz; - -use crate::error::Error; +use common_query::error::Error; pub struct EvalContext { _tz: Tz, diff --git a/src/common/function/src/scalars/expression/unary.rs b/src/common/function/src/scalars/expression/unary.rs index a3434a2b0e..0862f711e1 100644 --- a/src/common/function/src/scalars/expression/unary.rs +++ b/src/common/function/src/scalars/expression/unary.rs @@ -12,10 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +use common_query::error::{self, Result}; use datatypes::prelude::*; +use datatypes::vectors::Helper; use snafu::ResultExt; -use crate::error::{GetScalarVectorSnafu, Result}; use crate::scalars::expression::ctx::EvalContext; /// TODO: remove the allow_unused when it's used. @@ -28,7 +29,7 @@ pub fn scalar_unary_op( where F: Fn(Option>, &mut EvalContext) -> Option, { - let left = VectorHelper::check_get_scalar::(l).context(GetScalarVectorSnafu)?; + let left = Helper::check_get_scalar::(l).context(error::GetScalarVectorSnafu)?; let it = left.iter_data().map(|a| f(a, ctx)); let result = ::VectorType::from_owned_iterator(it); diff --git a/src/common/function/src/scalars/numpy.rs b/src/common/function/src/scalars/numpy.rs index 76140fb7de..ed8d9b6f30 100644 --- a/src/common/function/src/scalars/numpy.rs +++ b/src/common/function/src/scalars/numpy.rs @@ -13,7 +13,6 @@ // limitations under the License. mod clip; -#[allow(unused)] mod interp; use std::sync::Arc; diff --git a/src/common/function/src/scalars/numpy/clip.rs b/src/common/function/src/scalars/numpy/clip.rs index f743bf5ff5..be58614d70 100644 --- a/src/common/function/src/scalars/numpy/clip.rs +++ b/src/common/function/src/scalars/numpy/clip.rs @@ -15,14 +15,14 @@ use std::fmt; use std::sync::Arc; +use common_query::error::Result; use common_query::prelude::{Signature, Volatility}; use datatypes::data_type::{ConcreteDataType, DataType}; -use datatypes::prelude::{Scalar, VectorRef}; +use datatypes::prelude::*; use datatypes::with_match_primitive_type_id; use num_traits::AsPrimitive; use paste::paste; -use crate::error::Result; use crate::scalars::expression::{scalar_binary_op, EvalContext}; use crate::scalars::function::{Function, FunctionContext}; @@ -38,8 +38,8 @@ macro_rules! define_eval { with_match_primitive_type_id!(columns[1].data_type().logical_type_id(), |$T| { with_match_primitive_type_id!(columns[2].data_type().logical_type_id(), |$R| { // clip(a, min, max) is equals to min(max(a, min), max) - let col: VectorRef = Arc::new(scalar_binary_op::<$S, $T, $O, _>(&columns[0], &columns[1], scalar_max, &mut EvalContext::default())?); - let col = scalar_binary_op::<$O, $R, $O, _>(&col, &columns[2], scalar_min, &mut EvalContext::default())?; + let col: VectorRef = Arc::new(scalar_binary_op::<<$S as LogicalPrimitiveType>::Wrapper, <$T as LogicalPrimitiveType>::Wrapper, $O, _>(&columns[0], &columns[1], scalar_max, &mut EvalContext::default())?); + let col = scalar_binary_op::<$O, <$R as LogicalPrimitiveType>::Wrapper, $O, _>(&col, &columns[2], scalar_min, &mut EvalContext::default())?; Ok(Arc::new(col)) }, { unreachable!() diff --git a/src/common/function/src/scalars/numpy/interp.rs b/src/common/function/src/scalars/numpy/interp.rs index 67c247dc18..c4bb6e9811 100644 --- a/src/common/function/src/scalars/numpy/interp.rs +++ b/src/common/function/src/scalars/numpy/interp.rs @@ -14,41 +14,18 @@ use std::sync::Arc; -use datatypes::arrow::array::PrimitiveArray; +use common_query::error::{self, Result}; use datatypes::arrow::compute::cast; use datatypes::arrow::datatypes::DataType as ArrowDataType; use datatypes::data_type::DataType; use datatypes::prelude::ScalarVector; -use datatypes::type_id::LogicalTypeId; use datatypes::value::Value; -use datatypes::vectors::{Float64Vector, PrimitiveVector, Vector, VectorRef}; -use datatypes::{arrow, with_match_primitive_type_id}; -use snafu::{ensure, Snafu}; - -#[derive(Debug, Snafu)] -pub enum Error { - #[snafu(display( - "The length of the args is not enough, expect at least: {}, have: {}", - expect, - actual, - ))] - ArgsLenNotEnough { expect: usize, actual: usize }, - - #[snafu(display("The sample {} is empty", name))] - SampleEmpty { name: String }, - - #[snafu(display( - "The length of the len1: {} don't match the length of the len2: {}", - len1, - len2, - ))] - LenNotEquals { len1: usize, len2: usize }, -} - -pub type Result = std::result::Result; +use datatypes::vectors::{Float64Vector, Vector, VectorRef}; +use datatypes::with_match_primitive_type_id; +use snafu::{ensure, ResultExt}; /* search the biggest number that smaller than x in xp */ -fn linear_search_ascending_vector(x: Value, xp: &PrimitiveVector) -> usize { +fn linear_search_ascending_vector(x: Value, xp: &Float64Vector) -> usize { for i in 0..xp.len() { if x < xp.get(i) { return i - 1; @@ -58,7 +35,7 @@ fn linear_search_ascending_vector(x: Value, xp: &PrimitiveVector) -> usize } /* search the biggest number that smaller than x in xp */ -fn binary_search_ascending_vector(key: Value, xp: &PrimitiveVector) -> usize { +fn binary_search_ascending_vector(key: Value, xp: &Float64Vector) -> usize { let mut left = 0; let mut right = xp.len(); /* If len <= 4 use linear search. */ @@ -77,26 +54,33 @@ fn binary_search_ascending_vector(key: Value, xp: &PrimitiveVector) -> usiz left - 1 } -fn concrete_type_to_primitive_vector(arg: &VectorRef) -> Result> { +fn concrete_type_to_primitive_vector(arg: &VectorRef) -> Result { with_match_primitive_type_id!(arg.data_type().logical_type_id(), |$S| { let tmp = arg.to_arrow_array(); - let array = cast(&tmp, &DataType::Float64)?; - Ok(PrimitiveVector::new(array)) + let array = cast(&tmp, &ArrowDataType::Float64).context(error::TypeCastSnafu { + typ: ArrowDataType::Float64, + })?; + // Safety: array has been cast to Float64Array. + Ok(Float64Vector::try_from_arrow_array(array).unwrap()) },{ unreachable!() }) } /// https://github.com/numpy/numpy/blob/b101756ac02e390d605b2febcded30a1da50cc2c/numpy/core/src/multiarray/compiled_base.c#L491 +#[allow(unused)] pub fn interp(args: &[VectorRef]) -> Result { let mut left = None; let mut right = None; ensure!( args.len() >= 3, - ArgsLenNotEnoughSnafu { - expect: 3_usize, - actual: args.len() + error::InvalidFuncArgsSnafu { + err_msg: format!( + "The length of the args is not enough, expect at least: {}, have: {}", + 3, + args.len() + ), } ); @@ -108,9 +92,12 @@ pub fn interp(args: &[VectorRef]) -> Result { if args.len() > 3 { ensure!( args.len() == 5, - ArgsLenNotEnoughSnafu { - expect: 5_usize, - actual: args.len() + error::InvalidFuncArgsSnafu { + err_msg: format!( + "The length of the args is not enough, expect at least: {}, have: {}", + 5, + args.len() + ), } ); @@ -122,14 +109,32 @@ pub fn interp(args: &[VectorRef]) -> Result { .get_data(0); } - ensure!(x.len() != 0, SampleEmptySnafu { name: "x" }); - ensure!(xp.len() != 0, SampleEmptySnafu { name: "xp" }); - ensure!(fp.len() != 0, SampleEmptySnafu { name: "fp" }); + ensure!( + x.len() != 0, + error::InvalidFuncArgsSnafu { + err_msg: "The sample x is empty", + } + ); + ensure!( + xp.len() != 0, + error::InvalidFuncArgsSnafu { + err_msg: "The sample xp is empty", + } + ); + ensure!( + fp.len() != 0, + error::InvalidFuncArgsSnafu { + err_msg: "The sample fp is empty", + } + ); ensure!( xp.len() == fp.len(), - LenNotEqualsSnafu { - len1: xp.len(), - len2: fp.len(), + error::InvalidFuncArgsSnafu { + err_msg: format!( + "The length of the len1: {} don't match the length of the len2: {}", + xp.len(), + fp.len() + ), } ); @@ -146,7 +151,7 @@ pub fn interp(args: &[VectorRef]) -> Result { let res; if xp.len() == 1 { - res = x + let datas = x .iter_data() .map(|x| { if Value::from(x) < xp.get(0) { @@ -157,7 +162,8 @@ pub fn interp(args: &[VectorRef]) -> Result { fp.get_data(0) } }) - .collect::(); + .collect::>(); + res = Float64Vector::from(datas); } else { let mut j = 0; /* only pre-calculate slopes if there are relatively few of them. */ @@ -184,7 +190,7 @@ pub fn interp(args: &[VectorRef]) -> Result { } slopes = Some(slopes_tmp); } - res = x + let datas = x .iter_data() .map(|x| match x { Some(xi) => { @@ -247,7 +253,8 @@ pub fn interp(args: &[VectorRef]) -> Result { } _ => None, }) - .collect::(); + .collect::>(); + res = Float64Vector::from(datas); } Ok(Arc::new(res) as _) } @@ -256,8 +263,7 @@ pub fn interp(args: &[VectorRef]) -> Result { mod tests { use std::sync::Arc; - use datatypes::prelude::ScalarVectorBuilder; - use datatypes::vectors::{Int32Vector, Int64Vector, PrimitiveVectorBuilder}; + use datatypes::vectors::{Int32Vector, Int64Vector}; use super::*; #[test] @@ -340,12 +346,8 @@ mod tests { assert!(matches!(vector.get(0), Value::Float64(v) if v==x[0] as f64)); // x=None output:Null - let input = [None, Some(0.0), Some(0.3)]; - let mut builder = PrimitiveVectorBuilder::with_capacity(input.len()); - for v in input { - builder.push(v); - } - let x = builder.finish(); + let input = vec![None, Some(0.0), Some(0.3)]; + let x = Float64Vector::from(input); let args: Vec = vec![ Arc::new(x), Arc::new(Int64Vector::from_vec(xp)), diff --git a/src/common/query/src/error.rs b/src/common/query/src/error.rs index 6df18e841d..a7d39c725c 100644 --- a/src/common/query/src/error.rs +++ b/src/common/query/src/error.rs @@ -76,8 +76,8 @@ pub enum Error { backtrace: Backtrace, }, - #[snafu(display("Invalid inputs: {}", err_msg))] - InvalidInputs { + #[snafu(display("Invalid input type: {}", err_msg))] + InvalidInputType { #[snafu(backtrace)] source: DataTypeError, err_msg: String, @@ -154,6 +154,12 @@ pub enum Error { #[snafu(backtrace)] source: DataTypeError, }, + + #[snafu(display("Invalid function args: {}", err_msg))] + InvalidFuncArgs { + err_msg: String, + backtrace: Backtrace, + }, } pub type Result = std::result::Result; @@ -172,7 +178,7 @@ impl ErrorExt for Error { | Error::GetScalarVector { .. } | Error::ArrowCompute { .. } => StatusCode::EngineExecuteQuery, - Error::InvalidInputs { source, .. } + Error::InvalidInputType { source, .. } | Error::IntoVector { source, .. } | Error::FromScalarValue { source } | Error::ConvertArrowSchema { source } @@ -182,9 +188,9 @@ impl ErrorExt for Error { | Error::GeneralDataFusion { .. } | Error::DataFusionExecutionPlan { .. } => StatusCode::Unexpected, - Error::UnsupportedInputDataType { .. } | Error::TypeCast { .. } => { - StatusCode::InvalidArguments - } + Error::UnsupportedInputDataType { .. } + | Error::TypeCast { .. } + | Error::InvalidFuncArgs { .. } => StatusCode::InvalidArguments, Error::ConvertDfRecordBatchStream { source, .. } => source.status_code(), Error::ExecutePhysicalPlan { source } => source.status_code(), diff --git a/src/datatypes/src/types/primitive_type.rs b/src/datatypes/src/types/primitive_type.rs index c005b89fee..ea752cf8de 100644 --- a/src/datatypes/src/types/primitive_type.rs +++ b/src/datatypes/src/types/primitive_type.rs @@ -109,8 +109,8 @@ pub trait LogicalPrimitiveType: 'static + Sized { pub struct OrdPrimitive(pub T); impl OrdPrimitive { - pub fn as_primitive(&self) -> T { - self.0 + pub fn as_primitive(&self) -> T::Native { + self.0.into_native() } } @@ -343,6 +343,7 @@ mod tests { heap: BinaryHeap::new(), }; foo.push($Type::default()); + assert_eq!($Type::default(), foo.heap.pop().unwrap().as_primitive()); }; } diff --git a/src/datatypes/src/value.rs b/src/datatypes/src/value.rs index 911f8600de..257cad2d87 100644 --- a/src/datatypes/src/value.rs +++ b/src/datatypes/src/value.rs @@ -464,7 +464,7 @@ impl ListValue { Ok(ScalarValue::List( vs, - Box::new(new_item_field(output_type.as_arrow_type())), + Box::new(new_item_field(output_type.item_type().as_arrow_type())), )) } }