From 6f3baf96b0860d3cfafa4223c4e437b5035072e7 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Wed, 7 Dec 2022 16:38:43 +0800 Subject: [PATCH] fix: fix compile error for mean/polyval/pow/interp ops (#717) * fix: fix compile error for mean/polyval/pow/interp ops Signed-off-by: Ruihang Xia * simplify type bounds Signed-off-by: Ruihang Xia Signed-off-by: Ruihang Xia --- .../function/src/scalars/aggregate/mean.rs | 21 +++++----- .../function/src/scalars/aggregate/polyval.rs | 39 ++++++++++--------- src/common/function/src/scalars/math/pow.rs | 2 +- .../function/src/scalars/numpy/interp.rs | 7 ++-- 4 files changed, 35 insertions(+), 34 deletions(-) diff --git a/src/common/function/src/scalars/aggregate/mean.rs b/src/common/function/src/scalars/aggregate/mean.rs index 2393a58cd2..f3dc723b41 100644 --- a/src/common/function/src/scalars/aggregate/mean.rs +++ b/src/common/function/src/scalars/aggregate/mean.rs @@ -22,16 +22,14 @@ use common_query::error::{ use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; use common_query::prelude::*; use datatypes::prelude::*; -use datatypes::vectors::{ConstantVector, Float64Vector, UInt64Vector}; +use datatypes::types::WrapperType; +use datatypes::vectors::{ConstantVector, Float64Vector, Helper, UInt64Vector}; use datatypes::with_match_primitive_type_id; use num_traits::AsPrimitive; use snafu::{ensure, OptionExt}; #[derive(Debug, Default)] -pub struct Mean -where - T: Primitive + AsPrimitive, -{ +pub struct Mean { sum: f64, n: u64, _phantom: PhantomData, @@ -39,7 +37,8 @@ where impl Mean where - T: Primitive + AsPrimitive, + T: WrapperType, + T::Native: AsPrimitive, { #[inline(always)] fn push(&mut self, value: T) { @@ -56,8 +55,8 @@ where impl Accumulator for Mean where - T: Primitive + AsPrimitive, - for<'a> T: Scalar = T>, + T: WrapperType, + T::Native: AsPrimitive, { fn state(&self) -> Result> { Ok(vec![self.sum.into(), self.n.into()]) @@ -73,10 +72,10 @@ where let mut len = 1; let column: &::VectorType = if column.is_const() { len = column.len(); - let column: &ConstantVector = unsafe { VectorHelper::static_cast(column) }; - unsafe { VectorHelper::static_cast(column.inner()) } + let column: &ConstantVector = unsafe { Helper::static_cast(column) }; + unsafe { Helper::static_cast(column.inner()) } } else { - unsafe { VectorHelper::static_cast(column) } + unsafe { Helper::static_cast(column) } }; (0..len).for_each(|_| { for v in column.iter_data().flatten() { diff --git a/src/common/function/src/scalars/aggregate/polyval.rs b/src/common/function/src/scalars/aggregate/polyval.rs index 75a9d809f7..409137212e 100644 --- a/src/common/function/src/scalars/aggregate/polyval.rs +++ b/src/common/function/src/scalars/aggregate/polyval.rs @@ -23,9 +23,9 @@ use common_query::error::{ use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; use common_query::prelude::*; use datatypes::prelude::*; -use datatypes::types::PrimitiveType; +use datatypes::types::{LogicalPrimitiveType, WrapperType}; use datatypes::value::ListValue; -use datatypes::vectors::{ConstantVector, Int64Vector, ListVector}; +use datatypes::vectors::{ConstantVector, Helper, Int64Vector, ListVector}; use datatypes::with_match_primitive_type_id; use num_traits::AsPrimitive; use snafu::{ensure, OptionExt, ResultExt}; @@ -34,8 +34,10 @@ use snafu::{ensure, OptionExt, ResultExt}; #[derive(Debug, Default)] pub struct Polyval where - T: Primitive + AsPrimitive, - PolyT: Primitive + std::ops::Mul, + T: WrapperType, + T::Native: AsPrimitive, + PolyT: WrapperType, + PolyT::Native: std::ops::Mul, { values: Vec, // DataFusion casts constant in into i64 type. @@ -45,8 +47,10 @@ where impl Polyval where - T: Primitive + AsPrimitive, - PolyT: Primitive + std::ops::Mul, + T: WrapperType, + T::Native: AsPrimitive, + PolyT: WrapperType, + PolyT::Native: std::ops::Mul, { fn push(&mut self, value: T) { self.values.push(value); @@ -55,11 +59,10 @@ where impl Accumulator for Polyval where - T: Primitive + AsPrimitive, - PolyT: Primitive + std::ops::Mul + std::iter::Sum, - for<'a> T: Scalar = T>, - for<'a> PolyT: Scalar = PolyT>, - i64: AsPrimitive, + T: WrapperType, + PolyT: WrapperType, + T::Native: AsPrimitive, + PolyT::Native: std::ops::Mul + std::iter::Sum, { fn state(&self) -> Result> { let nums = self @@ -91,10 +94,10 @@ where let mut len = 1; let column: &::VectorType = if column.is_const() { len = column.len(); - let column: &ConstantVector = unsafe { VectorHelper::static_cast(column) }; - unsafe { VectorHelper::static_cast(column.inner()) } + let column: &ConstantVector = unsafe { Helper::static_cast(column) }; + unsafe { Helper::static_cast(column.inner()) } } else { - unsafe { VectorHelper::static_cast(column) } + unsafe { Helper::static_cast(column) } }; (0..len).for_each(|_| { for v in column.iter_data().flatten() { @@ -103,7 +106,7 @@ where }); let x = &values[1]; - let x = VectorHelper::check_get_scalar::(x).context(error::InvalidInputsSnafu { + let x = Helper::check_get_scalar::(x).context(error::InvalidInputsSnafu { err_msg: "expecting \"POLYVAL\" function's second argument to be a positive integer", })?; // `get(0)` is safe because we have checked `values[1].len() == values[0].len() != 0` @@ -173,7 +176,7 @@ where })?; for value in values.values_iter() { let value = value.context(FromScalarValueSnafu)?; - let column: &::VectorType = unsafe { VectorHelper::static_cast(&value) }; + let column: &::VectorType = unsafe { Helper::static_cast(&value) }; for v in column.iter_data().flatten() { self.push(v); } @@ -213,7 +216,7 @@ impl AggregateFunctionCreator for PolyvalAccumulatorCreator { with_match_primitive_type_id!( input_type.logical_type_id(), |$S| { - Ok(Box::new(Polyval::<$S,<$S as Primitive>::LargestType>::default())) + Ok(Box::new(Polyval::<$S,<<$S as LogicalPrimitiveType>::LargestType as LogicalPrimitiveType>::Wrapper>::default())) }, { let err_msg = format!( @@ -234,7 +237,7 @@ impl AggregateFunctionCreator for PolyvalAccumulatorCreator { with_match_primitive_type_id!( input_type, |$S| { - Ok(PrimitiveType::<<$S as Primitive>::LargestType>::default().into()) + Ok(<<$S as LogicalPrimitiveType>::LargestType as LogicalPrimitiveType>::build_data_type()) }, { unreachable!() diff --git a/src/common/function/src/scalars/math/pow.rs b/src/common/function/src/scalars/math/pow.rs index fcbb877240..fe28789e9b 100644 --- a/src/common/function/src/scalars/math/pow.rs +++ b/src/common/function/src/scalars/math/pow.rs @@ -15,6 +15,7 @@ use std::fmt; use std::sync::Arc; +use common_query::error::Result; use common_query::prelude::{Signature, Volatility}; use datatypes::data_type::DataType; use datatypes::prelude::ConcreteDataType; @@ -23,7 +24,6 @@ use datatypes::with_match_primitive_type_id; use num::traits::Pow; use num_traits::AsPrimitive; -use crate::error::Result; use crate::scalars::expression::{scalar_binary_op, EvalContext}; use crate::scalars::function::{Function, FunctionContext}; diff --git a/src/common/function/src/scalars/numpy/interp.rs b/src/common/function/src/scalars/numpy/interp.rs index 68981c2556..67c247dc18 100644 --- a/src/common/function/src/scalars/numpy/interp.rs +++ b/src/common/function/src/scalars/numpy/interp.rs @@ -15,8 +15,8 @@ use std::sync::Arc; use datatypes::arrow::array::PrimitiveArray; -use datatypes::arrow::compute::cast::primitive_to_primitive; -use datatypes::arrow::datatypes::DataType::Float64; +use datatypes::arrow::compute::cast; +use datatypes::arrow::datatypes::DataType as ArrowDataType; use datatypes::data_type::DataType; use datatypes::prelude::ScalarVector; use datatypes::type_id::LogicalTypeId; @@ -80,8 +80,7 @@ fn binary_search_ascending_vector(key: Value, xp: &PrimitiveVector) -> usiz fn concrete_type_to_primitive_vector(arg: &VectorRef) -> Result> { with_match_primitive_type_id!(arg.data_type().logical_type_id(), |$S| { let tmp = arg.to_arrow_array(); - let from = tmp.as_any().downcast_ref::>().expect("cast failed"); - let array = primitive_to_primitive(from, &Float64); + let array = cast(&tmp, &DataType::Float64)?; Ok(PrimitiveVector::new(array)) },{ unreachable!()