feat: a new type for supplying Ord to Primitive (#255)

Co-authored-by: luofucong <luofucong@greptime.com>
This commit is contained in:
LFC
2022-09-15 18:32:55 +08:00
committed by GitHub
parent dfa3012396
commit fb6153f7e0
7 changed files with 128 additions and 138 deletions

View File

@@ -9,9 +9,10 @@ use common_query::error::{
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*;
use datatypes::prelude::*;
use datatypes::types::OrdPrimitive;
use datatypes::value::ListValue;
use datatypes::vectors::{ConstantVector, ListVector};
use datatypes::with_match_ordered_primitive_type_id;
use datatypes::with_match_primitive_type_id;
use num::NumCast;
use snafu::{ensure, OptionExt, ResultExt};
@@ -36,17 +37,19 @@ use snafu::{ensure, OptionExt, ResultExt};
#[derive(Debug, Default)]
pub struct Median<T>
where
T: Primitive + Ord,
T: Primitive,
{
greater: BinaryHeap<Reverse<T>>,
not_greater: BinaryHeap<T>,
greater: BinaryHeap<Reverse<OrdPrimitive<T>>>,
not_greater: BinaryHeap<OrdPrimitive<T>>,
}
impl<T> Median<T>
where
T: Primitive + Ord,
T: Primitive,
{
fn push(&mut self, value: T) {
let value = OrdPrimitive::<T>(value);
if self.not_greater.is_empty() {
self.not_greater.push(value);
return;
@@ -70,7 +73,7 @@ where
// to use them.
impl<T> Accumulator for Median<T>
where
T: Primitive + Ord,
T: Primitive,
for<'a> T: Scalar<RefType<'a> = T>,
{
// This function serializes our state to `ScalarValue`, which DataFusion uses to pass this
@@ -165,8 +168,8 @@ where
let greater = self.greater.peek().unwrap();
// the following three NumCast's `unwrap`s are safe because T is primitive
let not_greater_v: f64 = NumCast::from(not_greater).unwrap();
let greater_v: f64 = NumCast::from(greater.0).unwrap();
let not_greater_v: f64 = NumCast::from(not_greater.as_primitive()).unwrap();
let greater_v: f64 = NumCast::from(greater.0.as_primitive()).unwrap();
let median: T = NumCast::from((not_greater_v + greater_v) / 2.0).unwrap();
median.into()
};
@@ -182,7 +185,7 @@ impl AggregateFunctionCreator for MedianAccumulatorCreator {
fn creator(&self) -> AccumulatorCreatorFunction {
let creator: AccumulatorCreatorFunction = Arc::new(move |types: &[ConcreteDataType]| {
let input_type = &types[0];
with_match_ordered_primitive_type_id!(
with_match_primitive_type_id!(
input_type.logical_type_id(),
|$S| {
Ok(Box::new(Median::<$S>::default()))

View File

@@ -10,9 +10,10 @@ use common_query::error::{
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*;
use datatypes::prelude::*;
use datatypes::types::OrdPrimitive;
use datatypes::value::{ListValue, OrderedFloat};
use datatypes::vectors::{ConstantVector, Float64Vector, ListVector};
use datatypes::with_match_ordered_primitive_type_id;
use datatypes::with_match_primitive_type_id;
use num::NumCast;
use snafu::{ensure, OptionExt, ResultExt};
@@ -37,19 +38,21 @@ use snafu::{ensure, OptionExt, ResultExt};
#[derive(Debug, Default)]
pub struct Percentile<T>
where
T: Primitive + Ord,
T: Primitive,
{
greater: BinaryHeap<Reverse<T>>,
not_greater: BinaryHeap<T>,
greater: BinaryHeap<Reverse<OrdPrimitive<T>>>,
not_greater: BinaryHeap<OrdPrimitive<T>>,
n: u64,
p: Option<f64>,
}
impl<T> Percentile<T>
where
T: Primitive + Ord,
T: Primitive,
{
fn push(&mut self, value: T) {
let value = OrdPrimitive::<T>(value);
self.n += 1;
if self.not_greater.is_empty() {
self.not_greater.push(value);
@@ -76,7 +79,7 @@ where
impl<T> Accumulator for Percentile<T>
where
T: Primitive + Ord,
T: Primitive,
for<'a> T: Scalar<RefType<'a> = T>,
{
fn state(&self) -> Result<Vec<Value>> {
@@ -212,7 +215,7 @@ where
if not_greater.is_none() {
return Ok(Value::Null);
}
let not_greater = *self.not_greater.peek().unwrap();
let not_greater = (*self.not_greater.peek().unwrap()).as_primitive();
let percentile = if self.greater.is_empty() {
NumCast::from(not_greater).unwrap()
} else {
@@ -224,7 +227,7 @@ where
};
let fract = (((self.n - 1) as f64) * p / 100_f64).fract();
let not_greater_v: f64 = NumCast::from(not_greater).unwrap();
let greater_v: f64 = NumCast::from(greater.0).unwrap();
let greater_v: f64 = NumCast::from(greater.0.as_primitive()).unwrap();
not_greater_v * (1.0 - fract) + greater_v * fract
};
Ok(Value::from(percentile))
@@ -239,7 +242,7 @@ impl AggregateFunctionCreator for PercentileAccumulatorCreator {
fn creator(&self) -> AccumulatorCreatorFunction {
let creator: AccumulatorCreatorFunction = Arc::new(move |types: &[ConcreteDataType]| {
let input_type = &types[0];
with_match_ordered_primitive_type_id!(
with_match_primitive_type_id!(
input_type.logical_type_id(),
|$S| {
Ok(Box::new(Percentile::<$S>::default()))