diff --git a/src/common/function/src/scalars/aggregate/median.rs b/src/common/function/src/scalars/aggregate/median.rs index ef2e1bf3f2..7cd3601a5c 100644 --- a/src/common/function/src/scalars/aggregate/median.rs +++ b/src/common/function/src/scalars/aggregate/median.rs @@ -9,9 +9,10 @@ use common_query::error::{ use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; use common_query::prelude::*; use datatypes::prelude::*; +use datatypes::types::OrdPrimitive; use datatypes::value::ListValue; use datatypes::vectors::{ConstantVector, ListVector}; -use datatypes::with_match_ordered_primitive_type_id; +use datatypes::with_match_primitive_type_id; use num::NumCast; use snafu::{ensure, OptionExt, ResultExt}; @@ -36,17 +37,19 @@ use snafu::{ensure, OptionExt, ResultExt}; #[derive(Debug, Default)] pub struct Median where - T: Primitive + Ord, + T: Primitive, { - greater: BinaryHeap>, - not_greater: BinaryHeap, + greater: BinaryHeap>>, + not_greater: BinaryHeap>, } impl Median where - T: Primitive + Ord, + T: Primitive, { fn push(&mut self, value: T) { + let value = OrdPrimitive::(value); + if self.not_greater.is_empty() { self.not_greater.push(value); return; @@ -70,7 +73,7 @@ where // to use them. impl Accumulator for Median where - T: Primitive + Ord, + T: Primitive, for<'a> T: Scalar = T>, { // This function serializes our state to `ScalarValue`, which DataFusion uses to pass this @@ -165,8 +168,8 @@ where let greater = self.greater.peek().unwrap(); // the following three NumCast's `unwrap`s are safe because T is primitive - let not_greater_v: f64 = NumCast::from(not_greater).unwrap(); - let greater_v: f64 = NumCast::from(greater.0).unwrap(); + let not_greater_v: f64 = NumCast::from(not_greater.as_primitive()).unwrap(); + let greater_v: f64 = NumCast::from(greater.0.as_primitive()).unwrap(); let median: T = NumCast::from((not_greater_v + greater_v) / 2.0).unwrap(); median.into() }; @@ -182,7 +185,7 @@ impl AggregateFunctionCreator for MedianAccumulatorCreator { fn creator(&self) -> AccumulatorCreatorFunction { let creator: AccumulatorCreatorFunction = Arc::new(move |types: &[ConcreteDataType]| { let input_type = &types[0]; - with_match_ordered_primitive_type_id!( + with_match_primitive_type_id!( input_type.logical_type_id(), |$S| { Ok(Box::new(Median::<$S>::default())) diff --git a/src/common/function/src/scalars/aggregate/percentile.rs b/src/common/function/src/scalars/aggregate/percentile.rs index 3b79a4bdf1..fbf497edcc 100644 --- a/src/common/function/src/scalars/aggregate/percentile.rs +++ b/src/common/function/src/scalars/aggregate/percentile.rs @@ -10,9 +10,10 @@ use common_query::error::{ use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; use common_query::prelude::*; use datatypes::prelude::*; +use datatypes::types::OrdPrimitive; use datatypes::value::{ListValue, OrderedFloat}; use datatypes::vectors::{ConstantVector, Float64Vector, ListVector}; -use datatypes::with_match_ordered_primitive_type_id; +use datatypes::with_match_primitive_type_id; use num::NumCast; use snafu::{ensure, OptionExt, ResultExt}; @@ -37,19 +38,21 @@ use snafu::{ensure, OptionExt, ResultExt}; #[derive(Debug, Default)] pub struct Percentile where - T: Primitive + Ord, + T: Primitive, { - greater: BinaryHeap>, - not_greater: BinaryHeap, + greater: BinaryHeap>>, + not_greater: BinaryHeap>, n: u64, p: Option, } impl Percentile where - T: Primitive + Ord, + T: Primitive, { fn push(&mut self, value: T) { + let value = OrdPrimitive::(value); + self.n += 1; if self.not_greater.is_empty() { self.not_greater.push(value); @@ -76,7 +79,7 @@ where impl Accumulator for Percentile where - T: Primitive + Ord, + T: Primitive, for<'a> T: Scalar = T>, { fn state(&self) -> Result> { @@ -212,7 +215,7 @@ where if not_greater.is_none() { return Ok(Value::Null); } - let not_greater = *self.not_greater.peek().unwrap(); + let not_greater = (*self.not_greater.peek().unwrap()).as_primitive(); let percentile = if self.greater.is_empty() { NumCast::from(not_greater).unwrap() } else { @@ -224,7 +227,7 @@ where }; let fract = (((self.n - 1) as f64) * p / 100_f64).fract(); let not_greater_v: f64 = NumCast::from(not_greater).unwrap(); - let greater_v: f64 = NumCast::from(greater.0).unwrap(); + let greater_v: f64 = NumCast::from(greater.0.as_primitive()).unwrap(); not_greater_v * (1.0 - fract) + greater_v * fract }; Ok(Value::from(percentile)) @@ -239,7 +242,7 @@ impl AggregateFunctionCreator for PercentileAccumulatorCreator { fn creator(&self) -> AccumulatorCreatorFunction { let creator: AccumulatorCreatorFunction = Arc::new(move |types: &[ConcreteDataType]| { let input_type = &types[0]; - with_match_ordered_primitive_type_id!( + with_match_primitive_type_id!( input_type.logical_type_id(), |$S| { Ok(Box::new(Percentile::<$S>::default())) diff --git a/src/datatypes/src/macros.rs b/src/datatypes/src/macros.rs index da385cd4ce..fb94195e38 100644 --- a/src/datatypes/src/macros.rs +++ b/src/datatypes/src/macros.rs @@ -38,23 +38,6 @@ macro_rules! for_all_primitive_types { }; } -#[macro_export] -macro_rules! for_all_ordered_primitive_types { - ($macro:tt $(, $x:tt)*) => { - $macro! { - [$($x),*], - { i8 }, - { i16 }, - { i32 }, - { i64 }, - { u8 }, - { u16 }, - { u32 }, - { u64 } - } - }; -} - #[macro_export] macro_rules! with_match_primitive_type_id { ($key_type:expr, | $_:tt $T:ident | $body:tt, $nbody:tt) => {{ @@ -81,27 +64,3 @@ macro_rules! with_match_primitive_type_id { } }}; } - -#[macro_export] -macro_rules! with_match_ordered_primitive_type_id { - ($key_type:expr, | $_:tt $T:ident | $body:tt, $nbody:tt) => {{ - macro_rules! __with_ty__ { - ( $_ $T:ident ) => { - $body - }; - } - - match $key_type { - LogicalTypeId::Int8 => __with_ty__! { i8 }, - LogicalTypeId::Int16 => __with_ty__! { i16 }, - LogicalTypeId::Int32 => __with_ty__! { i32 }, - LogicalTypeId::Int64 => __with_ty__! { i64 }, - LogicalTypeId::UInt8 => __with_ty__! { u8 }, - LogicalTypeId::UInt16 => __with_ty__! { u16 }, - LogicalTypeId::UInt32 => __with_ty__! { u32 }, - LogicalTypeId::UInt64 => __with_ty__! { u64 }, - - _ => $nbody, - } - }}; -} diff --git a/src/datatypes/src/types.rs b/src/datatypes/src/types.rs index a3bfda8aa5..3bb5bfdb93 100644 --- a/src/datatypes/src/types.rs +++ b/src/datatypes/src/types.rs @@ -15,7 +15,7 @@ pub use date::DateType; pub use datetime::DateTimeType; pub use list_type::ListType; pub use null_type::NullType; -pub use primitive_traits::Primitive; +pub use primitive_traits::{OrdPrimitive, Primitive}; pub use primitive_type::{ Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, PrimitiveElement, PrimitiveType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, diff --git a/src/datatypes/src/types/primitive_traits.rs b/src/datatypes/src/types/primitive_traits.rs index be031fe1ca..941d857349 100644 --- a/src/datatypes/src/types/primitive_traits.rs +++ b/src/datatypes/src/types/primitive_traits.rs @@ -1,3 +1,5 @@ +use std::cmp::Ordering; + use arrow::compute::arithmetics::basic::NativeArithmetics; use arrow::types::NativeType; use num::NumCast; @@ -41,3 +43,82 @@ impl_primitive!(i32, i64); impl_primitive!(i64, i64); impl_primitive!(f32, f64); impl_primitive!(f64, f64); + +/// A new type for [Primitive], complement the `Ord` feature for it. Wrapping not ordered +/// primitive types like `f32` and `f64` in `OrdPrimitive` can make them be used in places that +/// require `Ord`. For example, in `Median` or `Percentile` UDAFs. +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct OrdPrimitive(pub T); + +impl OrdPrimitive { + pub fn as_primitive(&self) -> T { + self.0 + } +} + +impl Eq for OrdPrimitive {} + +impl PartialOrd for OrdPrimitive { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for OrdPrimitive { + fn cmp(&self, other: &Self) -> Ordering { + self.0.into().cmp(&other.0.into()) + } +} + +impl From> for Value { + fn from(p: OrdPrimitive) -> Self { + p.0.into() + } +} + +#[cfg(test)] +mod tests { + use std::collections::BinaryHeap; + + use super::*; + + #[test] + fn test_ord_primitive() { + struct Foo + where + T: Primitive, + { + heap: BinaryHeap>, + } + + impl Foo + where + T: Primitive, + { + fn push(&mut self, value: T) { + let value = OrdPrimitive::(value); + self.heap.push(value); + } + } + + macro_rules! test { + ($Type:ident) => { + let mut foo = Foo::<$Type> { + heap: BinaryHeap::new(), + }; + foo.push($Type::default()); + }; + } + + test!(u8); + test!(u16); + test!(u32); + test!(u64); + test!(i8); + test!(i16); + test!(i32); + test!(i64); + test!(f32); + test!(f64); + } +} diff --git a/src/query/tests/percentile_test.rs b/src/query/tests/percentile_test.rs index 550d9d9fe8..08bf7df854 100644 --- a/src/query/tests/percentile_test.rs +++ b/src/query/tests/percentile_test.rs @@ -9,7 +9,7 @@ use common_recordbatch::error::Result as RecordResult; use common_recordbatch::{util, RecordBatch}; use datafusion::field_util::FieldExt; use datafusion::field_util::SchemaExt; -use datatypes::for_all_ordered_primitive_types; +use datatypes::for_all_primitive_types; use datatypes::prelude::*; use datatypes::schema::{ColumnSchema, Schema}; use datatypes::types::PrimitiveElement; @@ -25,9 +25,6 @@ async fn test_percentile_aggregator() -> Result<()> { common_telemetry::init_default_ut_logging(); let engine = create_query_engine(); - test_percentile_failed::("f32_number", "numbers", engine.clone()).await?; - test_percentile_failed::("f64_number", "numbers", engine.clone()).await?; - macro_rules! test_percentile { ([], $( { $T:ty } ),*) => { $( @@ -36,7 +33,7 @@ async fn test_percentile_aggregator() -> Result<()> { )* } } - for_all_ordered_primitive_types! { test_percentile } + for_all_primitive_types! { test_percentile } Ok(()) } @@ -114,24 +111,6 @@ async fn execute_percentile<'a>( util::collect(recordbatch_stream).await } -async fn test_percentile_failed( - column_name: &str, - table_name: &str, - engine: Arc, -) -> Result<()> -where - T: PrimitiveElement, -{ - let result = execute_percentile(column_name, table_name, engine).await; - assert!(result.is_err()); - let error = result.unwrap_err(); - assert!(error.to_string().contains(&format!( - "Failed to create accumulator: \"PERCENTILE\" aggregate function not support data type {}", - T::type_name() - ))); - Ok(()) -} - fn create_correctness_engine() -> Arc { // create engine let schema_provider = Arc::new(MemorySchemaProvider::new()); diff --git a/src/query/tests/query_engine_test.rs b/src/query/tests/query_engine_test.rs index 70703dd777..a1c9d79882 100644 --- a/src/query/tests/query_engine_test.rs +++ b/src/query/tests/query_engine_test.rs @@ -14,11 +14,11 @@ use common_recordbatch::{util, RecordBatch}; use datafusion::field_util::FieldExt; use datafusion::field_util::SchemaExt; use datafusion::logical_plan::LogicalPlanBuilder; -use datatypes::for_all_ordered_primitive_types; +use datatypes::for_all_primitive_types; use datatypes::prelude::*; use datatypes::schema::{ColumnSchema, Schema}; -use datatypes::types::PrimitiveElement; -use datatypes::vectors::{Float32Vector, Float64Vector, PrimitiveVector, UInt32Vector}; +use datatypes::types::{OrdPrimitive, PrimitiveElement}; +use datatypes::vectors::{PrimitiveVector, UInt32Vector}; use num::NumCast; use query::error::Result; use query::plan::LogicalPlan; @@ -149,7 +149,7 @@ fn create_query_engine() -> Arc { let catalog_provider = Arc::new(MemoryCatalogProvider::new()); let catalog_list = Arc::new(MemoryCatalogList::default()); - // create table with ordered primitives, and all columns' length are even + // create table with primitives, and all columns' length are even let mut column_schemas = vec![]; let mut columns = vec![]; macro_rules! create_even_number_table { @@ -161,13 +161,13 @@ fn create_query_engine() -> Arc { let column_schema = ColumnSchema::new(column_name, Value::from(<$T>::default()).data_type(), true); column_schemas.push(column_schema); - let numbers = (1..=100).map(|_| rng.gen_range(<$T>::MIN..<$T>::MAX)).collect::>(); + let numbers = (1..=100).map(|_| rng.gen::<$T>()).collect::>(); let column: VectorRef = Arc::new(PrimitiveVector::<$T>::from_vec(numbers.to_vec())); columns.push(column); )* } } - for_all_ordered_primitive_types! { create_even_number_table } + for_all_primitive_types! { create_even_number_table } let schema = Arc::new(Schema::new(column_schemas.clone())); let recordbatch = RecordBatch::new(schema, columns).unwrap(); @@ -179,7 +179,7 @@ fn create_query_engine() -> Arc { ) .unwrap(); - // create table with ordered primitives, and all columns' length are odd + // create table with primitives, and all columns' length are odd let mut column_schemas = vec![]; let mut columns = vec![]; macro_rules! create_odd_number_table { @@ -191,13 +191,13 @@ fn create_query_engine() -> Arc { let column_schema = ColumnSchema::new(column_name, Value::from(<$T>::default()).data_type(), true); column_schemas.push(column_schema); - let numbers = (1..=99).map(|_| rng.gen_range(<$T>::MIN..<$T>::MAX)).collect::>(); + let numbers = (1..=99).map(|_| rng.gen::<$T>()).collect::>(); let column: VectorRef = Arc::new(PrimitiveVector::<$T>::from_vec(numbers.to_vec())); columns.push(column); )* } } - for_all_ordered_primitive_types! { create_odd_number_table } + for_all_primitive_types! { create_odd_number_table } let schema = Arc::new(Schema::new(column_schemas.clone())); let recordbatch = RecordBatch::new(schema, columns).unwrap(); @@ -206,24 +206,6 @@ fn create_query_engine() -> Arc { .register_table(odd_number_table.table_name().to_string(), odd_number_table) .unwrap(); - // create table with floating numbers - let column_schemas = vec![ - ColumnSchema::new("f32_number", ConcreteDataType::float32_datatype(), true), - ColumnSchema::new("f64_number", ConcreteDataType::float64_datatype(), true), - ]; - let f32_numbers: VectorRef = Arc::new(Float32Vector::from_vec(vec![1.0f32, 2.0, 3.0])); - let f64_numbers: VectorRef = Arc::new(Float64Vector::from_vec(vec![1.0f64, 2.0, 3.0])); - let columns = vec![f32_numbers, f64_numbers]; - let schema = Arc::new(Schema::new(column_schemas)); - let recordbatch = RecordBatch::new(schema, columns).unwrap(); - let float_number_table = Arc::new(MemTable::new("float_numbers", recordbatch)); - schema_provider - .register_table( - float_number_table.table_name().to_string(), - float_number_table, - ) - .unwrap(); - catalog_provider.register_schema(DEFAULT_SCHEMA_NAME.to_string(), schema_provider); catalog_list.register_catalog(DEFAULT_CATALOG_NAME.to_string(), catalog_provider); @@ -235,7 +217,7 @@ async fn get_numbers_from_table<'s, T>( column_name: &'s str, table_name: &'s str, engine: Arc, -) -> Vec +) -> Vec> where T: PrimitiveElement, for<'a> T: Scalar = T>, @@ -253,7 +235,11 @@ where let columns = numbers[0].df_recordbatch.columns(); let column = VectorHelper::try_into_vector(&columns[0]).unwrap(); let column: &::VectorType = unsafe { VectorHelper::static_cast(&column) }; - column.iter_data().flatten().collect::>() + column + .iter_data() + .flatten() + .map(|x| OrdPrimitive::(x)) + .collect::>>() } #[tokio::test] @@ -262,9 +248,6 @@ async fn test_median_aggregator() -> Result<()> { let engine = create_query_engine(); - test_median_failed::("f32_number", "float_numbers", engine.clone()).await?; - test_median_failed::("f64_number", "float_numbers", engine.clone()).await?; - macro_rules! test_median { ([], $( { $T:ty } ),*) => { $( @@ -276,7 +259,7 @@ async fn test_median_aggregator() -> Result<()> { )* } } - for_all_ordered_primitive_types! { test_median } + for_all_primitive_types! { test_median } Ok(()) } @@ -286,7 +269,7 @@ async fn test_median_success( engine: Arc, ) -> Result<()> where - T: PrimitiveElement + Ord, + T: PrimitiveElement, for<'a> T: Scalar = T>, { let result = execute_median(column_name, table_name, engine.clone()) @@ -310,33 +293,15 @@ where let expected_median: Value = if len % 2 == 1 { numbers[len / 2] } else { - let a: f64 = NumCast::from(numbers[len / 2 - 1]).unwrap(); - let b: f64 = NumCast::from(numbers[len / 2]).unwrap(); - NumCast::from(a / 2.0 + b / 2.0).unwrap() + let a: f64 = NumCast::from(numbers[len / 2 - 1].as_primitive()).unwrap(); + let b: f64 = NumCast::from(numbers[len / 2].as_primitive()).unwrap(); + OrdPrimitive::(NumCast::from(a / 2.0 + b / 2.0).unwrap()) } .into(); assert_eq!(expected_median, median); Ok(()) } -async fn test_median_failed( - column_name: &str, - table_name: &str, - engine: Arc, -) -> Result<()> -where - T: PrimitiveElement, -{ - let result = execute_median(column_name, table_name, engine).await; - assert!(result.is_err()); - let error = result.unwrap_err(); - assert!(error.to_string().contains(&format!( - "Failed to create accumulator: \"MEDIAN\" aggregate function not support data type {}", - T::type_name() - ))); - Ok(()) -} - async fn execute_median<'a>( column_name: &'a str, table_name: &'a str,