refactor!: simplify NativeType trait and remove percentile UDAF (#4758)

* refactor!: simplify NativeType trait and remove percentile UDAF

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* remove NativeType

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* recover a mis-deleted case

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

---------

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
This commit is contained in:
Ruihang Xia
2024-09-23 18:55:20 +08:00
committed by GitHub
parent 2feddca1cb
commit 5c64f0ce09
8 changed files with 5 additions and 583 deletions

View File

@@ -16,7 +16,6 @@ mod argmax;
mod argmin;
mod diff;
mod mean;
mod percentile;
mod polyval;
mod scipy_stats_norm_cdf;
mod scipy_stats_norm_pdf;
@@ -28,7 +27,6 @@ pub use argmin::ArgminAccumulatorCreator;
use common_query::logical_plan::AggregateFunctionCreatorRef;
pub use diff::DiffAccumulatorCreator;
pub use mean::MeanAccumulatorCreator;
pub use percentile::PercentileAccumulatorCreator;
pub use polyval::PolyvalAccumulatorCreator;
pub use scipy_stats_norm_cdf::ScipyStatsNormCdfAccumulatorCreator;
pub use scipy_stats_norm_pdf::ScipyStatsNormPdfAccumulatorCreator;
@@ -91,7 +89,6 @@ impl AggregateFunctions {
register_aggr_func!("polyval", 2, PolyvalAccumulatorCreator);
register_aggr_func!("argmax", 1, ArgmaxAccumulatorCreator);
register_aggr_func!("argmin", 1, ArgminAccumulatorCreator);
register_aggr_func!("percentile", 2, PercentileAccumulatorCreator);
register_aggr_func!("scipystatsnormcdf", 2, ScipyStatsNormCdfAccumulatorCreator);
register_aggr_func!("scipystatsnormpdf", 2, ScipyStatsNormPdfAccumulatorCreator);
}

View File

@@ -1,436 +0,0 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::sync::Arc;
use common_macro::{as_aggr_func_creator, AggrFuncTypeStore};
use common_query::error::{
self, BadAccumulatorImplSnafu, CreateAccumulatorSnafu, DowncastVectorSnafu,
FromScalarValueSnafu, InvalidInputColSnafu, Result,
};
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*;
use datatypes::prelude::*;
use datatypes::types::OrdPrimitive;
use datatypes::value::{ListValue, OrderedFloat};
use datatypes::vectors::{ConstantVector, Float64Vector, Helper, ListVector};
use datatypes::with_match_primitive_type_id;
use num::NumCast;
use snafu::{ensure, OptionExt, ResultExt};
// https://numpy.org/doc/stable/reference/generated/numpy.percentile.html?highlight=percentile#numpy.percentile
// if the p is 50,then the Percentile become median
// we use two heap great and not_greater
// the not_greater push the value that smaller than P-value
// the greater push the value that bigger than P-value
// just like the percentile in numpy:
// Given a vector V of length N, the q-th percentile of V is the value q/100 of the way from the minimum to the maximum in a sorted copy of V.
// The values and distances of the two nearest neighbors as well as the method parameter will determine the percentile
// if the normalized ranking does not match the location of q exactly.
// This function is the same as the median if q=50, the same as the minimum if q=0 and the same as the maximum if q=100.
// This optional method parameter specifies the method to use when the desired quantile lies between two data points i < j.
// If g is the fractional part of the index surrounded by i and alpha and beta are correction constants modifying i and j.
// i+g = (q-alpha)/(n-alpha-beta+1)
// Below, 'q' is the quantile value, 'n' is the sample size and alpha and beta are constants. The following formula gives an interpolation "i + g" of where the quantile would be in the sorted sample.
// With 'i' being the floor and 'g' the fractional part of the result.
// the default method is linear where
// alpha = 1
// beta = 1
#[derive(Debug, Default)]
pub struct Percentile<T>
where
T: WrapperType,
{
greater: BinaryHeap<Reverse<OrdPrimitive<T>>>,
not_greater: BinaryHeap<OrdPrimitive<T>>,
n: u64,
p: Option<f64>,
}
impl<T> Percentile<T>
where
T: WrapperType,
{
fn push(&mut self, value: T) {
let value = OrdPrimitive::<T>(value);
self.n += 1;
if self.not_greater.is_empty() {
self.not_greater.push(value);
return;
}
// to keep the not_greater length == floor+1
// so to ensure the peek of the not_greater is array[floor]
// and the peek of the greater is array[floor+1]
let p = self.p.unwrap_or(0.0_f64);
let floor = (((self.n - 1) as f64) * p / (100_f64)).floor();
if value <= *self.not_greater.peek().unwrap() {
self.not_greater.push(value);
if self.not_greater.len() > (floor + 1.0) as usize {
self.greater.push(Reverse(self.not_greater.pop().unwrap()));
}
} else {
self.greater.push(Reverse(value));
if self.not_greater.len() < (floor + 1.0) as usize {
self.not_greater.push(self.greater.pop().unwrap().0);
}
}
}
}
impl<T> Accumulator for Percentile<T>
where
T: WrapperType,
{
fn state(&self) -> Result<Vec<Value>> {
let nums = self
.greater
.iter()
.map(|x| &x.0)
.chain(self.not_greater.iter())
.map(|&n| n.into())
.collect::<Vec<Value>>();
Ok(vec![
Value::List(ListValue::new(nums, T::LogicalType::build_data_type())),
self.p.into(),
])
}
fn update_batch(&mut self, values: &[VectorRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
}
ensure!(values.len() == 2, InvalidInputStateSnafu);
ensure!(values[0].len() == values[1].len(), InvalidInputStateSnafu);
if values[0].len() == 0 {
return Ok(());
}
// This is a unary accumulator, so only one column is provided.
let column = &values[0];
let mut len = 1;
let column: &<T as Scalar>::VectorType = if column.is_const() {
len = column.len();
let column: &ConstantVector = unsafe { Helper::static_cast(column) };
unsafe { Helper::static_cast(column.inner()) }
} else {
unsafe { Helper::static_cast(column) }
};
let x = &values[1];
let x = Helper::check_get_scalar::<f64>(x).context(error::InvalidInputTypeSnafu {
err_msg: "expecting \"POLYVAL\" function's second argument to be float64",
})?;
// `get(0)` is safe because we have checked `values[1].len() == values[0].len() != 0`
let first = x.get(0);
ensure!(!first.is_null(), InvalidInputColSnafu);
for i in 1..x.len() {
ensure!(first == x.get(i), InvalidInputColSnafu);
}
let first = match first {
Value::Float64(OrderedFloat(v)) => v,
// unreachable because we have checked `first` is not null and is i64 above
_ => unreachable!(),
};
if let Some(p) = self.p {
ensure!(p == first, InvalidInputColSnafu);
} else {
self.p = Some(first);
};
(0..len).for_each(|_| {
for v in column.iter_data().flatten() {
self.push(v);
}
});
Ok(())
}
fn merge_batch(&mut self, states: &[VectorRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
}
ensure!(
states.len() == 2,
BadAccumulatorImplSnafu {
err_msg: "expect 2 states in `merge_batch`"
}
);
let p = &states[1];
let p = p
.as_any()
.downcast_ref::<Float64Vector>()
.with_context(|| DowncastVectorSnafu {
err_msg: format!(
"expect float64vector, got vector type {}",
p.vector_type_name()
),
})?;
let p = p.get(0);
if p.is_null() {
return Ok(());
}
let p = match p {
Value::Float64(OrderedFloat(p)) => p,
_ => unreachable!(),
};
self.p = Some(p);
let values = &states[0];
let values = values
.as_any()
.downcast_ref::<ListVector>()
.with_context(|| DowncastVectorSnafu {
err_msg: format!(
"expect ListVector, got vector type {}",
values.vector_type_name()
),
})?;
for value in values.values_iter() {
if let Some(value) = value.context(FromScalarValueSnafu)? {
let column: &<T as Scalar>::VectorType = unsafe { Helper::static_cast(&value) };
for v in column.iter_data().flatten() {
self.push(v);
}
}
}
Ok(())
}
fn evaluate(&self) -> Result<Value> {
if self.not_greater.is_empty() {
assert!(
self.greater.is_empty(),
"not expected in two-heap percentile algorithm, there must be a bug when implementing it"
);
}
let not_greater = self.not_greater.peek();
if not_greater.is_none() {
return Ok(Value::Null);
}
let not_greater = (*self.not_greater.peek().unwrap()).as_primitive();
let percentile = if self.greater.is_empty() {
NumCast::from(not_greater).unwrap()
} else {
let greater = self.greater.peek().unwrap();
let p = if let Some(p) = self.p {
p
} else {
return Ok(Value::Null);
};
let fract = (((self.n - 1) as f64) * p / 100_f64).fract();
let not_greater_v: f64 = NumCast::from(not_greater).unwrap();
let greater_v: f64 = NumCast::from(greater.0.as_primitive()).unwrap();
not_greater_v * (1.0 - fract) + greater_v * fract
};
Ok(Value::from(percentile))
}
}
#[as_aggr_func_creator]
#[derive(Debug, Default, AggrFuncTypeStore)]
pub struct PercentileAccumulatorCreator {}
impl AggregateFunctionCreator for PercentileAccumulatorCreator {
fn creator(&self) -> AccumulatorCreatorFunction {
let creator: AccumulatorCreatorFunction = Arc::new(move |types: &[ConcreteDataType]| {
let input_type = &types[0];
with_match_primitive_type_id!(
input_type.logical_type_id(),
|$S| {
Ok(Box::new(Percentile::<<$S as LogicalPrimitiveType>::Wrapper>::default()))
},
{
let err_msg = format!(
"\"PERCENTILE\" aggregate function not support data type {:?}",
input_type.logical_type_id(),
);
CreateAccumulatorSnafu { err_msg }.fail()?
}
)
});
creator
}
fn output_type(&self) -> Result<ConcreteDataType> {
let input_types = self.input_types()?;
ensure!(input_types.len() == 2, InvalidInputStateSnafu);
// unwrap is safe because we have checked input_types len must equals 1
Ok(ConcreteDataType::float64_datatype())
}
fn state_types(&self) -> Result<Vec<ConcreteDataType>> {
let input_types = self.input_types()?;
ensure!(input_types.len() == 2, InvalidInputStateSnafu);
Ok(vec![
ConcreteDataType::list_datatype(input_types.into_iter().next().unwrap()),
ConcreteDataType::float64_datatype(),
])
}
}
#[cfg(test)]
mod test {
use datatypes::vectors::{Float64Vector, Int32Vector};
use super::*;
#[test]
fn test_update_batch() {
// test update empty batch, expect not updating anything
let mut percentile = Percentile::<i32>::default();
percentile.update_batch(&[]).unwrap();
assert!(percentile.not_greater.is_empty());
assert!(percentile.greater.is_empty());
assert_eq!(Value::Null, percentile.evaluate().unwrap());
// test update one not-null value
let mut percentile = Percentile::<i32>::default();
let v: Vec<VectorRef> = vec![
Arc::new(Int32Vector::from(vec![Some(42)])),
Arc::new(Float64Vector::from(vec![Some(100.0_f64)])),
];
percentile.update_batch(&v).unwrap();
assert_eq!(Value::from(42.0_f64), percentile.evaluate().unwrap());
// test update one null value
let mut percentile = Percentile::<i32>::default();
let v: Vec<VectorRef> = vec![
Arc::new(Int32Vector::from(vec![Option::<i32>::None])),
Arc::new(Float64Vector::from(vec![Some(100.0_f64)])),
];
percentile.update_batch(&v).unwrap();
assert_eq!(Value::Null, percentile.evaluate().unwrap());
// test update no null-value batch
let mut percentile = Percentile::<i32>::default();
let v: Vec<VectorRef> = vec![
Arc::new(Int32Vector::from(vec![Some(-1i32), Some(1), Some(2)])),
Arc::new(Float64Vector::from(vec![
Some(100.0_f64),
Some(100.0_f64),
Some(100.0_f64),
])),
];
percentile.update_batch(&v).unwrap();
assert_eq!(Value::from(2_f64), percentile.evaluate().unwrap());
// test update null-value batch
let mut percentile = Percentile::<i32>::default();
let v: Vec<VectorRef> = vec![
Arc::new(Int32Vector::from(vec![Some(-2i32), None, Some(3), Some(4)])),
Arc::new(Float64Vector::from(vec![
Some(100.0_f64),
Some(100.0_f64),
Some(100.0_f64),
Some(100.0_f64),
])),
];
percentile.update_batch(&v).unwrap();
assert_eq!(Value::from(4_f64), percentile.evaluate().unwrap());
// test update with constant vector
let mut percentile = Percentile::<i32>::default();
let v: Vec<VectorRef> = vec![
Arc::new(ConstantVector::new(
Arc::new(Int32Vector::from_vec(vec![4])),
2,
)),
Arc::new(Float64Vector::from(vec![Some(100.0_f64), Some(100.0_f64)])),
];
percentile.update_batch(&v).unwrap();
assert_eq!(Value::from(4_f64), percentile.evaluate().unwrap());
// test left border
let mut percentile = Percentile::<i32>::default();
let v: Vec<VectorRef> = vec![
Arc::new(Int32Vector::from(vec![Some(-1i32), Some(1), Some(2)])),
Arc::new(Float64Vector::from(vec![
Some(0.0_f64),
Some(0.0_f64),
Some(0.0_f64),
])),
];
percentile.update_batch(&v).unwrap();
assert_eq!(Value::from(-1.0_f64), percentile.evaluate().unwrap());
// test medium
let mut percentile = Percentile::<i32>::default();
let v: Vec<VectorRef> = vec![
Arc::new(Int32Vector::from(vec![Some(-1i32), Some(1), Some(2)])),
Arc::new(Float64Vector::from(vec![
Some(50.0_f64),
Some(50.0_f64),
Some(50.0_f64),
])),
];
percentile.update_batch(&v).unwrap();
assert_eq!(Value::from(1.0_f64), percentile.evaluate().unwrap());
// test right border
let mut percentile = Percentile::<i32>::default();
let v: Vec<VectorRef> = vec![
Arc::new(Int32Vector::from(vec![Some(-1i32), Some(1), Some(2)])),
Arc::new(Float64Vector::from(vec![
Some(100.0_f64),
Some(100.0_f64),
Some(100.0_f64),
])),
];
percentile.update_batch(&v).unwrap();
assert_eq!(Value::from(2.0_f64), percentile.evaluate().unwrap());
// the following is the result of numpy.percentile
// numpy.percentile
// a = np.array([[10,7,4]])
// np.percentile(a,40)
// >> 6.400000000000
let mut percentile = Percentile::<i32>::default();
let v: Vec<VectorRef> = vec![
Arc::new(Int32Vector::from(vec![Some(10i32), Some(7), Some(4)])),
Arc::new(Float64Vector::from(vec![
Some(40.0_f64),
Some(40.0_f64),
Some(40.0_f64),
])),
];
percentile.update_batch(&v).unwrap();
assert_eq!(Value::from(6.400000000_f64), percentile.evaluate().unwrap());
// the following is the result of numpy.percentile
// a = np.array([[10,7,4]])
// np.percentile(a,95)
// >> 9.7000000000000011
let mut percentile = Percentile::<i32>::default();
let v: Vec<VectorRef> = vec![
Arc::new(Int32Vector::from(vec![Some(10i32), Some(7), Some(4)])),
Arc::new(Float64Vector::from(vec![
Some(95.0_f64),
Some(95.0_f64),
Some(95.0_f64),
])),
];
percentile.update_batch(&v).unwrap();
assert_eq!(
Value::from(9.700_000_000_000_001_f64),
percentile.evaluate().unwrap()
);
}
}

View File

@@ -48,7 +48,7 @@ pub use list_type::ListType;
pub use null_type::NullType;
pub use primitive_type::{
Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, LogicalPrimitiveType,
NativeType, OrdPrimitive, UInt16Type, UInt32Type, UInt64Type, UInt8Type, WrapperType,
OrdPrimitive, UInt16Type, UInt32Type, UInt64Type, UInt8Type, WrapperType,
};
pub use string_type::StringType;
pub use time_type::{

View File

@@ -18,7 +18,6 @@ use std::fmt;
use arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType as ArrowDataType};
use common_time::interval::IntervalUnit;
use common_time::{Date, DateTime};
use num::NumCast;
use serde::{Deserialize, Serialize};
use snafu::OptionExt;
@@ -31,27 +30,6 @@ use crate::types::{DateTimeType, DateType};
use crate::value::{Value, ValueRef};
use crate::vectors::{MutableVector, PrimitiveVector, PrimitiveVectorBuilder, Vector};
/// Data types that can be used as arrow's native type.
pub trait NativeType: ArrowNativeType + NumCast {}
macro_rules! impl_native_type {
($Type: ident) => {
impl NativeType for $Type {}
};
}
impl_native_type!(u8);
impl_native_type!(u16);
impl_native_type!(u32);
impl_native_type!(u64);
impl_native_type!(i8);
impl_native_type!(i16);
impl_native_type!(i32);
impl_native_type!(i64);
impl_native_type!(i128);
impl_native_type!(f32);
impl_native_type!(f64);
/// Represents the wrapper type that wraps a native type using the `newtype pattern`,
/// such as [Date](`common_time::Date`) is a wrapper type for the underlying native
/// type `i32`.
@@ -70,7 +48,7 @@ pub trait WrapperType:
/// Logical primitive type that this wrapper type belongs to.
type LogicalType: LogicalPrimitiveType<Wrapper = Self, Native = Self::Native>;
/// The underlying native type.
type Native: NativeType;
type Native: ArrowNativeType;
/// Convert native type into this wrapper type.
fn from_native(value: Self::Native) -> Self;
@@ -84,7 +62,7 @@ pub trait LogicalPrimitiveType: 'static + Sized {
/// Arrow primitive type of this logical type.
type ArrowPrimitive: ArrowPrimitiveType<Native = Self::Native>;
/// Native (physical) type of this logical type.
type Native: NativeType;
type Native: ArrowNativeType;
/// Wrapper type that the vector returns.
type Wrapper: WrapperType<LogicalType = Self, Native = Self::Native>
+ for<'a> Scalar<VectorType = PrimitiveVector<Self>, RefType<'a> = Self::Wrapper>
@@ -107,7 +85,7 @@ pub trait LogicalPrimitiveType: 'static + Sized {
/// A new type for [WrapperType], complement the `Ord` feature for it. Wrapping non ordered
/// primitive types like `f32` and `f64` in `OrdPrimitive` can make them be used in places that
/// require `Ord`. For example, in `Median` or `Percentile` UDAFs.
/// require `Ord`. For example, in `Median` UDAFs.
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct OrdPrimitive<T: WrapperType>(pub T);

View File

@@ -25,7 +25,6 @@ mod argmax_test;
mod argmin_test;
mod mean_test;
mod my_sum_udaf_example;
mod percentile_test;
mod polyval_test;
mod query_engine_test;
mod scipy_stats_norm_cdf_test;

View File

@@ -1,97 +0,0 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use common_recordbatch::RecordBatch;
use datatypes::for_all_primitive_types;
use datatypes::prelude::*;
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::vectors::Int32Vector;
use function::{create_query_engine, get_numbers_from_table};
use num_traits::AsPrimitive;
use table::test_util::MemTable;
use super::new_query_engine_with_table;
use crate::error::Result;
use crate::tests::{exec_selection, function};
use crate::QueryEngine;
#[tokio::test]
async fn test_percentile_aggregator() -> Result<()> {
common_telemetry::init_default_ut_logging();
let engine = create_query_engine();
macro_rules! test_percentile {
([], $( { $T:ty } ),*) => {
$(
let column_name = format!("{}_number", std::any::type_name::<$T>());
test_percentile_success::<$T>(&column_name, "numbers", engine.clone()).await?;
)*
}
}
for_all_primitive_types! { test_percentile }
Ok(())
}
#[tokio::test]
async fn test_percentile_correctness() -> Result<()> {
let engine = create_correctness_engine();
let sql = String::from("select PERCENTILE(corr_number,88.0) as percentile from corr_numbers");
let record_batch = exec_selection(engine, &sql).await;
let column = record_batch[0].column(0);
let value = column.get(0);
assert_eq!(value, Value::from(9.280_000_000_000_001_f64));
Ok(())
}
async fn test_percentile_success<T>(
column_name: &str,
table_name: &str,
engine: Arc<dyn QueryEngine>,
) -> Result<()>
where
T: WrapperType + AsPrimitive<f64>,
{
let sql = format!("select PERCENTILE({column_name},50.0) as percentile from {table_name}");
let result = exec_selection(engine.clone(), &sql).await;
let value = function::get_value_from_batches("percentile", result);
let numbers = get_numbers_from_table::<T>(column_name, table_name, engine.clone()).await;
let expected_value = numbers.iter().map(|&n| n.as_()).collect::<Vec<f64>>();
let expected_value: inc_stats::Percentiles<f64> = expected_value.iter().cloned().collect();
let expected_value = expected_value.percentile(0.5).unwrap();
assert_eq!(value, expected_value.into());
Ok(())
}
fn create_correctness_engine() -> Arc<dyn QueryEngine> {
// create engine
let mut column_schemas = vec![];
let mut columns = vec![];
let column_schema = ColumnSchema::new("corr_number", ConcreteDataType::int32_datatype(), true);
column_schemas.push(column_schema);
let numbers = [3_i32, 6_i32, 8_i32, 10_i32];
let column: VectorRef = Arc::new(Int32Vector::from_slice(numbers));
columns.push(column);
let schema = Arc::new(Schema::new(column_schemas));
let number_table = MemTable::table("corr_numbers", RecordBatch::new(schema, columns).unwrap());
new_query_engine_with_table(number_table)
}

View File

@@ -345,7 +345,7 @@ fn run_builtin_fn_testcases() {
match case.expect{
Ok(v) => {
error!("\nError:\n{err_res}");
panic!("Expect Ok: {v:?}, found Error");
panic!("Expect Ok: {v:?}, found Error in case {}", case.script);
},
Err(err) => {
if !err_res.contains(&err){

View File

@@ -989,25 +989,6 @@ argmin(p)"#,
},
script: r#"
from greptime import *
percentile(x, p)"#,
expect: Ok((
ty: Float64,
value: Float(-0.97)
))
),
TestCase(
input: {
"x": Var(
ty: Float64,
value: FloatVec([-1.0, 2.0, 3.0])
),
"p": Var(
ty: Float64,
value: FloatVec([0.5, 0.5, 0.5])
)
},
script: r#"
from greptime import *
scipy_stats_norm_cdf(x, p)"#,
expect: Ok((
ty: Float64,