diff --git a/src/common/function/src/scalars/aggregate/mean.rs b/src/common/function/src/scalars/aggregate/mean.rs index 0778655256..2393a58cd2 100644 --- a/src/common/function/src/scalars/aggregate/mean.rs +++ b/src/common/function/src/scalars/aggregate/mean.rs @@ -22,8 +22,7 @@ use common_query::error::{ use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; use common_query::prelude::*; use datatypes::prelude::*; -use datatypes::types::WrapperType; -use datatypes::vectors::{ConstantVector, Float64Vector, Helper, UInt64Vector}; +use datatypes::vectors::{ConstantVector, Float64Vector, UInt64Vector}; use datatypes::with_match_primitive_type_id; use num_traits::AsPrimitive; use snafu::{ensure, OptionExt}; @@ -31,8 +30,7 @@ use snafu::{ensure, OptionExt}; #[derive(Debug, Default)] pub struct Mean where - T: WrapperType, - T::Native: AsPrimitive, + T: Primitive + AsPrimitive, { sum: f64, n: u64, @@ -41,8 +39,7 @@ where impl Mean where - T: WrapperType, - T::Native: AsPrimitive, + T: Primitive + AsPrimitive, { #[inline(always)] fn push(&mut self, value: T) { @@ -59,8 +56,8 @@ where impl Accumulator for Mean where - T: WrapperType, - T::Native: AsPrimitive, + T: Primitive + AsPrimitive, + for<'a> T: Scalar = T>, { fn state(&self) -> Result> { Ok(vec![self.sum.into(), self.n.into()]) @@ -76,10 +73,10 @@ where let mut len = 1; let column: &::VectorType = if column.is_const() { len = column.len(); - let column: &ConstantVector = unsafe { Helper::static_cast(column) }; - unsafe { Helper::static_cast(column.inner()) } + let column: &ConstantVector = unsafe { VectorHelper::static_cast(column) }; + unsafe { VectorHelper::static_cast(column.inner()) } } else { - unsafe { Helper::static_cast(column) } + unsafe { VectorHelper::static_cast(column) } }; (0..len).for_each(|_| { for v in column.iter_data().flatten() { diff --git a/src/common/function/src/scalars/aggregate/polyval.rs b/src/common/function/src/scalars/aggregate/polyval.rs index 32b37193b5..75a9d809f7 100644 --- a/src/common/function/src/scalars/aggregate/polyval.rs +++ b/src/common/function/src/scalars/aggregate/polyval.rs @@ -23,9 +23,9 @@ use common_query::error::{ use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; use common_query::prelude::*; use datatypes::prelude::*; -use datatypes::types::{LogicalPrimitiveType, WrapperType}; +use datatypes::types::PrimitiveType; use datatypes::value::ListValue; -use datatypes::vectors::{ConstantVector, Helper, Int64Vector, ListVector}; +use datatypes::vectors::{ConstantVector, Int64Vector, ListVector}; use datatypes::with_match_primitive_type_id; use num_traits::AsPrimitive; use snafu::{ensure, OptionExt, ResultExt}; @@ -34,10 +34,8 @@ use snafu::{ensure, OptionExt, ResultExt}; #[derive(Debug, Default)] pub struct Polyval where - T: WrapperType, - T::Native: AsPrimitive, - PolyT: WrapperType, - PolyT::Native: std::ops::Mul, + T: Primitive + AsPrimitive, + PolyT: Primitive + std::ops::Mul, { values: Vec, // DataFusion casts constant in into i64 type. @@ -47,10 +45,8 @@ where impl Polyval where - T: WrapperType, - T::Native: AsPrimitive, - PolyT: WrapperType, - PolyT::Native: std::ops::Mul, + T: Primitive + AsPrimitive, + PolyT: Primitive + std::ops::Mul, { fn push(&mut self, value: T) { self.values.push(value); @@ -59,15 +55,11 @@ where impl Accumulator for Polyval where - T: WrapperType, - PolyT: WrapperType, - T::Native: AsPrimitive, - PolyT::Native: std::ops::Mul + std::iter::Sum, - // T: Primitive + AsPrimitive, - // PolyT: Primitive + std::ops::Mul + std::iter::Sum, - // for<'a> T: Scalar = T>, - // for<'a> PolyT: Scalar = PolyT>, - // i64: AsPrimitive, + T: Primitive + AsPrimitive, + PolyT: Primitive + std::ops::Mul + std::iter::Sum, + for<'a> T: Scalar = T>, + for<'a> PolyT: Scalar = PolyT>, + i64: AsPrimitive, { fn state(&self) -> Result> { let nums = self @@ -99,10 +91,10 @@ where let mut len = 1; let column: &::VectorType = if column.is_const() { len = column.len(); - let column: &ConstantVector = unsafe { Helper::static_cast(column) }; - unsafe { Helper::static_cast(column.inner()) } + let column: &ConstantVector = unsafe { VectorHelper::static_cast(column) }; + unsafe { VectorHelper::static_cast(column.inner()) } } else { - unsafe { Helper::static_cast(column) } + unsafe { VectorHelper::static_cast(column) } }; (0..len).for_each(|_| { for v in column.iter_data().flatten() { @@ -111,7 +103,7 @@ where }); let x = &values[1]; - let x = Helper::check_get_scalar::(x).context(error::InvalidInputsSnafu { + let x = VectorHelper::check_get_scalar::(x).context(error::InvalidInputsSnafu { err_msg: "expecting \"POLYVAL\" function's second argument to be a positive integer", })?; // `get(0)` is safe because we have checked `values[1].len() == values[0].len() != 0` @@ -181,7 +173,7 @@ where })?; for value in values.values_iter() { let value = value.context(FromScalarValueSnafu)?; - let column: &::VectorType = unsafe { Helper::static_cast(&value) }; + let column: &::VectorType = unsafe { VectorHelper::static_cast(&value) }; for v in column.iter_data().flatten() { self.push(v); } @@ -221,7 +213,7 @@ impl AggregateFunctionCreator for PolyvalAccumulatorCreator { with_match_primitive_type_id!( input_type.logical_type_id(), |$S| { - Ok(Box::new(Polyval::<$S,<<$S as LogicalPrimitiveType>::LargestType as LogicalPrimitiveType>::Wrapper>::default())) + Ok(Box::new(Polyval::<$S,<$S as Primitive>::LargestType>::default())) }, { let err_msg = format!( @@ -242,7 +234,7 @@ impl AggregateFunctionCreator for PolyvalAccumulatorCreator { with_match_primitive_type_id!( input_type, |$S| { - Ok(<<$S as LogicalPrimitiveType>::LargestType as LogicalPrimitiveType>::build_data_type()) + Ok(PrimitiveType::<<$S as Primitive>::LargestType>::default().into()) }, { unreachable!() diff --git a/src/common/function/src/scalars/math/pow.rs b/src/common/function/src/scalars/math/pow.rs index fe28789e9b..fcbb877240 100644 --- a/src/common/function/src/scalars/math/pow.rs +++ b/src/common/function/src/scalars/math/pow.rs @@ -15,7 +15,6 @@ use std::fmt; use std::sync::Arc; -use common_query::error::Result; use common_query::prelude::{Signature, Volatility}; use datatypes::data_type::DataType; use datatypes::prelude::ConcreteDataType; @@ -24,6 +23,7 @@ use datatypes::with_match_primitive_type_id; use num::traits::Pow; use num_traits::AsPrimitive; +use crate::error::Result; use crate::scalars::expression::{scalar_binary_op, EvalContext}; use crate::scalars::function::{Function, FunctionContext}; diff --git a/src/common/function/src/scalars/numpy/interp.rs b/src/common/function/src/scalars/numpy/interp.rs index 67c247dc18..68981c2556 100644 --- a/src/common/function/src/scalars/numpy/interp.rs +++ b/src/common/function/src/scalars/numpy/interp.rs @@ -15,8 +15,8 @@ use std::sync::Arc; use datatypes::arrow::array::PrimitiveArray; -use datatypes::arrow::compute::cast; -use datatypes::arrow::datatypes::DataType as ArrowDataType; +use datatypes::arrow::compute::cast::primitive_to_primitive; +use datatypes::arrow::datatypes::DataType::Float64; use datatypes::data_type::DataType; use datatypes::prelude::ScalarVector; use datatypes::type_id::LogicalTypeId; @@ -80,7 +80,8 @@ fn binary_search_ascending_vector(key: Value, xp: &PrimitiveVector) -> usiz fn concrete_type_to_primitive_vector(arg: &VectorRef) -> Result> { with_match_primitive_type_id!(arg.data_type().logical_type_id(), |$S| { let tmp = arg.to_arrow_array(); - let array = cast(&tmp, &DataType::Float64)?; + let from = tmp.as_any().downcast_ref::>().expect("cast failed"); + let array = primitive_to_primitive(from, &Float64); Ok(PrimitiveVector::new(array)) },{ unreachable!()