feat: a new type for supplying Ord to Primitive (#255)

Co-authored-by: luofucong <luofucong@greptime.com>
This commit is contained in:
LFC
2022-09-15 18:32:55 +08:00
committed by GitHub
parent dfa3012396
commit fb6153f7e0
7 changed files with 128 additions and 138 deletions

View File

@@ -9,9 +9,10 @@ use common_query::error::{
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*;
use datatypes::prelude::*;
use datatypes::types::OrdPrimitive;
use datatypes::value::ListValue;
use datatypes::vectors::{ConstantVector, ListVector};
use datatypes::with_match_ordered_primitive_type_id;
use datatypes::with_match_primitive_type_id;
use num::NumCast;
use snafu::{ensure, OptionExt, ResultExt};
@@ -36,17 +37,19 @@ use snafu::{ensure, OptionExt, ResultExt};
#[derive(Debug, Default)]
pub struct Median<T>
where
T: Primitive + Ord,
T: Primitive,
{
greater: BinaryHeap<Reverse<T>>,
not_greater: BinaryHeap<T>,
greater: BinaryHeap<Reverse<OrdPrimitive<T>>>,
not_greater: BinaryHeap<OrdPrimitive<T>>,
}
impl<T> Median<T>
where
T: Primitive + Ord,
T: Primitive,
{
fn push(&mut self, value: T) {
let value = OrdPrimitive::<T>(value);
if self.not_greater.is_empty() {
self.not_greater.push(value);
return;
@@ -70,7 +73,7 @@ where
// to use them.
impl<T> Accumulator for Median<T>
where
T: Primitive + Ord,
T: Primitive,
for<'a> T: Scalar<RefType<'a> = T>,
{
// This function serializes our state to `ScalarValue`, which DataFusion uses to pass this
@@ -165,8 +168,8 @@ where
let greater = self.greater.peek().unwrap();
// the following three NumCast's `unwrap`s are safe because T is primitive
let not_greater_v: f64 = NumCast::from(not_greater).unwrap();
let greater_v: f64 = NumCast::from(greater.0).unwrap();
let not_greater_v: f64 = NumCast::from(not_greater.as_primitive()).unwrap();
let greater_v: f64 = NumCast::from(greater.0.as_primitive()).unwrap();
let median: T = NumCast::from((not_greater_v + greater_v) / 2.0).unwrap();
median.into()
};
@@ -182,7 +185,7 @@ impl AggregateFunctionCreator for MedianAccumulatorCreator {
fn creator(&self) -> AccumulatorCreatorFunction {
let creator: AccumulatorCreatorFunction = Arc::new(move |types: &[ConcreteDataType]| {
let input_type = &types[0];
with_match_ordered_primitive_type_id!(
with_match_primitive_type_id!(
input_type.logical_type_id(),
|$S| {
Ok(Box::new(Median::<$S>::default()))

View File

@@ -10,9 +10,10 @@ use common_query::error::{
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*;
use datatypes::prelude::*;
use datatypes::types::OrdPrimitive;
use datatypes::value::{ListValue, OrderedFloat};
use datatypes::vectors::{ConstantVector, Float64Vector, ListVector};
use datatypes::with_match_ordered_primitive_type_id;
use datatypes::with_match_primitive_type_id;
use num::NumCast;
use snafu::{ensure, OptionExt, ResultExt};
@@ -37,19 +38,21 @@ use snafu::{ensure, OptionExt, ResultExt};
#[derive(Debug, Default)]
pub struct Percentile<T>
where
T: Primitive + Ord,
T: Primitive,
{
greater: BinaryHeap<Reverse<T>>,
not_greater: BinaryHeap<T>,
greater: BinaryHeap<Reverse<OrdPrimitive<T>>>,
not_greater: BinaryHeap<OrdPrimitive<T>>,
n: u64,
p: Option<f64>,
}
impl<T> Percentile<T>
where
T: Primitive + Ord,
T: Primitive,
{
fn push(&mut self, value: T) {
let value = OrdPrimitive::<T>(value);
self.n += 1;
if self.not_greater.is_empty() {
self.not_greater.push(value);
@@ -76,7 +79,7 @@ where
impl<T> Accumulator for Percentile<T>
where
T: Primitive + Ord,
T: Primitive,
for<'a> T: Scalar<RefType<'a> = T>,
{
fn state(&self) -> Result<Vec<Value>> {
@@ -212,7 +215,7 @@ where
if not_greater.is_none() {
return Ok(Value::Null);
}
let not_greater = *self.not_greater.peek().unwrap();
let not_greater = (*self.not_greater.peek().unwrap()).as_primitive();
let percentile = if self.greater.is_empty() {
NumCast::from(not_greater).unwrap()
} else {
@@ -224,7 +227,7 @@ where
};
let fract = (((self.n - 1) as f64) * p / 100_f64).fract();
let not_greater_v: f64 = NumCast::from(not_greater).unwrap();
let greater_v: f64 = NumCast::from(greater.0).unwrap();
let greater_v: f64 = NumCast::from(greater.0.as_primitive()).unwrap();
not_greater_v * (1.0 - fract) + greater_v * fract
};
Ok(Value::from(percentile))
@@ -239,7 +242,7 @@ impl AggregateFunctionCreator for PercentileAccumulatorCreator {
fn creator(&self) -> AccumulatorCreatorFunction {
let creator: AccumulatorCreatorFunction = Arc::new(move |types: &[ConcreteDataType]| {
let input_type = &types[0];
with_match_ordered_primitive_type_id!(
with_match_primitive_type_id!(
input_type.logical_type_id(),
|$S| {
Ok(Box::new(Percentile::<$S>::default()))

View File

@@ -38,23 +38,6 @@ macro_rules! for_all_primitive_types {
};
}
#[macro_export]
macro_rules! for_all_ordered_primitive_types {
($macro:tt $(, $x:tt)*) => {
$macro! {
[$($x),*],
{ i8 },
{ i16 },
{ i32 },
{ i64 },
{ u8 },
{ u16 },
{ u32 },
{ u64 }
}
};
}
#[macro_export]
macro_rules! with_match_primitive_type_id {
($key_type:expr, | $_:tt $T:ident | $body:tt, $nbody:tt) => {{
@@ -81,27 +64,3 @@ macro_rules! with_match_primitive_type_id {
}
}};
}
#[macro_export]
macro_rules! with_match_ordered_primitive_type_id {
($key_type:expr, | $_:tt $T:ident | $body:tt, $nbody:tt) => {{
macro_rules! __with_ty__ {
( $_ $T:ident ) => {
$body
};
}
match $key_type {
LogicalTypeId::Int8 => __with_ty__! { i8 },
LogicalTypeId::Int16 => __with_ty__! { i16 },
LogicalTypeId::Int32 => __with_ty__! { i32 },
LogicalTypeId::Int64 => __with_ty__! { i64 },
LogicalTypeId::UInt8 => __with_ty__! { u8 },
LogicalTypeId::UInt16 => __with_ty__! { u16 },
LogicalTypeId::UInt32 => __with_ty__! { u32 },
LogicalTypeId::UInt64 => __with_ty__! { u64 },
_ => $nbody,
}
}};
}

View File

@@ -15,7 +15,7 @@ pub use date::DateType;
pub use datetime::DateTimeType;
pub use list_type::ListType;
pub use null_type::NullType;
pub use primitive_traits::Primitive;
pub use primitive_traits::{OrdPrimitive, Primitive};
pub use primitive_type::{
Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, PrimitiveElement,
PrimitiveType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,

View File

@@ -1,3 +1,5 @@
use std::cmp::Ordering;
use arrow::compute::arithmetics::basic::NativeArithmetics;
use arrow::types::NativeType;
use num::NumCast;
@@ -41,3 +43,82 @@ impl_primitive!(i32, i64);
impl_primitive!(i64, i64);
impl_primitive!(f32, f64);
impl_primitive!(f64, f64);
/// A new type for [Primitive], complement the `Ord` feature for it. Wrapping not ordered
/// primitive types like `f32` and `f64` in `OrdPrimitive` can make them be used in places that
/// require `Ord`. For example, in `Median` or `Percentile` UDAFs.
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct OrdPrimitive<T: Primitive>(pub T);
impl<T: Primitive> OrdPrimitive<T> {
pub fn as_primitive(&self) -> T {
self.0
}
}
impl<T: Primitive> Eq for OrdPrimitive<T> {}
impl<T: Primitive> PartialOrd for OrdPrimitive<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<T: Primitive> Ord for OrdPrimitive<T> {
fn cmp(&self, other: &Self) -> Ordering {
self.0.into().cmp(&other.0.into())
}
}
impl<T: Primitive> From<OrdPrimitive<T>> for Value {
fn from(p: OrdPrimitive<T>) -> Self {
p.0.into()
}
}
#[cfg(test)]
mod tests {
use std::collections::BinaryHeap;
use super::*;
#[test]
fn test_ord_primitive() {
struct Foo<T>
where
T: Primitive,
{
heap: BinaryHeap<OrdPrimitive<T>>,
}
impl<T> Foo<T>
where
T: Primitive,
{
fn push(&mut self, value: T) {
let value = OrdPrimitive::<T>(value);
self.heap.push(value);
}
}
macro_rules! test {
($Type:ident) => {
let mut foo = Foo::<$Type> {
heap: BinaryHeap::new(),
};
foo.push($Type::default());
};
}
test!(u8);
test!(u16);
test!(u32);
test!(u64);
test!(i8);
test!(i16);
test!(i32);
test!(i64);
test!(f32);
test!(f64);
}
}

View File

@@ -9,7 +9,7 @@ use common_recordbatch::error::Result as RecordResult;
use common_recordbatch::{util, RecordBatch};
use datafusion::field_util::FieldExt;
use datafusion::field_util::SchemaExt;
use datatypes::for_all_ordered_primitive_types;
use datatypes::for_all_primitive_types;
use datatypes::prelude::*;
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::types::PrimitiveElement;
@@ -25,9 +25,6 @@ async fn test_percentile_aggregator() -> Result<()> {
common_telemetry::init_default_ut_logging();
let engine = create_query_engine();
test_percentile_failed::<f32>("f32_number", "numbers", engine.clone()).await?;
test_percentile_failed::<f64>("f64_number", "numbers", engine.clone()).await?;
macro_rules! test_percentile {
([], $( { $T:ty } ),*) => {
$(
@@ -36,7 +33,7 @@ async fn test_percentile_aggregator() -> Result<()> {
)*
}
}
for_all_ordered_primitive_types! { test_percentile }
for_all_primitive_types! { test_percentile }
Ok(())
}
@@ -114,24 +111,6 @@ async fn execute_percentile<'a>(
util::collect(recordbatch_stream).await
}
async fn test_percentile_failed<T>(
column_name: &str,
table_name: &str,
engine: Arc<dyn QueryEngine>,
) -> Result<()>
where
T: PrimitiveElement,
{
let result = execute_percentile(column_name, table_name, engine).await;
assert!(result.is_err());
let error = result.unwrap_err();
assert!(error.to_string().contains(&format!(
"Failed to create accumulator: \"PERCENTILE\" aggregate function not support data type {}",
T::type_name()
)));
Ok(())
}
fn create_correctness_engine() -> Arc<dyn QueryEngine> {
// create engine
let schema_provider = Arc::new(MemorySchemaProvider::new());

View File

@@ -14,11 +14,11 @@ use common_recordbatch::{util, RecordBatch};
use datafusion::field_util::FieldExt;
use datafusion::field_util::SchemaExt;
use datafusion::logical_plan::LogicalPlanBuilder;
use datatypes::for_all_ordered_primitive_types;
use datatypes::for_all_primitive_types;
use datatypes::prelude::*;
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::types::PrimitiveElement;
use datatypes::vectors::{Float32Vector, Float64Vector, PrimitiveVector, UInt32Vector};
use datatypes::types::{OrdPrimitive, PrimitiveElement};
use datatypes::vectors::{PrimitiveVector, UInt32Vector};
use num::NumCast;
use query::error::Result;
use query::plan::LogicalPlan;
@@ -149,7 +149,7 @@ fn create_query_engine() -> Arc<dyn QueryEngine> {
let catalog_provider = Arc::new(MemoryCatalogProvider::new());
let catalog_list = Arc::new(MemoryCatalogList::default());
// create table with ordered primitives, and all columns' length are even
// create table with primitives, and all columns' length are even
let mut column_schemas = vec![];
let mut columns = vec![];
macro_rules! create_even_number_table {
@@ -161,13 +161,13 @@ fn create_query_engine() -> Arc<dyn QueryEngine> {
let column_schema = ColumnSchema::new(column_name, Value::from(<$T>::default()).data_type(), true);
column_schemas.push(column_schema);
let numbers = (1..=100).map(|_| rng.gen_range(<$T>::MIN..<$T>::MAX)).collect::<Vec<$T>>();
let numbers = (1..=100).map(|_| rng.gen::<$T>()).collect::<Vec<$T>>();
let column: VectorRef = Arc::new(PrimitiveVector::<$T>::from_vec(numbers.to_vec()));
columns.push(column);
)*
}
}
for_all_ordered_primitive_types! { create_even_number_table }
for_all_primitive_types! { create_even_number_table }
let schema = Arc::new(Schema::new(column_schemas.clone()));
let recordbatch = RecordBatch::new(schema, columns).unwrap();
@@ -179,7 +179,7 @@ fn create_query_engine() -> Arc<dyn QueryEngine> {
)
.unwrap();
// create table with ordered primitives, and all columns' length are odd
// create table with primitives, and all columns' length are odd
let mut column_schemas = vec![];
let mut columns = vec![];
macro_rules! create_odd_number_table {
@@ -191,13 +191,13 @@ fn create_query_engine() -> Arc<dyn QueryEngine> {
let column_schema = ColumnSchema::new(column_name, Value::from(<$T>::default()).data_type(), true);
column_schemas.push(column_schema);
let numbers = (1..=99).map(|_| rng.gen_range(<$T>::MIN..<$T>::MAX)).collect::<Vec<$T>>();
let numbers = (1..=99).map(|_| rng.gen::<$T>()).collect::<Vec<$T>>();
let column: VectorRef = Arc::new(PrimitiveVector::<$T>::from_vec(numbers.to_vec()));
columns.push(column);
)*
}
}
for_all_ordered_primitive_types! { create_odd_number_table }
for_all_primitive_types! { create_odd_number_table }
let schema = Arc::new(Schema::new(column_schemas.clone()));
let recordbatch = RecordBatch::new(schema, columns).unwrap();
@@ -206,24 +206,6 @@ fn create_query_engine() -> Arc<dyn QueryEngine> {
.register_table(odd_number_table.table_name().to_string(), odd_number_table)
.unwrap();
// create table with floating numbers
let column_schemas = vec![
ColumnSchema::new("f32_number", ConcreteDataType::float32_datatype(), true),
ColumnSchema::new("f64_number", ConcreteDataType::float64_datatype(), true),
];
let f32_numbers: VectorRef = Arc::new(Float32Vector::from_vec(vec![1.0f32, 2.0, 3.0]));
let f64_numbers: VectorRef = Arc::new(Float64Vector::from_vec(vec![1.0f64, 2.0, 3.0]));
let columns = vec![f32_numbers, f64_numbers];
let schema = Arc::new(Schema::new(column_schemas));
let recordbatch = RecordBatch::new(schema, columns).unwrap();
let float_number_table = Arc::new(MemTable::new("float_numbers", recordbatch));
schema_provider
.register_table(
float_number_table.table_name().to_string(),
float_number_table,
)
.unwrap();
catalog_provider.register_schema(DEFAULT_SCHEMA_NAME.to_string(), schema_provider);
catalog_list.register_catalog(DEFAULT_CATALOG_NAME.to_string(), catalog_provider);
@@ -235,7 +217,7 @@ async fn get_numbers_from_table<'s, T>(
column_name: &'s str,
table_name: &'s str,
engine: Arc<dyn QueryEngine>,
) -> Vec<T>
) -> Vec<OrdPrimitive<T>>
where
T: PrimitiveElement,
for<'a> T: Scalar<RefType<'a> = T>,
@@ -253,7 +235,11 @@ where
let columns = numbers[0].df_recordbatch.columns();
let column = VectorHelper::try_into_vector(&columns[0]).unwrap();
let column: &<T as Scalar>::VectorType = unsafe { VectorHelper::static_cast(&column) };
column.iter_data().flatten().collect::<Vec<T>>()
column
.iter_data()
.flatten()
.map(|x| OrdPrimitive::<T>(x))
.collect::<Vec<OrdPrimitive<T>>>()
}
#[tokio::test]
@@ -262,9 +248,6 @@ async fn test_median_aggregator() -> Result<()> {
let engine = create_query_engine();
test_median_failed::<f32>("f32_number", "float_numbers", engine.clone()).await?;
test_median_failed::<f64>("f64_number", "float_numbers", engine.clone()).await?;
macro_rules! test_median {
([], $( { $T:ty } ),*) => {
$(
@@ -276,7 +259,7 @@ async fn test_median_aggregator() -> Result<()> {
)*
}
}
for_all_ordered_primitive_types! { test_median }
for_all_primitive_types! { test_median }
Ok(())
}
@@ -286,7 +269,7 @@ async fn test_median_success<T>(
engine: Arc<dyn QueryEngine>,
) -> Result<()>
where
T: PrimitiveElement + Ord,
T: PrimitiveElement,
for<'a> T: Scalar<RefType<'a> = T>,
{
let result = execute_median(column_name, table_name, engine.clone())
@@ -310,33 +293,15 @@ where
let expected_median: Value = if len % 2 == 1 {
numbers[len / 2]
} else {
let a: f64 = NumCast::from(numbers[len / 2 - 1]).unwrap();
let b: f64 = NumCast::from(numbers[len / 2]).unwrap();
NumCast::from(a / 2.0 + b / 2.0).unwrap()
let a: f64 = NumCast::from(numbers[len / 2 - 1].as_primitive()).unwrap();
let b: f64 = NumCast::from(numbers[len / 2].as_primitive()).unwrap();
OrdPrimitive::<T>(NumCast::from(a / 2.0 + b / 2.0).unwrap())
}
.into();
assert_eq!(expected_median, median);
Ok(())
}
async fn test_median_failed<T>(
column_name: &str,
table_name: &str,
engine: Arc<dyn QueryEngine>,
) -> Result<()>
where
T: PrimitiveElement,
{
let result = execute_median(column_name, table_name, engine).await;
assert!(result.is_err());
let error = result.unwrap_err();
assert!(error.to_string().contains(&format!(
"Failed to create accumulator: \"MEDIAN\" aggregate function not support data type {}",
T::type_name()
)));
Ok(())
}
async fn execute_median<'a>(
column_name: &'a str,
table_name: &'a str,