fix: Fix compiler errors in argmax/rate/median/norm_cdf (#716)

* fix: Fix compiler errors in argmax/rate/median/norm_cdf

* chore: Address CR comments
This commit is contained in:
Yingwen
2022-12-07 15:28:27 +08:00
committed by GitHub
parent a562199455
commit a898f846d1
9 changed files with 99 additions and 127 deletions

View File

@@ -35,8 +35,6 @@ pub use polyval::PolyvalAccumulatorCreator;
pub use scipy_stats_norm_cdf::ScipyStatsNormCdfAccumulatorCreator;
pub use scipy_stats_norm_pdf::ScipyStatsNormPdfAccumulatorCreator;
use crate::scalars::FunctionRegistry;
/// A function creates `AggregateFunctionCreator`.
/// "Aggregator" *is* AggregatorFunction. Since the later one is long, we named an short alias for it.
/// The two names might be used interchangeably.

View File

@@ -20,24 +20,22 @@ use common_query::error::{BadAccumulatorImplSnafu, CreateAccumulatorSnafu, Resul
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*;
use datatypes::prelude::*;
use datatypes::vectors::ConstantVector;
use datatypes::types::{LogicalPrimitiveType, WrapperType};
use datatypes::vectors::{ConstantVector, Helper};
use datatypes::with_match_primitive_type_id;
use snafu::ensure;
// https://numpy.org/doc/stable/reference/generated/numpy.argmax.html
// return the index of the max value
#[derive(Debug, Default)]
pub struct Argmax<T>
where
T: Primitive + PartialOrd,
{
pub struct Argmax<T> {
max: Option<T>,
n: u64,
}
impl<T> Argmax<T>
where
T: Primitive + PartialOrd,
T: PartialOrd,
{
fn update(&mut self, value: T, index: u64) {
if let Some(Ordering::Less) = self.max.partial_cmp(&Some(value)) {
@@ -49,8 +47,7 @@ where
impl<T> Accumulator for Argmax<T>
where
T: Primitive + PartialOrd,
for<'a> T: Scalar<RefType<'a> = T>,
T: WrapperType + PartialOrd,
{
fn state(&self) -> Result<Vec<Value>> {
match self.max {
@@ -66,10 +63,10 @@ where
let column = &values[0];
let column: &<T as Scalar>::VectorType = if column.is_const() {
let column: &ConstantVector = unsafe { VectorHelper::static_cast(column) };
unsafe { VectorHelper::static_cast(column.inner()) }
let column: &ConstantVector = unsafe { Helper::static_cast(column) };
unsafe { Helper::static_cast(column.inner()) }
} else {
unsafe { VectorHelper::static_cast(column) }
unsafe { Helper::static_cast(column) }
};
for (i, v) in column.iter_data().enumerate() {
if let Some(value) = v {
@@ -93,8 +90,8 @@ where
let max = &states[0];
let index = &states[1];
let max: &<T as Scalar>::VectorType = unsafe { VectorHelper::static_cast(max) };
let index: &<u64 as Scalar>::VectorType = unsafe { VectorHelper::static_cast(index) };
let max: &<T as Scalar>::VectorType = unsafe { Helper::static_cast(max) };
let index: &<u64 as Scalar>::VectorType = unsafe { Helper::static_cast(index) };
index
.iter_data()
.flatten()
@@ -122,7 +119,7 @@ impl AggregateFunctionCreator for ArgmaxAccumulatorCreator {
with_match_primitive_type_id!(
input_type.logical_type_id(),
|$S| {
Ok(Box::new(Argmax::<$S>::default()))
Ok(Box::new(Argmax::<<$S as LogicalPrimitiveType>::Wrapper>::default()))
},
{
let err_msg = format!(
@@ -154,7 +151,7 @@ impl AggregateFunctionCreator for ArgmaxAccumulatorCreator {
#[cfg(test)]
mod test {
use datatypes::vectors::PrimitiveVector;
use datatypes::vectors::Int32Vector;
use super::*;
#[test]
@@ -166,21 +163,19 @@ mod test {
// test update one not-null value
let mut argmax = Argmax::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(PrimitiveVector::<i32>::from(vec![Some(42)]))];
let v: Vec<VectorRef> = vec![Arc::new(Int32Vector::from(vec![Some(42)]))];
assert!(argmax.update_batch(&v).is_ok());
assert_eq!(Value::from(0_u64), argmax.evaluate().unwrap());
// test update one null value
let mut argmax = Argmax::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(PrimitiveVector::<i32>::from(vec![
Option::<i32>::None,
]))];
let v: Vec<VectorRef> = vec![Arc::new(Int32Vector::from(vec![Option::<i32>::None]))];
assert!(argmax.update_batch(&v).is_ok());
assert_eq!(Value::Null, argmax.evaluate().unwrap());
// test update no null-value batch
let mut argmax = Argmax::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(PrimitiveVector::<i32>::from(vec![
let v: Vec<VectorRef> = vec![Arc::new(Int32Vector::from(vec![
Some(-1i32),
Some(1),
Some(3),
@@ -190,7 +185,7 @@ mod test {
// test update null-value batch
let mut argmax = Argmax::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(PrimitiveVector::<i32>::from(vec![
let v: Vec<VectorRef> = vec![Arc::new(Int32Vector::from(vec![
Some(-2i32),
None,
Some(4),
@@ -201,7 +196,7 @@ mod test {
// test update with constant vector
let mut argmax = Argmax::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(ConstantVector::new(
Arc::new(PrimitiveVector::<i32>::from_vec(vec![4])),
Arc::new(Int32Vector::from_vec(vec![4])),
10,
))];
assert!(argmax.update_batch(&v).is_ok());

View File

@@ -22,7 +22,6 @@ use common_query::error::{
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*;
use datatypes::prelude::*;
use datatypes::types::{LogicalPrimitiveType, WrapperType};
use datatypes::value::ListValue;
use datatypes::vectors::{ConstantVector, Helper, ListVector};
use datatypes::with_match_primitive_type_id;
@@ -32,20 +31,12 @@ use snafu::{ensure, OptionExt, ResultExt};
// https://numpy.org/doc/stable/reference/generated/numpy.diff.html
// I is the input type, O is the output type.
#[derive(Debug, Default)]
pub struct Diff<I, O>
where
I: WrapperType,
O: WrapperType,
{
pub struct Diff<I, O> {
values: Vec<I>,
_phantom: PhantomData<O>,
}
impl<I, O> Diff<I, O>
where
I: WrapperType,
O: WrapperType,
{
impl<I, O> Diff<I, O> {
fn push(&mut self, value: I) {
self.values.push(value);
}

View File

@@ -25,7 +25,7 @@ use common_query::prelude::*;
use datatypes::prelude::*;
use datatypes::types::OrdPrimitive;
use datatypes::value::ListValue;
use datatypes::vectors::{ConstantVector, ListVector};
use datatypes::vectors::{ConstantVector, Helper, ListVector};
use datatypes::with_match_primitive_type_id;
use num::NumCast;
use snafu::{ensure, OptionExt, ResultExt};
@@ -51,7 +51,7 @@ use snafu::{ensure, OptionExt, ResultExt};
#[derive(Debug, Default)]
pub struct Median<T>
where
T: Primitive,
T: WrapperType,
{
greater: BinaryHeap<Reverse<OrdPrimitive<T>>>,
not_greater: BinaryHeap<OrdPrimitive<T>>,
@@ -59,7 +59,7 @@ where
impl<T> Median<T>
where
T: Primitive,
T: WrapperType,
{
fn push(&mut self, value: T) {
let value = OrdPrimitive::<T>(value);
@@ -87,8 +87,7 @@ where
// to use them.
impl<T> Accumulator for Median<T>
where
T: Primitive,
for<'a> T: Scalar<RefType<'a> = T>,
T: WrapperType + NumCast,
{
// This function serializes our state to `ScalarValue`, which DataFusion uses to pass this
// state between execution stages. Note that this can be arbitrary data.
@@ -105,7 +104,7 @@ where
.collect::<Vec<Value>>();
Ok(vec![Value::List(ListValue::new(
Some(Box::new(nums)),
T::default().into().data_type(),
T::LogicalType::build_data_type(),
))])
}
@@ -123,10 +122,10 @@ where
let mut len = 1;
let column: &<T as Scalar>::VectorType = if column.is_const() {
len = column.len();
let column: &ConstantVector = unsafe { VectorHelper::static_cast(column) };
unsafe { VectorHelper::static_cast(column.inner()) }
let column: &ConstantVector = unsafe { Helper::static_cast(column) };
unsafe { Helper::static_cast(column.inner()) }
} else {
unsafe { VectorHelper::static_cast(column) }
unsafe { Helper::static_cast(column) }
};
(0..len).for_each(|_| {
for v in column.iter_data().flatten() {
@@ -156,9 +155,10 @@ where
),
})?;
for state in states.values_iter() {
let state = state.context(FromScalarValueSnafu)?;
// merging state is simply accumulate stored numbers from others', so just call update
self.update_batch(&[state])?
if let Some(state) = state.context(FromScalarValueSnafu)? {
// merging state is simply accumulate stored numbers from others', so just call update
self.update_batch(&[state])?
}
}
Ok(())
}
@@ -202,7 +202,7 @@ impl AggregateFunctionCreator for MedianAccumulatorCreator {
with_match_primitive_type_id!(
input_type.logical_type_id(),
|$S| {
Ok(Box::new(Median::<$S>::default()))
Ok(Box::new(Median::<<$S as LogicalPrimitiveType>::Wrapper>::default()))
},
{
let err_msg = format!(
@@ -230,7 +230,7 @@ impl AggregateFunctionCreator for MedianAccumulatorCreator {
#[cfg(test)]
mod test {
use datatypes::vectors::PrimitiveVector;
use datatypes::vectors::Int32Vector;
use super::*;
#[test]
@@ -244,21 +244,19 @@ mod test {
// test update one not-null value
let mut median = Median::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(PrimitiveVector::<i32>::from(vec![Some(42)]))];
let v: Vec<VectorRef> = vec![Arc::new(Int32Vector::from(vec![Some(42)]))];
assert!(median.update_batch(&v).is_ok());
assert_eq!(Value::Int32(42), median.evaluate().unwrap());
// test update one null value
let mut median = Median::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(PrimitiveVector::<i32>::from(vec![
Option::<i32>::None,
]))];
let v: Vec<VectorRef> = vec![Arc::new(Int32Vector::from(vec![Option::<i32>::None]))];
assert!(median.update_batch(&v).is_ok());
assert_eq!(Value::Null, median.evaluate().unwrap());
// test update no null-value batch
let mut median = Median::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(PrimitiveVector::<i32>::from(vec![
let v: Vec<VectorRef> = vec![Arc::new(Int32Vector::from(vec![
Some(-1i32),
Some(1),
Some(2),
@@ -268,7 +266,7 @@ mod test {
// test update null-value batch
let mut median = Median::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(PrimitiveVector::<i32>::from(vec![
let v: Vec<VectorRef> = vec![Arc::new(Int32Vector::from(vec![
Some(-2i32),
None,
Some(3),
@@ -280,7 +278,7 @@ mod test {
// test update with constant vector
let mut median = Median::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(ConstantVector::new(
Arc::new(PrimitiveVector::<i32>::from_vec(vec![4])),
Arc::new(Int32Vector::from_vec(vec![4])),
10,
))];
assert!(median.update_batch(&v).is_ok());

View File

@@ -23,7 +23,7 @@ use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*;
use datatypes::prelude::*;
use datatypes::value::{ListValue, OrderedFloat};
use datatypes::vectors::{ConstantVector, Float64Vector, ListVector};
use datatypes::vectors::{ConstantVector, Float64Vector, Helper, ListVector};
use datatypes::with_match_primitive_type_id;
use num_traits::AsPrimitive;
use snafu::{ensure, OptionExt, ResultExt};
@@ -33,18 +33,12 @@ use statrs::statistics::Statistics;
// https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.norm.html
#[derive(Debug, Default)]
pub struct ScipyStatsNormCdf<T>
where
T: Primitive + AsPrimitive<f64> + std::iter::Sum<T>,
{
pub struct ScipyStatsNormCdf<T> {
values: Vec<T>,
x: Option<f64>,
}
impl<T> ScipyStatsNormCdf<T>
where
T: Primitive + AsPrimitive<f64> + std::iter::Sum<T>,
{
impl<T> ScipyStatsNormCdf<T> {
fn push(&mut self, value: T) {
self.values.push(value);
}
@@ -52,8 +46,8 @@ where
impl<T> Accumulator for ScipyStatsNormCdf<T>
where
T: Primitive + AsPrimitive<f64> + std::iter::Sum<T>,
for<'a> T: Scalar<RefType<'a> = T>,
T: WrapperType + std::iter::Sum<T>,
T::Native: AsPrimitive<f64>,
{
fn state(&self) -> Result<Vec<Value>> {
let nums = self
@@ -64,7 +58,7 @@ where
Ok(vec![
Value::List(ListValue::new(
Some(Box::new(nums)),
T::default().into().data_type(),
T::LogicalType::build_data_type(),
)),
self.x.into(),
])
@@ -86,14 +80,14 @@ where
let mut len = 1;
let column: &<T as Scalar>::VectorType = if column.is_const() {
len = column.len();
let column: &ConstantVector = unsafe { VectorHelper::static_cast(column) };
unsafe { VectorHelper::static_cast(column.inner()) }
let column: &ConstantVector = unsafe { Helper::static_cast(column) };
unsafe { Helper::static_cast(column.inner()) }
} else {
unsafe { VectorHelper::static_cast(column) }
unsafe { Helper::static_cast(column) }
};
let x = &values[1];
let x = VectorHelper::check_get_scalar::<f64>(x).context(error::InvalidInputsSnafu {
let x = Helper::check_get_scalar::<f64>(x).context(error::InvalidInputsSnafu {
err_msg: "expecting \"SCIPYSTATSNORMCDF\" function's second argument to be a positive integer",
})?;
let first = x.get(0);
@@ -160,19 +154,19 @@ where
),
})?;
for value in values.values_iter() {
let value = value.context(FromScalarValueSnafu)?;
let column: &<T as Scalar>::VectorType = unsafe { VectorHelper::static_cast(&value) };
for v in column.iter_data().flatten() {
self.push(v);
if let Some(value) = value.context(FromScalarValueSnafu)? {
let column: &<T as Scalar>::VectorType = unsafe { Helper::static_cast(&value) };
for v in column.iter_data().flatten() {
self.push(v);
}
}
}
Ok(())
}
fn evaluate(&self) -> Result<Value> {
let values = self.values.iter().map(|&v| v.as_()).collect::<Vec<_>>();
let mean = values.clone().mean();
let std_dev = values.std_dev();
let mean = self.values.iter().map(|v| v.into_native().as_()).mean();
let std_dev = self.values.iter().map(|v| v.into_native().as_()).std_dev();
if mean.is_nan() || std_dev.is_nan() {
Ok(Value::Null)
} else {
@@ -198,7 +192,7 @@ impl AggregateFunctionCreator for ScipyStatsNormCdfAccumulatorCreator {
with_match_primitive_type_id!(
input_type.logical_type_id(),
|$S| {
Ok(Box::new(ScipyStatsNormCdf::<$S>::default()))
Ok(Box::new(ScipyStatsNormCdf::<<$S as LogicalPrimitiveType>::Wrapper>::default()))
},
{
let err_msg = format!(
@@ -230,7 +224,7 @@ impl AggregateFunctionCreator for ScipyStatsNormCdfAccumulatorCreator {
#[cfg(test)]
mod test {
use datatypes::vectors::PrimitiveVector;
use datatypes::vectors::{Float64Vector, Int32Vector};
use super::*;
#[test]
@@ -244,12 +238,8 @@ mod test {
// test update no null-value batch
let mut scipy_stats_norm_cdf = ScipyStatsNormCdf::<i32>::default();
let v: Vec<VectorRef> = vec![
Arc::new(PrimitiveVector::<i32>::from(vec![
Some(-1i32),
Some(1),
Some(2),
])),
Arc::new(PrimitiveVector::<f64>::from(vec![
Arc::new(Int32Vector::from(vec![Some(-1i32), Some(1), Some(2)])),
Arc::new(Float64Vector::from(vec![
Some(2.0_f64),
Some(2.0_f64),
Some(2.0_f64),
@@ -264,13 +254,8 @@ mod test {
// test update null-value batch
let mut scipy_stats_norm_cdf = ScipyStatsNormCdf::<i32>::default();
let v: Vec<VectorRef> = vec![
Arc::new(PrimitiveVector::<i32>::from(vec![
Some(-2i32),
None,
Some(3),
Some(4),
])),
Arc::new(PrimitiveVector::<f64>::from(vec![
Arc::new(Int32Vector::from(vec![Some(-2i32), None, Some(3), Some(4)])),
Arc::new(Float64Vector::from(vec![
Some(2.0_f64),
None,
Some(2.0_f64),

View File

@@ -14,10 +14,10 @@
use std::fmt;
use arrow::array::Array;
use common_query::error::{FromArrowArraySnafu, Result, TypeCastSnafu};
use common_query::error::{self, Result};
use common_query::prelude::{Signature, Volatility};
use datatypes::arrow;
use datatypes::arrow::compute::kernels::{arithmetic, cast};
use datatypes::arrow::datatypes::DataType;
use datatypes::prelude::*;
use datatypes::vectors::{Helper, VectorRef};
use snafu::ResultExt;
@@ -51,28 +51,21 @@ impl Function for RateFunction {
let val = &columns[0].to_arrow_array();
let val_0 = val.slice(0, val.len() - 1);
let val_1 = val.slice(1, val.len() - 1);
let dv = arrow::compute::arithmetics::sub(&*val_1, &*val_0);
let dv = arithmetic::subtract_dyn(&val_1, &val_0).context(error::ArrowComputeSnafu)?;
let ts = &columns[1].to_arrow_array();
let ts_0 = ts.slice(0, ts.len() - 1);
let ts_1 = ts.slice(1, ts.len() - 1);
let dt = arrow::compute::arithmetics::sub(&*ts_1, &*ts_0);
fn all_to_f64(array: &dyn Array) -> Result<Box<dyn Array>> {
Ok(arrow::compute::cast::cast(
array,
&arrow::datatypes::DataType::Float64,
arrow::compute::cast::CastOptions {
wrapped: true,
partial: true,
},
)
.context(TypeCastSnafu {
typ: arrow::datatypes::DataType::Float64,
})?)
}
let dv = all_to_f64(&*dv)?;
let dt = all_to_f64(&*dt)?;
let rate = arrow::compute::arithmetics::div(&*dv, &*dt);
let v = Helper::try_into_vector(&rate).context(FromArrowArraySnafu)?;
let dt = arithmetic::subtract_dyn(&ts_1, &ts_0).context(error::ArrowComputeSnafu)?;
let dv = cast::cast(&dv, &DataType::Float64).context(error::TypeCastSnafu {
typ: DataType::Float64,
})?;
let dt = cast::cast(&dt, &DataType::Float64).context(error::TypeCastSnafu {
typ: DataType::Float64,
})?;
let rate = arithmetic::divide_dyn(&dv, &dt).context(error::ArrowComputeSnafu)?;
let v = Helper::try_into_vector(&rate).context(error::FromArrowArraySnafu)?;
Ok(v)
}
}
@@ -81,9 +74,8 @@ impl Function for RateFunction {
mod tests {
use std::sync::Arc;
use arrow::array::Float64Array;
use common_query::prelude::TypeSignature;
use datatypes::vectors::{Float32Vector, Int64Vector};
use datatypes::vectors::{Float32Vector, Float64Vector, Int64Vector};
use super::*;
#[test]
@@ -108,9 +100,7 @@ mod tests {
Arc::new(Int64Vector::from_vec(ts)),
];
let vector = rate.eval(FunctionContext::default(), &args).unwrap();
let arr = vector.to_arrow_array();
let expect = Arc::new(Float64Array::from_vec(vec![2.0, 3.0]));
let res = arrow::compute::comparison::eq(&*arr, &*expect);
res.iter().for_each(|x| assert!(matches!(x, Some(true))));
let expect: VectorRef = Arc::new(Float64Vector::from_vec(vec![2.0, 3.0]));
assert_eq!(expect, vector);
}
}

View File

@@ -127,10 +127,20 @@ pub enum Error {
source: BoxedError,
},
#[snafu(display("Fail to cast array to {:?}, source: {}", typ, source))]
#[snafu(display("Failed to cast array to {:?}, source: {}", typ, source))]
TypeCast {
source: ArrowError,
typ: arrow::datatypes::DataType,
backtrace: Backtrace,
},
#[snafu(display(
"Failed to perform compute operation on arrow arrays, source: {}",
source
))]
ArrowCompute {
source: ArrowError,
backtrace: Backtrace,
},
#[snafu(display("Query engine fail to cast value: {}", source))]
@@ -139,7 +149,7 @@ pub enum Error {
source: DataTypeError,
},
#[snafu(display("Fail to get scalar vector, {}", source))]
#[snafu(display("Failed to get scalar vector, {}", source))]
GetScalarVector {
#[snafu(backtrace)]
source: DataTypeError,
@@ -159,7 +169,8 @@ impl ErrorExt for Error {
| Error::InvalidInputCol { .. }
| Error::BadAccumulatorImpl { .. }
| Error::ToScalarValue { .. }
| Error::GetScalarVector { .. } => StatusCode::EngineExecuteQuery,
| Error::GetScalarVector { .. }
| Error::ArrowCompute { .. } => StatusCode::EngineExecuteQuery,
Error::InvalidInputs { source, .. }
| Error::IntoVector { source, .. }

View File

@@ -16,5 +16,6 @@ pub use crate::data_type::{ConcreteDataType, DataType, DataTypeRef};
pub use crate::macros::*;
pub use crate::scalars::{Scalar, ScalarRef, ScalarVector, ScalarVectorBuilder};
pub use crate::type_id::LogicalTypeId;
pub use crate::types::{LogicalPrimitiveType, WrapperType};
pub use crate::value::{Value, ValueRef};
pub use crate::vectors::{MutableVector, Validity, Vector, VectorRef};

View File

@@ -31,7 +31,10 @@ pub use list_type::ListType;
pub use null_type::NullType;
pub use primitive_type::{
Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, LogicalPrimitiveType,
NativeType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, WrapperType,
NativeType, OrdPrimitive, UInt16Type, UInt32Type, UInt64Type, UInt8Type, WrapperType,
};
pub use string_type::StringType;
pub use timestamp_type::*;
pub use timestamp_type::{
TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
TimestampSecondType, TimestampType,
};