feat: UDAF made generically (#91)

* feat: UDAF implementation backed by DataFusion.

Directly Transplant DataFusion's UDAF related structs, traits and functions, like `AggregateUDF`, `Accumulator` or `create_udaf` etc.

Implement median UDAF on top of it and used in unit testing.

Refs: #61

* feat: UDAF made generically

Refs: #61

* fix: cargo fmt

* fix: use prelude

* fix: uniform the name

* fix: move maybe commonly used functions together

* fix: make comments more clear

* fix: resolve conversations in CR

* fix: store input types in AccumulatorCreator, and use ScalarVector's iterator

* feat: introducing List value and List datatype

* refactor: use ArcSwap instead of Mutext

* refactor: shorten some namings

* refactor: move median UDAF out of tests

* refactor: rename

* feat: aggregate function registry

* fix: make `Value` satisfy ordering again

* fix: clippy warnings

* doc: add "how to write aggregate function"

* fix: address PR comments

* fix: trying to get rid of unwraps

Co-authored-by: luofucong <luofucong@greptime.com>
This commit is contained in:
LFC
2022-07-25 10:35:36 +08:00
committed by GitHub
parent c126b480fd
commit 2b064265bf
37 changed files with 2000 additions and 38 deletions

9
Cargo.lock generated
View File

@@ -621,10 +621,12 @@ dependencies = [
name = "common-function"
version = "0.1.0"
dependencies = [
"arc-swap",
"arrow2",
"chrono-tz",
"common-error",
"common-query",
"datafusion-common",
"datatypes",
"num",
"num-traits",
@@ -644,6 +646,7 @@ dependencies = [
"datafusion-expr",
"datatypes",
"snafu",
"tokio",
]
[[package]]
@@ -1064,6 +1067,8 @@ dependencies = [
"common-error",
"datafusion-common",
"enum_dispatch",
"num",
"num-traits",
"ordered-float 3.0.0",
"paste",
"serde",
@@ -2696,6 +2701,7 @@ dependencies = [
name = "query"
version = "0.1.0"
dependencies = [
"arc-swap",
"arrow2",
"async-trait",
"common-error",
@@ -2709,6 +2715,9 @@ dependencies = [
"futures",
"futures-util",
"metrics",
"num",
"num-traits",
"rand 0.8.5",
"snafu",
"sql",
"table",

View File

@@ -0,0 +1,68 @@
Currently, our query engine is based on DataFusion, so all aggregate function is executed by DataFusion, through its UDAF interface. You can find DataFusion's UDAF example [here](https://github.com/apache/arrow-datafusion/blob/arrow2/datafusion-examples/examples/simple_udaf.rs). Basically, we provide the same way as DataFusion to write aggregate functions: both are centered in a struct called "Accumulator" to accumulates states along the way in aggregation.
However, DataFusion's UDAF implementation has a huge restriction, that it requires user to provide a concrete "Accumulator". Take `Median` aggregate function for example, to aggregate a `u32` datatype column, you have to write a `MedianU32`, and use `SELECT MEDIANU32(x)` in SQL. `MedianU32` cannot be used to aggregate a `i32` datatype column. Or, there's another way: you can use a special type that can hold all kinds of data (like our `Value` enum or Arrow's `ScalarValue`), and `match` all the way up to do aggregate calculations. It might work, though rather tedious. (But I think it's DataFusion's prefer way to write UDAF.)
So is there a way we can make an aggregate function that automatically match the input data's type? For example, a `Median` aggregator that can work on both `u32` column and `i32`? The answer is yes until we found a way to bypassing DataFusion's restriction, a restriction that DataFusion simply don't pass the input data's type when creating an Accumulator.
> There's an example in `my_sum_udaf_example.rs`, take that as quick start.
# 1. Impl `AggregateFunctionCreator` trait for your accumulator creator.
You must first define a struct that can store the input data's type. For example,
```Rust
struct MySumAccumulatorCreator {
input_types: ArcSwapOption<Vec<ConcreteDataType>>,
}
```
Then impl `AggregateFunctionCreator` trait on it. The definition of the trait is:
```Rust
pub trait AggregateFunctionCreator: Send + Sync + Debug {
fn creator(&self) -> AccumulatorCreatorFunction;
fn input_types(&self) -> Vec<ConcreteDataType>;
fn set_input_types(&self, input_types: Vec<ConcreteDataType>);
fn output_type(&self) -> ConcreteDataType;
fn state_types(&self) -> Vec<ConcreteDataType>;
}
```
our query engine will call `set_input_types` the very first, so you can use input data's type in methods that return output type and state types.
The output type is aggregate function's output data's type. For example, `SUM` aggregate function's output type is `u64` for a `u32` datatype column. The state types are accumulator's internal states' types. Take `AVG` aggregate function on a `i32` column as example, it's state types are `i64` (for sum) and `u64` (for count).
The `creator` function is where you define how an accumulator (that will be used in DataFusion) is created. You define "how" to create the accumulator (instead of "what" to create), using the input data's type as arguments. With input datatype known, you can create accumulator generically.
# 2. Impl `Accumulator` trait for you accumulator.
The accumulator is where you store the aggregate calculation states and evaluate a result. You must impl `Accumulator` trait for it. The trait's definition is:
```Rust
pub trait Accumulator: Send + Sync + Debug {
fn state(&self) -> Result<Vec<Value>>;
fn update_batch(&mut self, values: &[VectorRef]) -> Result<()>;
fn merge_batch(&mut self, states: &[VectorRef]) -> Result<()>;
fn evaluate(&self) -> Result<Value>;
}
```
The DataFusion basically execute aggregate like this:
1. Partitioning all input data for aggregate. Create an accumulator for each part.
2. Call `update_batch` on each accumulator with partitioned data, to let you update your aggregate calculation.
3. Call `state` to get each accumulator's internal state, the medial calculation result.
4. Call `merge_batch` to merge all accumulator's internal state to one.
5. Execute `evalute` on the chosen one to get the final calculation result.
Once you know the meaning of each method, you can easily write your accumulator. You can refer to `Median` accumulator or `SUM` accumulator defined in file `my_sum_udaf_example.rs` for more details.
# 3. Register your aggregate function to our query engine.
You can call `register_aggregate_function` method in query engine to register your aggregate function. To do that, you have to new an instance of struct `AggregateFunctionMeta`. The struct has two fields, first is the name of your aggregate function's name. The function name is case-sensitive due to DataFusion's restriction. We strongly recommend using lowercase for your name. If you have to use uppercase name, wrap your aggregate function with quotation marks. For example, if you define an aggregate function named "my_aggr", you can use "`SELECT MY_AGGR(x)`"; if you define "my_AGGR", you have to use "`SELECT "my_AGGR"(x)`".
The second field is a function about how to create your accumulator creator that you defined in step 1 above. Create creator, that's a bit intertwined, but it is how we make DataFusion use a newly created aggregate function each time it executes a SQL, preventing the stored input types from affecting each other. The key detail can be starting looking at our `DfContextProviderAdapter` struct's `get_aggregate_meta` method.
# (Optional) 4. Make your aggregate function automatically registered.
If you've written a great aggregate function that want to let everyone use it, you can make it automatically registered to our query engine at start time. It's quick simple, just refer to the `AggregateFunctions::register` function in `common/function/src/scalars/aggregate/mod.rs`.

View File

@@ -11,9 +11,11 @@ features = ["io_csv", "io_json", "io_parquet", "io_parquet_compression", "io_ipc
[dependencies]
arc-swap = "1.0"
chrono-tz = "0.6"
common-error = { path = "../error" }
common-query = { path = "../query" }
datafusion-common = { git = "https://github.com/apache/arrow-datafusion.git" , branch = "arrow2" }
datatypes = { path = "../../datatypes" }
num = "0.4.0"
num-traits = "0.2.14"

View File

@@ -1,3 +1,4 @@
pub mod aggregate;
pub mod expression;
pub mod function;
pub mod function_registry;
@@ -7,5 +8,6 @@ pub mod numpy;
pub(crate) mod test;
pub mod udf;
pub use aggregate::MedianAccumulatorCreator;
pub use function::{Function, FunctionRef};
pub use function_registry::{FunctionRegistry, FUNCTION_REGISTRY};

View File

@@ -0,0 +1,300 @@
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::sync::Arc;
use arc_swap::ArcSwapOption;
use common_query::error::{
CreateAccumulatorSnafu, DowncastVectorSnafu, ExecuteFunctionSnafu, FromScalarValueSnafu, Result,
};
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*;
use datafusion_common::DataFusionError;
use datatypes::prelude::*;
use datatypes::value::ListValue;
use datatypes::vectors::{ConstantVector, ListVector};
use datatypes::with_match_ordered_primitive_type_id;
use num::NumCast;
use snafu::{OptionExt, ResultExt};
// This median calculation algorithm's details can be found at
// https://leetcode.cn/problems/find-median-from-data-stream/
//
// Basically, it uses two heaps, a maximum heap and a minimum. The maximum heap stores numbers that
// are not greater than the median, and the minimum heap stores the greater. In a streaming of
// numbers, when a number is arrived, we adjust the heaps' tops, so that either one top is the
// median or both tops can be averaged to get the median.
//
// The time complexity to update the median is O(logn), O(1) to get the median; and the space
// complexity is O(n). (Ignore the costs for heap expansion.)
//
// From the point of algorithm, [quick select](https://en.wikipedia.org/wiki/Quickselect) might be
// better. But to use quick select here, we need a mutable self in the final calculation(`evaluate`)
// to swap stored numbers in the states vector. Though we can make our `evaluate` received
// `&mut self`, DataFusion calls our accumulator with `&self` (see `DfAccumulatorAdaptor`). That
// means we have to introduce some kinds of interior mutability, and the overhead is not neglectable.
//
// TODO(LFC): Use quick select to get median when we can modify DataFusion's code, and benchmark with two-heap algorithm.
#[derive(Debug, Default)]
pub struct Median<T>
where
T: Primitive + Ord,
{
greater: BinaryHeap<Reverse<T>>,
not_greater: BinaryHeap<T>,
}
impl<T> Median<T>
where
T: Primitive + Ord,
{
fn push(&mut self, value: T) {
if self.not_greater.is_empty() {
self.not_greater.push(value);
return;
}
// The `unwrap`s below are safe because there are `push`s before them.
if value <= *self.not_greater.peek().unwrap() {
self.not_greater.push(value);
if self.not_greater.len() > self.greater.len() + 1 {
self.greater.push(Reverse(self.not_greater.pop().unwrap()));
}
} else {
self.greater.push(Reverse(value));
if self.greater.len() > self.not_greater.len() {
self.not_greater.push(self.greater.pop().unwrap().0);
}
}
}
}
// UDAFs are built using the trait `Accumulator`, that offers DataFusion the necessary functions
// to use them.
impl<T> Accumulator for Median<T>
where
T: Primitive + Ord,
for<'a> T: Scalar<RefType<'a> = T>,
{
// This function serializes our state to `ScalarValue`, which DataFusion uses to pass this
// state between execution stages. Note that this can be arbitrary data.
//
// The `ScalarValue`s returned here will be passed in as argument `states: &[VectorRef]` to
// `merge_batch` function.
fn state(&self) -> Result<Vec<Value>> {
let nums = self
.greater
.iter()
.map(|x| &x.0)
.chain(self.not_greater.iter())
.map(|&n| n.into())
.collect::<Vec<Value>>();
Ok(vec![Value::List(ListValue::new(
Some(Box::new(nums)),
T::default().into().data_type(),
))])
}
// DataFusion calls this function to update the accumulator's state for a batch of inputs rows.
// It is expected this function to update the accumulator's state.
fn update_batch(&mut self, values: &[VectorRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
};
// This is a unary accumulator, so only one column is provided.
let column = &values[0];
let column: &<T as Scalar>::VectorType = if column.is_const() {
let column: &ConstantVector = unsafe { VectorHelper::static_cast(column) };
unsafe { VectorHelper::static_cast(column.inner()) }
} else {
unsafe { VectorHelper::static_cast(column) }
};
for v in column.iter_data().flatten() {
self.push(v);
}
Ok(())
}
// DataFusion executes accumulators in partitions. In some execution stage, DataFusion will
// merge states from other accumulators (returned by `state()` method).
fn merge_batch(&mut self, states: &[VectorRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
};
// The states here are returned by the `state` method. Since we only returned a vector
// with one value in that method, `states[0]` is fine.
let states = &states[0];
let states = states
.as_any()
.downcast_ref::<ListVector>()
.with_context(|| DowncastVectorSnafu {
err_msg: format!(
"expect ListVector, got vector type {}",
states.vector_type_name()
),
})?;
for state in states.values_iter() {
let state = state.context(FromScalarValueSnafu)?;
// merging state is simply accumulate stored numbers from others', so just call update
self.update_batch(&[state])?
}
Ok(())
}
// DataFusion expects this function to return the final value of this aggregator.
fn evaluate(&self) -> Result<Value> {
if self.not_greater.is_empty() {
assert!(
self.greater.is_empty(),
"not expected in two-heap median algorithm, there must be a bug when implementing it"
);
return Ok(Value::Null);
}
// unwrap is safe because we checked not_greater heap's len above
let not_greater = *self.not_greater.peek().unwrap();
let median = if self.not_greater.len() > self.greater.len() {
not_greater.into()
} else {
// unwrap is safe because greater heap len >= not_greater heap len, which is > 0
let greater = self.greater.peek().unwrap();
// the following three NumCast's `unwrap`s are safe because T is primitive
let not_greater_v: f64 = NumCast::from(not_greater).unwrap();
let greater_v: f64 = NumCast::from(greater.0).unwrap();
let median: T = NumCast::from((not_greater_v + greater_v) / 2.0).unwrap();
median.into()
};
Ok(median)
}
}
#[derive(Debug, Default)]
pub struct MedianAccumulatorCreator {
input_types: ArcSwapOption<Vec<ConcreteDataType>>,
}
impl AggregateFunctionCreator for MedianAccumulatorCreator {
fn creator(&self) -> AccumulatorCreatorFunction {
let creator: AccumulatorCreatorFunction = Arc::new(move |types: &[ConcreteDataType]| {
let input_type = &types[0];
with_match_ordered_primitive_type_id!(
input_type.logical_type_id(),
|$S| {
Ok(Box::new(Median::<$S>::default()))
},
{
let err_msg = format!(
"\"MEDIAN\" aggregate function not support data type {:?}",
input_type.logical_type_id(),
);
CreateAccumulatorSnafu { err_msg }.fail()?
}
)
});
creator
}
fn input_types(&self) -> Result<Vec<ConcreteDataType>> {
let input_types = self.input_types.load();
if input_types.is_none() {
return Err(datafusion_internal_error()).context(ExecuteFunctionSnafu)?;
}
Ok(input_types.as_ref().unwrap().as_ref().clone())
}
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> Result<()> {
let old = self.input_types.swap(Some(Arc::new(input_types.clone())));
if let Some(old) = old {
if old.len() != input_types.len() {
return Err(datafusion_internal_error()).context(ExecuteFunctionSnafu)?;
}
for (x, y) in old.iter().zip(input_types.iter()) {
if x != y {
return Err(datafusion_internal_error()).context(ExecuteFunctionSnafu)?;
}
}
}
Ok(())
}
fn output_type(&self) -> Result<ConcreteDataType> {
let input_types = self.input_types()?;
if input_types.len() != 1 {
return Err(datafusion_internal_error()).context(ExecuteFunctionSnafu)?;
}
// unwrap is safe because we have checked input_types len must equals 1
Ok(input_types.into_iter().next().unwrap())
}
fn state_types(&self) -> Result<Vec<ConcreteDataType>> {
Ok(vec![ConcreteDataType::list_datatype(self.output_type()?)])
}
}
fn datafusion_internal_error() -> DataFusionError {
DataFusionError::Internal(
"Illegal input_types status, check if DataFusion has changed its UDAF execution logic."
.to_string(),
)
}
#[cfg(test)]
mod test {
use datatypes::vectors::PrimitiveVector;
use super::*;
#[test]
fn test_update_batch() {
// test update empty batch, expect not updating anything
let mut median = Median::<i32>::default();
assert!(median.update_batch(&[]).is_ok());
assert!(median.not_greater.is_empty());
assert!(median.greater.is_empty());
assert_eq!(Value::Null, median.evaluate().unwrap());
// test update one not-null value
let mut median = Median::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(PrimitiveVector::<i32>::from(vec![Some(42)]))];
assert!(median.update_batch(&v).is_ok());
assert_eq!(Value::Int32(42), median.evaluate().unwrap());
// test update one null value
let mut median = Median::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(PrimitiveVector::<i32>::from(vec![
Option::<i32>::None,
]))];
assert!(median.update_batch(&v).is_ok());
assert_eq!(Value::Null, median.evaluate().unwrap());
// test update no null-value batch
let mut median = Median::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(PrimitiveVector::<i32>::from(vec![
Some(-1i32),
Some(1),
Some(2),
]))];
assert!(median.update_batch(&v).is_ok());
assert_eq!(Value::Int32(1), median.evaluate().unwrap());
// test update null-value batch
let mut median = Median::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(PrimitiveVector::<i32>::from(vec![
Some(-2i32),
None,
Some(3),
Some(4),
]))];
assert!(median.update_batch(&v).is_ok());
assert_eq!(Value::Int32(3), median.evaluate().unwrap());
// test update with constant vector
let mut median = Median::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(ConstantVector::new(
Arc::new(PrimitiveVector::<i32>::from_vec(vec![4])),
10,
))];
assert!(median.update_batch(&v).is_ok());
assert_eq!(Value::Int32(4), median.evaluate().unwrap());
}
}

View File

@@ -0,0 +1,50 @@
mod median;
use std::sync::Arc;
use common_query::logical_plan::AggregateFunctionCreatorRef;
pub use median::MedianAccumulatorCreator;
use crate::scalars::FunctionRegistry;
/// A function creates `AggregateFunctionCreator`.
/// "Aggregator" *is* AggregatorFunction. Since the later one is long, we named an short alias for it.
/// The two names might be used interchangeably.
type AggregatorCreatorFunction = Arc<dyn Fn() -> AggregateFunctionCreatorRef + Send + Sync>;
/// `AggregateFunctionMeta` dynamically creates AggregateFunctionCreator.
#[derive(Clone)]
pub struct AggregateFunctionMeta {
name: String,
creator: AggregatorCreatorFunction,
}
pub type AggregateFunctionMetaRef = Arc<AggregateFunctionMeta>;
impl AggregateFunctionMeta {
pub fn new(name: &str, creator: AggregatorCreatorFunction) -> Self {
Self {
name: name.to_string(),
creator,
}
}
pub fn name(&self) -> String {
self.name.to_string()
}
pub fn create(&self) -> AggregateFunctionCreatorRef {
(self.creator)()
}
}
pub(crate) struct AggregateFunctions;
impl AggregateFunctions {
pub fn register(registry: &FunctionRegistry) {
registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new(
"median",
Arc::new(|| Arc::new(MedianAccumulatorCreator::default())),
)));
}
}

View File

@@ -5,6 +5,7 @@ use std::sync::RwLock;
use once_cell::sync::Lazy;
use crate::scalars::aggregate::{AggregateFunctionMetaRef, AggregateFunctions};
use crate::scalars::function::FunctionRef;
use crate::scalars::math::MathFunction;
use crate::scalars::numpy::NumpyFunction;
@@ -12,6 +13,7 @@ use crate::scalars::numpy::NumpyFunction;
#[derive(Default)]
pub struct FunctionRegistry {
functions: RwLock<HashMap<String, FunctionRef>>,
aggregate_functions: RwLock<HashMap<String, AggregateFunctionMetaRef>>,
}
impl FunctionRegistry {
@@ -22,6 +24,13 @@ impl FunctionRegistry {
.insert(func.name().to_string(), func);
}
pub fn register_aggregate_function(&self, func: AggregateFunctionMetaRef) {
self.aggregate_functions
.write()
.unwrap()
.insert(func.name(), func);
}
pub fn get_function(&self, name: &str) -> Option<FunctionRef> {
self.functions.read().unwrap().get(name).cloned()
}
@@ -29,6 +38,15 @@ impl FunctionRegistry {
pub fn functions(&self) -> Vec<FunctionRef> {
self.functions.read().unwrap().values().cloned().collect()
}
pub fn aggregate_functions(&self) -> Vec<AggregateFunctionMetaRef> {
self.aggregate_functions
.read()
.unwrap()
.values()
.cloned()
.collect()
}
}
pub static FUNCTION_REGISTRY: Lazy<Arc<FunctionRegistry>> = Lazy::new(|| {
@@ -37,6 +55,8 @@ pub static FUNCTION_REGISTRY: Lazy<Arc<FunctionRegistry>> = Lazy::new(|| {
MathFunction::register(&function_registry);
NumpyFunction::register(&function_registry);
AggregateFunctions::register(&function_registry);
Arc::new(function_registry)
});

View File

@@ -13,4 +13,7 @@ datafusion = { git = "https://github.com/apache/arrow-datafusion.git" , branch =
datafusion-common = { git = "https://github.com/apache/arrow-datafusion.git" , branch = "arrow2"}
datafusion-expr = { git = "https://github.com/apache/arrow-datafusion.git" , branch = "arrow2"}
datatypes = { path = "../../datatypes"}
snafu = { version = "0.7", features = ["backtraces"] }
snafu = { version = "0.7", features = ["backtraces"] }
[dev-dependencies]
tokio = { version = "1.0", features = ["full"] }

View File

@@ -28,6 +28,12 @@ pub enum InnerError {
source: DataTypeError,
data_type: ArrowDatatype,
},
#[snafu(display("Failed to create accumulator: {}", err_msg))]
CreateAccumulator { err_msg: String },
#[snafu(display("Failed to downcast vector: {}", err_msg))]
DowncastVector { err_msg: String },
}
pub type Result<T> = std::result::Result<T, Error>;
@@ -35,7 +41,9 @@ pub type Result<T> = std::result::Result<T, Error>;
impl ErrorExt for InnerError {
fn status_code(&self) -> StatusCode {
match self {
InnerError::ExecuteFunction { .. } => StatusCode::EngineExecuteQuery,
InnerError::ExecuteFunction { .. }
| InnerError::CreateAccumulator { .. }
| InnerError::DowncastVector { .. } => StatusCode::EngineExecuteQuery,
InnerError::IntoVector { source, .. } => source.status_code(),
InnerError::FromScalarValue { source } => source.status_code(),
}

View File

@@ -1,10 +1,13 @@
use std::sync::Arc;
use datatypes::prelude::ConcreteDataType;
use arrow::datatypes::DataType as ArrowDataType;
use datafusion_expr::ReturnTypeFunction as DfReturnTypeFunction;
use datatypes::prelude::{ConcreteDataType, DataType};
use datatypes::vectors::VectorRef;
use snafu::ResultExt;
use crate::error::{ExecuteFunctionSnafu, Result};
use crate::logical_plan::Accumulator;
use crate::prelude::{ColumnarValue, ScalarValue};
/// Scalar function
@@ -22,6 +25,13 @@ pub type ScalarFunctionImplementation =
pub type ReturnTypeFunction =
Arc<dyn Fn(&[ConcreteDataType]) -> Result<Arc<ConcreteDataType>> + Send + Sync>;
/// Accumulator creator that will be used by DataFusion
pub type AccumulatorFunctionImpl = Arc<dyn Fn() -> Result<Box<dyn Accumulator>> + Send + Sync>;
/// Create Accumulator with the data type of input columns.
pub type AccumulatorCreatorFunction =
Arc<dyn Fn(&[ConcreteDataType]) -> Result<Box<dyn Accumulator>> + Sync + Send>;
/// This signature corresponds to which types an aggregator serializes
/// its state, given its return datatype.
pub type StateTypeFunction =
@@ -69,6 +79,25 @@ where
})
}
pub fn to_df_return_type(func: ReturnTypeFunction) -> DfReturnTypeFunction {
let df_func = move |data_types: &[ArrowDataType]| {
// DataFusion DataType -> ConcreteDataType
let concrete_data_types = data_types
.iter()
.map(ConcreteDataType::from_arrow_type)
.collect::<Vec<_>>();
// evaluate ConcreteDataType
let eval_result = (func)(&concrete_data_types);
// ConcreteDataType -> DataFusion DataType
eval_result
.map(|t| Arc::new(t.as_arrow_type()))
.map_err(|e| e.into())
};
Arc::new(df_func)
}
#[cfg(test)]
mod tests {
use std::sync::Arc;

View File

@@ -0,0 +1,125 @@
//! Accumulator module contains the trait definition for aggregation function's accumulators.
use std::fmt::Debug;
use std::sync::Arc;
use arrow::array::ArrayRef;
use datafusion_common::Result as DfResult;
use datafusion_expr::Accumulator as DfAccumulator;
use datatypes::prelude::*;
use datatypes::vectors::Helper as VectorHelper;
use datatypes::vectors::VectorRef;
use snafu::ResultExt;
use crate::error::{Error, FromScalarValueSnafu, IntoVectorSnafu, Result};
use crate::prelude::*;
pub type AggregateFunctionCreatorRef = Arc<dyn AggregateFunctionCreator>;
/// An accumulator represents a stateful object that lives throughout the evaluation of multiple rows and
/// generically accumulates values.
///
/// An accumulator knows how to:
/// * update its state from inputs via `update_batch`
/// * convert its internal state to a vector of scalar values
/// * update its state from multiple accumulators' states via `merge_batch`
/// * compute the final value from its internal state via `evaluate`
///
/// Modified from DataFusion.
pub trait Accumulator: Send + Sync + Debug {
/// Returns the state of the accumulator at the end of the accumulation.
// in the case of an average on which we track `sum` and `n`, this function should return a vector
// of two values, sum and n.
fn state(&self) -> Result<Vec<Value>>;
/// updates the accumulator's state from a vector of arrays.
fn update_batch(&mut self, values: &[VectorRef]) -> Result<()>;
/// updates the accumulator's state from a vector of states.
fn merge_batch(&mut self, states: &[VectorRef]) -> Result<()>;
/// returns its value based on its current state.
fn evaluate(&self) -> Result<Value>;
}
/// An `AggregateFunctionCreator` dynamically creates `Accumulator`.
/// DataFusion does not provide the input data's types when creating Accumulator, we have to stores
/// it somewhere else ourself. So an `AggregateFunctionCreator` often has a companion struct, that
/// can store the input data types, and knows the output and states types of an Accumulator.
/// That's how we create the Accumulator generically.
pub trait AggregateFunctionCreator: Send + Sync + Debug {
/// Create a function that can create a new accumulator with some input data type.
fn creator(&self) -> AccumulatorCreatorFunction;
/// Get the input data type of the Accumulator.
fn input_types(&self) -> Result<Vec<ConcreteDataType>>;
/// Store the input data type that is provided by DataFusion at runtime.
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> Result<()>;
/// Get the Accumulator's output data type.
fn output_type(&self) -> Result<ConcreteDataType>;
/// Get the Accumulator's state data types.
fn state_types(&self) -> Result<Vec<ConcreteDataType>>;
}
pub fn make_accumulator_function(
creator: Arc<dyn AggregateFunctionCreator>,
) -> AccumulatorFunctionImpl {
Arc::new(move || {
let input_types = creator.input_types()?;
let creator = creator.creator();
creator(&input_types)
})
}
pub fn make_return_function(creator: Arc<dyn AggregateFunctionCreator>) -> ReturnTypeFunction {
Arc::new(move |input_types| {
creator.set_input_types(input_types.to_vec())?;
let output_type = creator.output_type()?;
Ok(Arc::new(output_type))
})
}
pub fn make_state_function(creator: Arc<dyn AggregateFunctionCreator>) -> StateTypeFunction {
Arc::new(move |_| Ok(Arc::new(creator.state_types()?)))
}
/// A wrapper newtype for our Accumulator to DataFusion's Accumulator,
/// so to make our Accumulator able to be executed by DataFusion query engine.
#[derive(Debug)]
pub struct DfAccumulatorAdaptor(pub Box<dyn Accumulator>);
impl DfAccumulator for DfAccumulatorAdaptor {
fn state(&self) -> DfResult<Vec<ScalarValue>> {
let state = self.0.state()?;
Ok(state.into_iter().map(ScalarValue::from).collect())
}
fn update_batch(&mut self, values: &[ArrayRef]) -> DfResult<()> {
let vectors = VectorHelper::try_into_vectors(values)
.context(FromScalarValueSnafu)
.map_err(Error::from)?;
self.0.update_batch(&vectors).map_err(|e| e.into())
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> DfResult<()> {
let mut vectors = Vec::with_capacity(states.len());
for array in states.iter() {
vectors.push(
VectorHelper::try_into_vector(array)
.context(IntoVectorSnafu {
data_type: array.data_type().clone(),
})
.map_err(Error::from)?,
);
}
self.0.merge_batch(&vectors).map_err(|e| e.into())
}
fn evaluate(&self) -> DfResult<ScalarValue> {
Ok(ScalarValue::from(self.0.evaluate()?))
}
}

View File

@@ -1,13 +1,18 @@
mod accumulator;
mod expr;
mod udaf;
mod udf;
use std::sync::Arc;
use datatypes::prelude::ConcreteDataType;
pub use self::accumulator::{Accumulator, AggregateFunctionCreator, AggregateFunctionCreatorRef};
pub use self::expr::Expr;
pub use self::udaf::AggregateFunction;
pub use self::udf::ScalarUdf;
use crate::function::{ReturnTypeFunction, ScalarFunctionImplementation};
use crate::logical_plan::accumulator::*;
use crate::signature::{Signature, Volatility};
/// Creates a new UDF with a specific signature and specific return type.
@@ -31,6 +36,22 @@ pub fn create_udf(
)
}
pub fn create_aggregate_function(
name: String,
creator: Arc<dyn AggregateFunctionCreator>,
) -> AggregateFunction {
let return_type = make_return_function(creator.clone());
let accumulator = make_accumulator_function(creator.clone());
let state_type = make_state_function(creator.clone());
AggregateFunction::new(
name,
Signature::any(1, Volatility::Immutable),
return_type,
accumulator,
state_type,
)
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
@@ -40,13 +61,13 @@ mod tests {
use datafusion_expr::ColumnarValue as DfColumnarValue;
use datafusion_expr::ScalarUDF as DfScalarUDF;
use datafusion_expr::TypeSignature as DfTypeSignature;
use datatypes::prelude::ScalarVector;
use datatypes::prelude::*;
use datatypes::vectors::BooleanVector;
use datatypes::vectors::VectorRef;
use super::*;
use crate::error::Result;
use crate::function::make_scalar_function;
use crate::function::{make_scalar_function, AccumulatorCreatorFunction};
use crate::prelude::ScalarValue;
use crate::signature::TypeSignature;
@@ -129,4 +150,76 @@ mod tests {
_ => unreachable!(),
}
}
#[derive(Debug)]
struct DummyAccumulator;
impl Accumulator for DummyAccumulator {
fn state(&self) -> Result<Vec<Value>> {
Ok(vec![])
}
fn update_batch(&mut self, _values: &[VectorRef]) -> Result<()> {
Ok(())
}
fn merge_batch(&mut self, _states: &[VectorRef]) -> Result<()> {
Ok(())
}
fn evaluate(&self) -> Result<Value> {
Ok(Value::Int32(0))
}
}
#[derive(Debug)]
struct DummyAccumulatorCreator;
impl AggregateFunctionCreator for DummyAccumulatorCreator {
fn creator(&self) -> AccumulatorCreatorFunction {
Arc::new(|_| Ok(Box::new(DummyAccumulator)))
}
fn input_types(&self) -> Result<Vec<ConcreteDataType>> {
Ok(vec![ConcreteDataType::float64_datatype()])
}
fn set_input_types(&self, _: Vec<ConcreteDataType>) -> Result<()> {
Ok(())
}
fn output_type(&self) -> Result<ConcreteDataType> {
Ok(self.input_types()?.into_iter().next().unwrap())
}
fn state_types(&self) -> Result<Vec<ConcreteDataType>> {
Ok(vec![
ConcreteDataType::float64_datatype(),
ConcreteDataType::uint32_datatype(),
])
}
}
#[test]
fn test_create_udaf() {
let creator = DummyAccumulatorCreator;
let udaf = create_aggregate_function("dummy".to_string(), Arc::new(creator));
assert_eq!("dummy", udaf.name);
let signature = udaf.signature;
assert_eq!(TypeSignature::Any(1), signature.type_signature);
assert_eq!(Volatility::Immutable, signature.volatility);
assert_eq!(
Arc::new(ConcreteDataType::float64_datatype()),
(udaf.return_type)(&[ConcreteDataType::float64_datatype()]).unwrap()
);
assert_eq!(
Arc::new(vec![
ConcreteDataType::float64_datatype(),
ConcreteDataType::uint32_datatype(),
]),
(udaf.state_type)(&ConcreteDataType::float64_datatype()).unwrap()
);
}
}

View File

@@ -0,0 +1,104 @@
//! Udaf module contains functions and structs supporting user-defined aggregate functions.
//!
//! Modified from DataFusion.
use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;
use arrow::datatypes::DataType as ArrowDataType;
use datafusion_expr::AccumulatorFunctionImplementation as DfAccumulatorFunctionImplementation;
use datafusion_expr::AggregateUDF as DfAggregateUdf;
use datafusion_expr::StateTypeFunction as DfStateTypeFunction;
use datatypes::prelude::*;
use crate::function::{
to_df_return_type, AccumulatorFunctionImpl, ReturnTypeFunction, StateTypeFunction,
};
use crate::logical_plan::accumulator::DfAccumulatorAdaptor;
use crate::signature::Signature;
/// Logical representation of a user-defined aggregate function (UDAF)
/// A UDAF is different from a UDF in that it is stateful across batches.
#[derive(Clone)]
pub struct AggregateFunction {
/// name
pub name: String,
/// signature
pub signature: Signature,
/// Return type
pub return_type: ReturnTypeFunction,
/// actual implementation
pub accumulator: AccumulatorFunctionImpl,
/// the accumulator's state's description as a function of the return type
pub state_type: StateTypeFunction,
}
impl Debug for AggregateFunction {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_struct("AggregateUDF")
.field("name", &self.name)
.field("signature", &self.signature)
.field("fun", &"<FUNC>")
.finish()
}
}
impl PartialEq for AggregateFunction {
fn eq(&self, other: &Self) -> bool {
self.name == other.name && self.signature == other.signature
}
}
impl AggregateFunction {
/// Create a new AggregateUDF
pub fn new(
name: String,
signature: Signature,
return_type: ReturnTypeFunction,
accumulator: AccumulatorFunctionImpl,
state_type: StateTypeFunction,
) -> Self {
Self {
name,
signature,
return_type,
accumulator,
state_type,
}
}
}
impl From<AggregateFunction> for DfAggregateUdf {
fn from(udaf: AggregateFunction) -> Self {
DfAggregateUdf::new(
&udaf.name,
&udaf.signature.into(),
&to_df_return_type(udaf.return_type),
&to_df_accumulator_func(udaf.accumulator),
&to_df_state_type(udaf.state_type),
)
}
}
fn to_df_accumulator_func(func: AccumulatorFunctionImpl) -> DfAccumulatorFunctionImplementation {
Arc::new(move || {
let acc = func()?;
Ok(Box::new(DfAccumulatorAdaptor(acc)))
})
}
fn to_df_state_type(func: StateTypeFunction) -> DfStateTypeFunction {
let df_func = move |data_type: &ArrowDataType| {
// DataFusion DataType -> ConcreteDataType
let concrete_data_type = ConcreteDataType::from_arrow_type(data_type);
// evaluate ConcreteDataType
let eval_result = (func)(&concrete_data_type);
// ConcreteDataType -> DataFusion DataType
eval_result
.map(|ts| Arc::new(ts.iter().map(|t| t.as_arrow_type()).collect()))
.map_err(|e| e.into())
};
Arc::new(df_func)
}

View File

@@ -5,15 +5,14 @@ use std::fmt::Debug;
use std::fmt::Formatter;
use std::sync::Arc;
use arrow::datatypes::DataType as ArrowDataType;
use datafusion_expr::{
ColumnarValue as DfColumnarValue, ReturnTypeFunction as DfReturnTypeFunction,
ColumnarValue as DfColumnarValue,
ScalarFunctionImplementation as DfScalarFunctionImplementation, ScalarUDF as DfScalarUDF,
};
use datatypes::prelude::{ConcreteDataType, DataType};
use crate::error::Result;
use crate::function::{ReturnTypeFunction, ScalarFunctionImplementation};
use crate::prelude::to_df_return_type;
use crate::signature::Signature;
/// Logical representation of a UDF.
@@ -60,27 +59,12 @@ impl ScalarUdf {
DfScalarUDF::new(
&self.name,
&self.signature.into(),
&to_df_returntype(self.return_type),
&to_df_return_type(self.return_type),
&to_df_scalar_func(self.fun),
)
}
}
fn to_df_returntype(fun: ReturnTypeFunction) -> DfReturnTypeFunction {
Arc::new(move |data_types: &[ArrowDataType]| {
let concret_types = data_types
.iter()
.map(ConcreteDataType::from_arrow_type)
.collect::<Vec<ConcreteDataType>>();
let result = (fun)(&concret_types);
result
.map(|t| Arc::new(t.as_arrow_type()))
.map_err(|e| e.into())
})
}
fn to_df_scalar_func(fun: ScalarFunctionImplementation) -> DfScalarFunctionImplementation {
Arc::new(move |args: &[DfColumnarValue]| {
let args: Result<Vec<_>> = args.iter().map(TryFrom::try_from).collect();

View File

@@ -3,6 +3,7 @@ pub use datafusion_common::ScalarValue;
pub use crate::columnar_value::ColumnarValue;
pub use crate::function::*;
pub use crate::logical_plan::create_udf;
pub use crate::logical_plan::AggregateFunction;
pub use crate::logical_plan::Expr;
pub use crate::logical_plan::ScalarUdf;
pub use crate::signature::{Signature, TypeSignature, Volatility};

View File

@@ -15,6 +15,8 @@ datafusion-common = { git = "https://github.com/apache/arrow-datafusion.git" , b
enum_dispatch = "0.3"
ordered-float = "3.0"
paste = "1.0"
num = "0.4"
num-traits = "0.2"
serde = { version = "1.0.136", features = ["derive"] }
serde_json = "1.0"
snafu = { version = "0.7", features = ["backtraces"] }

View File

@@ -7,7 +7,7 @@ use crate::error::{self, Error, Result};
use crate::type_id::LogicalTypeId;
use crate::types::{
BinaryType, BooleanType, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
NullType, StringType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
ListType, NullType, StringType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
};
use crate::value::Value;
@@ -32,6 +32,8 @@ pub enum ConcreteDataType {
// String types
Binary(BinaryType),
String(StringType),
List(ListType),
}
impl ConcreteDataType {
@@ -113,6 +115,9 @@ impl TryFrom<&ArrowDataType> for ConcreteDataType {
ArrowDataType::Float64 => Self::float64_datatype(),
ArrowDataType::Binary | ArrowDataType::LargeBinary => Self::binary_datatype(),
ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 => Self::string_datatype(),
ArrowDataType::List(field) => Self::List(ListType::new(
ConcreteDataType::from_arrow_type(&field.data_type),
)),
_ => {
return error::UnsupportedArrowTypeSnafu {
arrow_type: dt.clone(),
@@ -144,6 +149,12 @@ impl_new_concrete_type_functions!(
Binary, String
);
impl ConcreteDataType {
pub fn list_datatype(inner_type: ConcreteDataType) -> ConcreteDataType {
ConcreteDataType::List(ListType::new(inner_type))
}
}
/// Data type abstraction.
#[enum_dispatch::enum_dispatch]
pub trait DataType: std::fmt::Debug + Send + Sync {
@@ -164,6 +175,8 @@ pub type DataTypeRef = Arc<dyn DataType>;
#[cfg(test)]
mod tests {
use arrow::datatypes::Field;
use super::*;
#[test]
@@ -242,5 +255,13 @@ mod tests {
ConcreteDataType::from_arrow_type(&ArrowDataType::LargeUtf8),
ConcreteDataType::String(_)
));
assert_eq!(
ConcreteDataType::from_arrow_type(&ArrowDataType::List(Box::new(Field::new(
"item",
ArrowDataType::Int32,
true
)))),
ConcreteDataType::List(ListType::new(ConcreteDataType::int32_datatype()))
);
}
}

View File

@@ -38,6 +38,23 @@ macro_rules! for_all_primitive_types{
};
}
#[macro_export]
macro_rules! for_all_ordered_primitive_types {
($macro:tt $(, $x:tt)*) => {
$macro! {
[$($x),*],
{ i8 },
{ i16 },
{ i32 },
{ i64 },
{ u8 },
{ u16 },
{ u32 },
{ u64 }
}
};
}
#[macro_export]
macro_rules! with_match_primitive_type_id {
($key_type:expr, | $_:tt $T:ident | $body:tt, $nbody:tt) => {{
@@ -63,3 +80,27 @@ macro_rules! with_match_primitive_type_id {
}
}};
}
#[macro_export]
macro_rules! with_match_ordered_primitive_type_id {
($key_type:expr, | $_:tt $T:ident | $body:tt, $nbody:tt) => {{
macro_rules! __with_ty__ {
( $_ $T:ident ) => {
$body
};
}
match $key_type {
LogicalTypeId::Int8 => __with_ty__! { i8 },
LogicalTypeId::Int16 => __with_ty__! { i16 },
LogicalTypeId::Int32 => __with_ty__! { i32 },
LogicalTypeId::Int64 => __with_ty__! { i64 },
LogicalTypeId::UInt8 => __with_ty__! { u8 },
LogicalTypeId::UInt16 => __with_ty__! { u16 },
LogicalTypeId::UInt32 => __with_ty__! { u32 },
LogicalTypeId::UInt64 => __with_ty__! { u64 },
_ => $nbody,
}
}};
}

View File

@@ -2,6 +2,7 @@ pub use crate::data_type::{ConcreteDataType, DataType, DataTypeRef};
pub use crate::macros::*;
pub use crate::scalars::{Scalar, ScalarRef, ScalarVector, ScalarVectorBuilder};
pub use crate::type_id::LogicalTypeId;
pub use crate::types::Primitive;
pub use crate::value::Value;
pub use crate::vectors::{
Helper as VectorHelper, MutableVector, Validity, Vector, VectorBuilder, VectorRef,

View File

@@ -29,6 +29,8 @@ pub enum LogicalTypeId {
/// Datetime representing the elapsed time since UNIX epoch (1970-01-01) in
/// seconds/milliseconds/microseconds/nanoseconds, determined by precision.
DateTime,
List,
}
impl LogicalTypeId {
@@ -50,7 +52,7 @@ impl LogicalTypeId {
LogicalTypeId::Float64 => ConcreteDataType::float64_datatype(),
LogicalTypeId::String => ConcreteDataType::string_datatype(),
LogicalTypeId::Binary => ConcreteDataType::binary_datatype(),
LogicalTypeId::Date | LogicalTypeId::DateTime => {
LogicalTypeId::Date | LogicalTypeId::DateTime | LogicalTypeId::List => {
unimplemented!("Data type for {:?} is unimplemented", self)
}
}

View File

@@ -1,5 +1,6 @@
mod binary_type;
mod boolean_type;
mod list_type;
mod null_type;
mod primitive_traits;
mod primitive_type;
@@ -7,6 +8,7 @@ mod string_type;
pub use binary_type::BinaryType;
pub use boolean_type::BooleanType;
pub use list_type::ListType;
pub use null_type::NullType;
pub use primitive_traits::Primitive;
pub use primitive_type::{

View File

@@ -0,0 +1,65 @@
use arrow::datatypes::{DataType as ArrowDataType, Field};
use crate::prelude::*;
use crate::value::ListValue;
/// Used to represent the List datatype.
#[derive(Debug, Clone, PartialEq)]
pub struct ListType {
/// The type of List's inner data.
inner: Box<ConcreteDataType>,
}
impl Default for ListType {
fn default() -> Self {
ListType::new(ConcreteDataType::null_datatype())
}
}
impl ListType {
pub fn new(datatype: ConcreteDataType) -> Self {
ListType {
inner: Box::new(datatype),
}
}
}
impl DataType for ListType {
fn name(&self) -> &str {
"List"
}
fn logical_type_id(&self) -> LogicalTypeId {
LogicalTypeId::List
}
fn default_value(&self) -> Value {
Value::List(ListValue::new(None, *self.inner.clone()))
}
fn as_arrow_type(&self) -> ArrowDataType {
let field = Box::new(Field::new("item", self.inner.as_arrow_type(), true));
ArrowDataType::List(field)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::value::ListValue;
#[test]
fn test_list_type() {
let t = ListType::new(ConcreteDataType::boolean_datatype());
assert_eq!("List", t.name());
assert_eq!(LogicalTypeId::List, t.logical_type_id());
assert_eq!(
Value::List(ListValue::new(None, ConcreteDataType::boolean_datatype())),
t.default_value()
);
assert_eq!(
ArrowDataType::List(Box::new(Field::new("item", ArrowDataType::Boolean, true))),
t.as_arrow_type()
);
}
}

View File

@@ -1,10 +1,22 @@
use arrow::compute::arithmetics::basic::NativeArithmetics;
use arrow::types::NativeType;
use num::NumCast;
use crate::prelude::Scalar;
use crate::value::Value;
/// Primitive type.
pub trait Primitive:
PartialOrd + Default + Clone + Copy + Into<Value> + NativeType + serde::Serialize
PartialOrd
+ Default
+ Clone
+ Copy
+ Into<Value>
+ NativeType
+ serde::Serialize
+ NativeArithmetics
+ NumCast
+ Scalar
{
/// Largest numeric type this primitive type can be cast to.
type LargestType: Primitive;

View File

@@ -1,8 +1,11 @@
use std::cmp::Ordering;
use common_base::bytes::{Bytes, StringBytes};
use datafusion_common::ScalarValue;
pub use ordered_float::OrderedFloat;
use serde::{Serialize, Serializer};
use crate::data_type::ConcreteDataType;
use crate::prelude::*;
pub type OrderedF32 = OrderedFloat<f32>;
pub type OrderedF64 = OrderedFloat<f64>;
@@ -36,6 +39,8 @@ pub enum Value {
// Date & Time types:
Date(i32),
DateTime(i64),
List(ListValue),
}
impl Value {
@@ -59,7 +64,7 @@ impl Value {
Value::Float64(_) => ConcreteDataType::float64_datatype(),
Value::String(_) => ConcreteDataType::string_datatype(),
Value::Binary(_) => ConcreteDataType::binary_datatype(),
Value::Date(_) | Value::DateTime(_) => {
Value::Date(_) | Value::DateTime(_) | Value::List(_) => {
unimplemented!("Unsupported data type of value {:?}", self)
}
}
@@ -149,10 +154,73 @@ impl Serialize for Value {
Value::Binary(bytes) => bytes.serialize(serializer),
Value::Date(v) => v.serialize(serializer),
Value::DateTime(v) => v.serialize(serializer),
Value::List(_) => unimplemented!(),
}
}
}
impl From<Value> for ScalarValue {
fn from(value: Value) -> Self {
match value {
Value::Boolean(v) => ScalarValue::Boolean(Some(v)),
Value::UInt8(v) => ScalarValue::UInt8(Some(v)),
Value::UInt16(v) => ScalarValue::UInt16(Some(v)),
Value::UInt32(v) => ScalarValue::UInt32(Some(v)),
Value::UInt64(v) => ScalarValue::UInt64(Some(v)),
Value::Int8(v) => ScalarValue::Int8(Some(v)),
Value::Int16(v) => ScalarValue::Int16(Some(v)),
Value::Int32(v) => ScalarValue::Int32(Some(v)),
Value::Int64(v) => ScalarValue::Int64(Some(v)),
Value::Float32(v) => ScalarValue::Float32(Some(v.0)),
Value::Float64(v) => ScalarValue::Float64(Some(v.0)),
Value::String(v) => ScalarValue::LargeUtf8(Some(v.as_utf8().to_string())),
Value::Binary(v) => ScalarValue::LargeBinary(Some(v.to_vec())),
Value::Date(v) => ScalarValue::Date32(Some(v)),
Value::DateTime(v) => ScalarValue::Date64(Some(v)),
Value::Null => ScalarValue::Boolean(None),
Value::List(v) => ScalarValue::List(
v.items
.map(|vs| Box::new(vs.into_iter().map(ScalarValue::from).collect())),
Box::new(v.datatype.as_arrow_type()),
),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct ListValue {
/// List of nested Values (boxed to reduce size_of(Value))
#[allow(clippy::box_collection)]
items: Option<Box<Vec<Value>>>,
/// Inner values datatype, to distinguish empty lists of different datatypes.
/// Restricted by DataFusion, cannot use null datatype for empty list.
datatype: ConcreteDataType,
}
impl Eq for ListValue {}
impl ListValue {
pub fn new(items: Option<Box<Vec<Value>>>, datatype: ConcreteDataType) -> Self {
Self { items, datatype }
}
}
impl PartialOrd for ListValue {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for ListValue {
fn cmp(&self, other: &Self) -> Ordering {
assert_eq!(
self.datatype, other.datatype,
"Cannot compare different datatypes!"
);
self.items.cmp(&other.items)
}
}
#[cfg(test)]
mod tests {
use super::*;
@@ -297,4 +365,76 @@ mod tests {
let world: &[u8] = b"world";
assert_eq!(Value::Binary(Bytes::from(world)), Value::from(world));
}
#[test]
fn test_value_into_scalar_value() {
assert_eq!(
ScalarValue::Boolean(Some(true)),
Value::Boolean(true).into()
);
assert_eq!(
ScalarValue::Boolean(Some(true)),
Value::Boolean(true).into()
);
assert_eq!(
ScalarValue::UInt8(Some(u8::MIN + 1)),
Value::UInt8(u8::MIN + 1).into()
);
assert_eq!(
ScalarValue::UInt16(Some(u16::MIN + 2)),
Value::UInt16(u16::MIN + 2).into()
);
assert_eq!(
ScalarValue::UInt32(Some(u32::MIN + 3)),
Value::UInt32(u32::MIN + 3).into()
);
assert_eq!(
ScalarValue::UInt64(Some(u64::MIN + 4)),
Value::UInt64(u64::MIN + 4).into()
);
assert_eq!(
ScalarValue::Int8(Some(i8::MIN + 4)),
Value::Int8(i8::MIN + 4).into()
);
assert_eq!(
ScalarValue::Int16(Some(i16::MIN + 5)),
Value::Int16(i16::MIN + 5).into()
);
assert_eq!(
ScalarValue::Int32(Some(i32::MIN + 6)),
Value::Int32(i32::MIN + 6).into()
);
assert_eq!(
ScalarValue::Int64(Some(i64::MIN + 7)),
Value::Int64(i64::MIN + 7).into()
);
assert_eq!(
ScalarValue::Float32(Some(8.0f32)),
Value::Float32(OrderedFloat(8.0f32)).into()
);
assert_eq!(
ScalarValue::Float64(Some(9.0f64)),
Value::Float64(OrderedFloat(9.0f64)).into()
);
assert_eq!(
ScalarValue::LargeUtf8(Some("hello".to_string())),
Value::String(StringBytes::from("hello")).into()
);
assert_eq!(
ScalarValue::LargeBinary(Some("world".as_bytes().to_vec())),
Value::Binary(Bytes::from("world".as_bytes())).into()
);
assert_eq!(ScalarValue::Date32(Some(10i32)), Value::Date(10i32).into());
assert_eq!(
ScalarValue::Date64(Some(20i64)),
Value::DateTime(20i64).into()
);
assert_eq!(ScalarValue::Boolean(None), Value::Null.into());
}
}

View File

@@ -3,6 +3,7 @@ pub mod boolean;
mod builder;
pub mod constant;
mod helper;
mod list;
pub mod mutable;
pub mod null;
pub mod primitive;
@@ -18,6 +19,7 @@ pub use boolean::*;
pub use builder::VectorBuilder;
pub use constant::*;
pub use helper::Helper;
pub use list::*;
pub use mutable::MutableVector;
pub use null::*;
pub use primitive::*;

View File

@@ -78,6 +78,7 @@ impl VectorBuilder {
ConcreteDataType::Binary(_) => {
VectorBuilder::Binary(BinaryVectorBuilder::with_capacity(capacity))
}
_ => unimplemented!(),
}
}

View File

@@ -172,7 +172,35 @@ impl Helper {
ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 => {
Arc::new(StringVector::try_from_arrow_array(array)?)
}
ArrowDataType::List(_) => Arc::new(ListVector::try_from_arrow_array(array)?),
_ => unimplemented!("Arrow array datatype: {:?}", array.as_ref().data_type()),
})
}
pub fn try_into_vectors(arrays: &[ArrayRef]) -> Result<Vec<VectorRef>> {
arrays.iter().map(Self::try_into_vector).collect()
}
}
#[cfg(test)]
mod tests {
use arrow::array::Int32Array;
use super::*;
#[test]
fn test_try_into_vectors() {
let arrays: Vec<ArrayRef> = vec![
Arc::new(Int32Array::from_vec(vec![1])),
Arc::new(Int32Array::from_vec(vec![2])),
Arc::new(Int32Array::from_vec(vec![3])),
];
let vectors = Helper::try_into_vectors(&arrays);
assert!(vectors.is_ok());
let vectors = vectors.unwrap();
vectors.iter().for_each(|v| assert_eq!(1, v.len()));
assert_eq!(Value::Int32(1), vectors[0].get(0));
assert_eq!(Value::Int32(2), vectors[1].get(0));
assert_eq!(Value::Int32(3), vectors[2].get(0));
}
}

View File

@@ -0,0 +1,286 @@
use std::any::Any;
use std::sync::Arc;
use arrow::array::{Array, ArrayRef, ListArray};
use arrow::datatypes::DataType as ArrowDataType;
use serde_json::Value as JsonValue;
use snafu::prelude::*;
use crate::error::Result;
use crate::prelude::*;
use crate::serialize::Serializable;
use crate::types::ListType;
use crate::value::ListValue;
use crate::vectors::{impl_try_from_arrow_array_for_vector, impl_validity_for_vector};
type ArrowListArray = ListArray<i32>;
/// Vector of Lists, basically backed by Arrow's `ListArray`.
#[derive(Debug, Clone)]
pub struct ListVector {
array: ArrowListArray,
inner_data_type: ConcreteDataType,
}
impl ListVector {
pub fn values_iter(&self) -> Box<dyn Iterator<Item = Result<VectorRef>> + '_> {
Box::new(self.array.values_iter().map(VectorHelper::try_into_vector))
}
}
impl Vector for ListVector {
fn data_type(&self) -> ConcreteDataType {
ConcreteDataType::List(ListType::new(self.inner_data_type.clone()))
}
fn vector_type_name(&self) -> String {
"ListVector".to_string()
}
fn as_any(&self) -> &dyn Any {
self
}
fn len(&self) -> usize {
self.array.len()
}
fn to_arrow_array(&self) -> ArrayRef {
Arc::new(self.array.clone())
}
fn validity(&self) -> Validity {
impl_validity_for_vector!(self.array)
}
fn memory_size(&self) -> usize {
let offsets_bytes = self.array.offsets().len() * std::mem::size_of::<i64>();
let value_refs_bytes = self.array.values().len() * std::mem::size_of::<Arc<dyn Array>>();
offsets_bytes + value_refs_bytes
}
fn is_null(&self, row: usize) -> bool {
self.array.is_null(row)
}
fn slice(&self, offset: usize, length: usize) -> VectorRef {
Arc::new(ListVector::from(self.array.slice(offset, length)))
}
fn get(&self, index: usize) -> Value {
let array = &self.array.value(index);
let vector = VectorHelper::try_into_vector(array).unwrap_or_else(|_| {
panic!(
"arrow array with datatype {:?} cannot converted to our vector",
array.data_type()
)
});
let values = (0..vector.len())
.map(|i| vector.get(i))
.collect::<Vec<Value>>();
Value::List(ListValue::new(
Some(Box::new(values)),
self.inner_data_type.clone(),
))
}
fn replicate(&self, _: &[usize]) -> VectorRef {
// ListVector can be a scalar vector for implementing this `replicate` method. However,
// that requires a lot of efforts, starting from not using Arrow's ListArray.
// Refer to Databend's `ArrayColumn` for more details.
unimplemented!()
}
}
impl Serializable for ListVector {
fn serialize_to_json(&self) -> Result<Vec<JsonValue>> {
self.array
.iter()
.map(|v| match v {
None => Ok(JsonValue::Null),
Some(v) => VectorHelper::try_into_vector(v)
.and_then(|v| v.serialize_to_json())
.map(JsonValue::Array),
})
.collect()
}
}
impl From<ArrowListArray> for ListVector {
fn from(array: ArrowListArray) -> Self {
let inner_data_type = ConcreteDataType::from_arrow_type(match array.data_type() {
ArrowDataType::List(field) => &field.data_type,
_ => unreachable!(),
});
Self {
array,
inner_data_type,
}
}
}
impl_try_from_arrow_array_for_vector!(ArrowListArray, ListVector);
#[cfg(test)]
mod tests {
use arrow::array::{MutableListArray, MutablePrimitiveArray, TryExtend};
use super::*;
use crate::types::ListType;
#[test]
fn test_list_vector() {
let data = vec![
Some(vec![Some(1i32), Some(2), Some(3)]),
None,
Some(vec![Some(4), None, Some(6)]),
];
let mut arrow_array = MutableListArray::<i32, MutablePrimitiveArray<i32>>::new();
arrow_array.try_extend(data).unwrap();
let arrow_array: ArrowListArray = arrow_array.into();
let list_vector = ListVector {
array: arrow_array.clone(),
inner_data_type: ConcreteDataType::int32_datatype(),
};
assert_eq!(
ConcreteDataType::List(ListType::new(ConcreteDataType::int32_datatype())),
list_vector.data_type()
);
assert_eq!("ListVector", list_vector.vector_type_name());
assert_eq!(3, list_vector.len());
assert!(!list_vector.is_null(0));
assert!(list_vector.is_null(1));
assert!(!list_vector.is_null(2));
assert_eq!(
arrow_array,
list_vector
.to_arrow_array()
.as_any()
.downcast_ref::<ArrowListArray>()
.unwrap()
.clone()
);
assert_eq!(
Validity::Slots(arrow_array.validity().unwrap()),
list_vector.validity()
);
assert_eq!(
arrow_array.offsets().len() * std::mem::size_of::<i64>()
+ arrow_array.values().len() * std::mem::size_of::<Arc<dyn Array>>(),
list_vector.memory_size()
);
let slice = list_vector.slice(0, 2);
assert_eq!(
"ListArray[[1, 2, 3], None]",
format!("{:?}", slice.to_arrow_array())
);
assert_eq!(
Value::List(ListValue::new(
Some(Box::new(vec![
Value::Int32(1),
Value::Int32(2),
Value::Int32(3)
])),
ConcreteDataType::int32_datatype()
)),
list_vector.get(0)
);
assert_eq!(
Value::List(ListValue::new(
Some(Box::new(vec![])),
ConcreteDataType::int32_datatype()
)),
list_vector.get(1)
);
assert_eq!(
Value::List(ListValue::new(
Some(Box::new(vec![
Value::Int32(4),
Value::Null,
Value::Int32(6)
])),
ConcreteDataType::int32_datatype()
)),
list_vector.get(2)
);
}
#[test]
fn test_from_arrow_array() {
let data = vec![
Some(vec![Some(1u32), Some(2), Some(3)]),
None,
Some(vec![Some(4), None, Some(6)]),
];
let mut arrow_array = MutableListArray::<i32, MutablePrimitiveArray<u32>>::new();
arrow_array.try_extend(data).unwrap();
let arrow_array: ArrowListArray = arrow_array.into();
let array_ref: ArrayRef = Arc::new(arrow_array);
let list_vector = ListVector::try_from_arrow_array(array_ref).unwrap();
assert_eq!(
"ListVector { array: ListArray[[1, 2, 3], None, [4, None, 6]], inner_data_type: UInt32(UInt32) }",
format!("{:?}", list_vector)
);
}
#[test]
fn test_iter_list_vector_values() {
let data = vec![
Some(vec![Some(1i64), Some(2), Some(3)]),
None,
Some(vec![Some(4), None, Some(6)]),
];
let mut arrow_array = MutableListArray::<i32, MutablePrimitiveArray<i64>>::new();
arrow_array.try_extend(data).unwrap();
let arrow_array: ArrowListArray = arrow_array.into();
let list_vector = ListVector {
array: arrow_array,
inner_data_type: ConcreteDataType::int32_datatype(),
};
let mut iter = list_vector.values_iter();
assert_eq!(
"Int64[1, 2, 3]",
format!("{:?}", iter.next().unwrap().unwrap().to_arrow_array())
);
assert_eq!(
"Int64[]",
format!("{:?}", iter.next().unwrap().unwrap().to_arrow_array())
);
assert_eq!(
"Int64[4, None, 6]",
format!("{:?}", iter.next().unwrap().unwrap().to_arrow_array())
);
assert!(iter.next().is_none())
}
#[test]
fn test_serialize_to_json() {
let data = vec![
Some(vec![Some(1i64), Some(2), Some(3)]),
None,
Some(vec![Some(4), None, Some(6)]),
];
let mut arrow_array = MutableListArray::<i32, MutablePrimitiveArray<i64>>::new();
arrow_array.try_extend(data).unwrap();
let arrow_array: ArrowListArray = arrow_array.into();
let list_vector = ListVector {
array: arrow_array,
inner_data_type: ConcreteDataType::int32_datatype(),
};
assert_eq!(
"Ok([Array([Number(1), Number(2), Number(3)]), Null, Array([Number(4), Null, Number(6)])])",
format!("{:?}", list_vector.serialize_to_json())
);
}
}

View File

@@ -25,4 +25,4 @@ tempdir = "0.3"
tokio = { version = "1.18", features = ["full"] }
[dev-dependencies]
rand = "0.8.5"
rand = "0.8"

View File

@@ -9,6 +9,7 @@ version="0.10"
features = ["io_csv", "io_json", "io_parquet", "io_parquet_compression", "io_ipc", "ahash", "compute", "serde_types"]
[dependencies]
arc-swap = "1.0"
async-trait = "0.1"
common-error = { path = "../common/error" }
common-function = { path = "../common/function" }
@@ -27,5 +28,8 @@ tokio = "1.0"
sql = { path = "../sql" }
[dev-dependencies]
num = "0.4"
num-traits = "0.2"
rand = "0.8"
tokio = { version = "1.0", features = ["full"] }
tokio-stream = "0.1"

View File

@@ -7,6 +7,7 @@ mod planner;
use std::sync::Arc;
use common_function::scalars::aggregate::AggregateFunctionMetaRef;
use common_function::scalars::udf::create_udf;
use common_function::scalars::FunctionRef;
use common_query::prelude::ScalarUdf;
@@ -87,6 +88,17 @@ impl QueryEngine for DatafusionQueryEngine {
self.state.register_udf(udf);
}
/// Note in SQL queries, aggregate names are looked up using
/// lowercase unless the query uses quotes. For example,
///
/// `SELECT MY_UDAF(x)...` will look for an aggregate named `"my_udaf"`
/// `SELECT "my_UDAF"(x)` will look for an aggregate named `"my_UDAF"`
///
/// So it's better to make UDAF name lowercase when creating one.
fn register_aggregate_function(&self, func: AggregateFunctionMetaRef) {
self.state.register_aggregate_function(func);
}
fn register_function(&self, func: FunctionRef) {
self.state.register_udf(create_udf(func));
}

View File

@@ -1,6 +1,7 @@
use std::sync::Arc;
use arrow::datatypes::DataType;
use common_query::logical_plan::create_aggregate_function;
use datafusion::catalog::TableReference;
use datafusion::datasource::TableProvider;
use datafusion::physical_plan::udaf::AggregateUDF;
@@ -80,10 +81,8 @@ impl ContextProvider for DfContextProviderAdapter {
fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
self.state
.df_context()
.state
.lock()
.get_aggregate_meta(name)
.aggregate_function(name)
.map(|func| Arc::new(create_aggregate_function(func.name(), func.create()).into()))
}
fn get_variable_type(&self, variable_names: &[String]) -> Option<DataType> {

View File

@@ -3,6 +3,7 @@ mod state;
use std::sync::Arc;
use common_function::scalars::aggregate::AggregateFunctionMetaRef;
use common_function::scalars::{FunctionRef, FUNCTION_REGISTRY};
use common_query::prelude::ScalarUdf;
use common_recordbatch::SendableRecordBatchStream;
@@ -35,6 +36,8 @@ pub trait QueryEngine: Send + Sync {
fn register_udf(&self, udf: ScalarUdf);
fn register_aggregate_function(&self, func: AggregateFunctionMetaRef);
fn register_function(&self, func: FunctionRef);
}
@@ -50,6 +53,10 @@ impl QueryEngineFactory {
query_engine.register_function(func);
}
for accumulator in FUNCTION_REGISTRY.aggregate_functions() {
query_engine.register_aggregate_function(accumulator);
}
Self { query_engine }
}
}

View File

@@ -1,6 +1,8 @@
use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use std::sync::{Arc, RwLock};
use common_function::scalars::aggregate::AggregateFunctionMetaRef;
use common_query::prelude::ScalarUdf;
use datafusion::prelude::{ExecutionConfig, ExecutionContext};
@@ -16,6 +18,7 @@ use crate::executor::Runtime;
pub struct QueryEngineState {
df_context: ExecutionContext,
catalog_list: CatalogListRef,
aggregate_functions: Arc<RwLock<HashMap<String, AggregateFunctionMetaRef>>>,
}
impl fmt::Debug for QueryEngineState {
@@ -41,6 +44,7 @@ impl QueryEngineState {
Self {
df_context,
catalog_list,
aggregate_functions: Arc::new(RwLock::new(HashMap::new())),
}
}
@@ -54,6 +58,23 @@ impl QueryEngineState {
.insert(udf.name.clone(), Arc::new(udf.into_df_udf()));
}
pub fn aggregate_function(&self, function_name: &str) -> Option<AggregateFunctionMetaRef> {
self.aggregate_functions
.read()
.unwrap()
.get(function_name)
.cloned()
}
pub fn register_aggregate_function(&self, func: AggregateFunctionMetaRef) {
// TODO(LFC): Return some error if there exists an aggregate function with the same name.
// Simply overwrite the old value for now.
self.aggregate_functions
.write()
.unwrap()
.insert(func.name(), func);
}
#[inline]
pub fn catalog_list(&self) -> &CatalogListRef {
&self.catalog_list

View File

@@ -0,0 +1,267 @@
use std::fmt::Debug;
use std::marker::PhantomData;
use std::sync::Arc;
mod testing_table;
use arc_swap::ArcSwapOption;
use common_function::scalars::aggregate::AggregateFunctionMeta;
use common_query::error::CreateAccumulatorSnafu;
use common_query::error::Result as QueryResult;
use common_query::logical_plan::Accumulator;
use common_query::logical_plan::AggregateFunctionCreator;
use common_query::prelude::*;
use common_recordbatch::util;
use datafusion::arrow_print;
use datafusion_common::record_batch::RecordBatch as DfRecordBatch;
use datatypes::prelude::*;
use datatypes::types::DataTypeBuilder;
use datatypes::types::PrimitiveType;
use datatypes::vectors::PrimitiveVector;
use datatypes::with_match_primitive_type_id;
use num_traits::AsPrimitive;
use query::catalog::memory::{MemoryCatalogList, MemoryCatalogProvider, MemorySchemaProvider};
use query::catalog::schema::SchemaProvider;
use query::catalog::{CatalogList, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use query::error::Result;
use query::query_engine::Output;
use query::QueryEngineFactory;
use table::TableRef;
use crate::testing_table::TestingTable;
#[derive(Debug, Default)]
struct MySumAccumulator<T, SumT>
where
T: Primitive + AsPrimitive<SumT>,
SumT: Primitive + std::ops::AddAssign,
{
sum: SumT,
_phantom: PhantomData<T>,
}
impl<T, SumT> MySumAccumulator<T, SumT>
where
T: Primitive + AsPrimitive<SumT>,
SumT: Primitive + std::ops::AddAssign,
{
#[inline(always)]
fn add(&mut self, v: T) {
self.sum += v.as_();
}
#[inline(always)]
fn merge(&mut self, s: SumT) {
self.sum += s;
}
}
#[derive(Debug, Default)]
struct MySumAccumulatorCreator {
input_type: ArcSwapOption<Vec<ConcreteDataType>>,
}
impl AggregateFunctionCreator for MySumAccumulatorCreator {
fn creator(&self) -> AccumulatorCreatorFunction {
let creator: AccumulatorCreatorFunction = Arc::new(move |types: &[ConcreteDataType]| {
let input_type = &types[0];
with_match_primitive_type_id!(
input_type.logical_type_id(),
|$S| {
Ok(Box::new(MySumAccumulator::<$S, <$S as Primitive>::LargestType>::default()))
},
{
let err_msg = format!(
"\"MY_SUM\" aggregate function not support data type {:?}",
input_type.logical_type_id(),
);
CreateAccumulatorSnafu { err_msg }.fail()?
}
)
});
creator
}
fn input_types(&self) -> QueryResult<Vec<ConcreteDataType>> {
Ok(self.input_type
.load()
.as_ref()
.expect("input_type is not present, check if DataFusion has changed its UDAF execution logic")
.as_ref()
.clone())
}
fn set_input_types(&self, input_types: Vec<ConcreteDataType>) -> QueryResult<()> {
let old = self.input_type.swap(Some(Arc::new(input_types.clone())));
if let Some(old) = old {
assert_eq!(old.len(), input_types.len());
old.iter().zip(input_types.iter()).for_each(|(x, y)|
assert_eq!(x, y, "input type {:?} != {:?}, check if DataFusion has changed its UDAF execution logic", x, y)
);
}
Ok(())
}
fn output_type(&self) -> QueryResult<ConcreteDataType> {
let input_type = &self.input_types()?[0];
with_match_primitive_type_id!(
input_type.logical_type_id(),
|$S| {
Ok(PrimitiveType::<<$S as Primitive>::LargestType>::default().logical_type_id().data_type())
},
{
unreachable!()
}
)
}
fn state_types(&self) -> QueryResult<Vec<ConcreteDataType>> {
Ok(vec![self.output_type()?])
}
}
impl<T, SumT> Accumulator for MySumAccumulator<T, SumT>
where
T: Primitive + AsPrimitive<SumT>,
for<'a> T: Scalar<RefType<'a> = T>,
SumT: Primitive + std::ops::AddAssign,
for<'a> SumT: Scalar<RefType<'a> = SumT>,
{
fn state(&self) -> QueryResult<Vec<Value>> {
Ok(vec![self.sum.into()])
}
fn update_batch(&mut self, values: &[VectorRef]) -> QueryResult<()> {
if values.is_empty() {
return Ok(());
};
let column = &values[0];
let column: &<T as Scalar>::VectorType = unsafe { VectorHelper::static_cast(column) };
for v in column.iter_data().flatten() {
self.add(v)
}
Ok(())
}
fn merge_batch(&mut self, states: &[VectorRef]) -> QueryResult<()> {
if states.is_empty() {
return Ok(());
};
let states = &states[0];
let states: &<SumT as Scalar>::VectorType = unsafe { VectorHelper::static_cast(states) };
for s in states.iter_data().flatten() {
self.merge(s)
}
Ok(())
}
fn evaluate(&self) -> QueryResult<Value> {
Ok(self.sum.into())
}
}
#[tokio::test]
async fn test_my_sum() -> Result<()> {
common_telemetry::init_default_ut_logging();
test_my_sum_with(
(1..=10).collect::<Vec<u32>>(),
vec![
"+--------+",
"| my_sum |",
"+--------+",
"| 55 |",
"+--------+",
],
)
.await?;
test_my_sum_with(
(-10..=11).collect::<Vec<i32>>(),
vec![
"+--------+",
"| my_sum |",
"+--------+",
"| 11 |",
"+--------+",
],
)
.await?;
test_my_sum_with(
vec![-1.0f32, 1.0, 2.0, 3.0, 4.0],
vec![
"+--------+",
"| my_sum |",
"+--------+",
"| 9 |",
"+--------+",
],
)
.await?;
test_my_sum_with(
vec![u32::MAX, u32::MAX],
vec![
"+------------+",
"| my_sum |",
"+------------+",
"| 8589934590 |",
"+------------+",
],
)
.await?;
Ok(())
}
async fn test_my_sum_with<T>(numbers: Vec<T>, expected: Vec<&str>) -> Result<()>
where
T: Primitive + DataTypeBuilder,
{
let table_name = format!("{}_numbers", std::any::type_name::<T>());
let column_name = format!("{}_number", std::any::type_name::<T>());
let testing_table = Arc::new(TestingTable::new(
&column_name,
Arc::new(PrimitiveVector::<T>::from_vec(numbers.clone())),
));
let factory = new_query_engine_factory(table_name.clone(), testing_table);
let engine = factory.query_engine();
engine.register_aggregate_function(Arc::new(AggregateFunctionMeta::new(
"my_sum",
Arc::new(|| Arc::new(MySumAccumulatorCreator::default())),
)));
let sql = format!(
"select MY_SUM({}) as my_sum from {}",
column_name, table_name
);
let plan = engine.sql_to_plan(&sql)?;
let output = engine.execute(&plan).await?;
let recordbatch_stream = match output {
Output::RecordBatch(batch) => batch,
_ => unreachable!(),
};
let recordbatch = util::collect(recordbatch_stream).await.unwrap();
let df_recordbatch = recordbatch
.into_iter()
.map(|r| r.df_recordbatch)
.collect::<Vec<DfRecordBatch>>();
let pretty_print = arrow_print::write(&df_recordbatch);
let pretty_print = pretty_print.lines().collect::<Vec<&str>>();
assert_eq!(expected, pretty_print);
Ok(())
}
pub fn new_query_engine_factory(table_name: String, table: TableRef) -> QueryEngineFactory {
let schema_provider = Arc::new(MemorySchemaProvider::new());
let catalog_provider = Arc::new(MemoryCatalogProvider::new());
let catalog_list = Arc::new(MemoryCatalogList::default());
schema_provider.register_table(table_name, table).unwrap();
catalog_provider.register_schema(DEFAULT_SCHEMA_NAME, schema_provider);
catalog_list.register_catalog(DEFAULT_CATALOG_NAME.to_string(), catalog_provider);
QueryEngineFactory::new(catalog_list)
}

View File

@@ -1,22 +1,33 @@
mod pow;
mod testing_table;
use std::sync::Arc;
use arrow::array::UInt32Array;
use common_query::prelude::{create_udf, make_scalar_function, Volatility};
use common_recordbatch::util;
use common_recordbatch::error::Result as RecordResult;
use common_recordbatch::{util, RecordBatch};
use datafusion::field_util::FieldExt;
use datafusion::field_util::SchemaExt;
use datafusion::logical_plan::LogicalPlanBuilder;
use datatypes::data_type::ConcreteDataType;
use query::catalog::memory;
use datatypes::for_all_ordered_primitive_types;
use datatypes::prelude::*;
use datatypes::types::DataTypeBuilder;
use datatypes::vectors::PrimitiveVector;
use num::NumCast;
use query::catalog::memory::{MemoryCatalogList, MemoryCatalogProvider, MemorySchemaProvider};
use query::catalog::schema::SchemaProvider;
use query::catalog::{memory, CatalogList, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use query::error::Result;
use query::plan::LogicalPlan;
use query::query_engine::{Output, QueryEngineFactory};
use query::QueryEngine;
use rand::Rng;
use table::table::adapter::DfTableProviderAdapter;
use table::table::numbers::NumbersTable;
use crate::pow::pow;
use crate::testing_table::TestingTable;
#[tokio::test]
async fn test_datafusion_query_engine() -> Result<()> {
@@ -110,3 +121,170 @@ async fn test_udf() -> Result<()> {
Ok(())
}
fn create_query_engine() -> Arc<dyn QueryEngine> {
let schema_provider = Arc::new(MemorySchemaProvider::new());
let catalog_provider = Arc::new(MemoryCatalogProvider::new());
let catalog_list = Arc::new(MemoryCatalogList::default());
macro_rules! create_testing_table {
([], $( { $T:ty } ),*) => {
$(
let mut rng = rand::thread_rng();
let table_name = format!("{}_number_even", std::any::type_name::<$T>());
let column_name = table_name.clone();
let numbers = (1..=100).map(|_| rng.gen_range(<$T>::MIN..<$T>::MAX)).collect::<Vec<$T>>();
let table = Arc::new(TestingTable::new(
&column_name,
Arc::new(PrimitiveVector::<$T>::from_vec(numbers.to_vec())),
));
schema_provider.register_table(table_name, table).unwrap();
let table_name = format!("{}_number_odd", std::any::type_name::<$T>());
let column_name = table_name.clone();
let numbers = (1..=99).map(|_| rng.gen_range(<$T>::MIN..<$T>::MAX)).collect::<Vec<$T>>();
let table = Arc::new(TestingTable::new(
&column_name,
Arc::new(PrimitiveVector::<$T>::from_vec(numbers.to_vec())),
));
schema_provider.register_table(table_name, table).unwrap();
)*
}
}
for_all_ordered_primitive_types! { create_testing_table }
let table = Arc::new(TestingTable::new(
"f32_number",
Arc::new(PrimitiveVector::<f32>::from_vec(vec![1.0f32, 2.0, 3.0])),
));
schema_provider
.register_table("f32_number".to_string(), table)
.unwrap();
let table = Arc::new(TestingTable::new(
"f64_number",
Arc::new(PrimitiveVector::<f64>::from_vec(vec![1.0f64, 2.0, 3.0])),
));
schema_provider
.register_table("f64_number".to_string(), table)
.unwrap();
catalog_provider.register_schema(DEFAULT_SCHEMA_NAME, schema_provider);
catalog_list.register_catalog(DEFAULT_CATALOG_NAME.to_string(), catalog_provider);
let factory = QueryEngineFactory::new(catalog_list);
factory.query_engine().clone()
}
async fn get_numbers_from_table<T>(table_name: &str, engine: Arc<dyn QueryEngine>) -> Vec<T>
where
T: Primitive + DataTypeBuilder,
for<'a> T: Scalar<RefType<'a> = T>,
{
let column_name = table_name;
let sql = format!("SELECT {} FROM {}", column_name, table_name);
let plan = engine.sql_to_plan(&sql).unwrap();
let output = engine.execute(&plan).await.unwrap();
let recordbatch_stream = match output {
Output::RecordBatch(batch) => batch,
_ => unreachable!(),
};
let numbers = util::collect(recordbatch_stream).await.unwrap();
let columns = numbers[0].df_recordbatch.columns();
let column = VectorHelper::try_into_vector(&columns[0]).unwrap();
let column: &<T as Scalar>::VectorType = unsafe { VectorHelper::static_cast(&column) };
column.iter_data().flatten().collect::<Vec<T>>()
}
#[tokio::test]
async fn test_median_aggregator() -> Result<()> {
common_telemetry::init_default_ut_logging();
let engine = create_query_engine();
test_median_failed::<f32>("f32_number", engine.clone()).await?;
test_median_failed::<f64>("f64_number", engine.clone()).await?;
macro_rules! test_median {
([], $( { $T:ty } ),*) => {
$(
let table_name = format!("{}_number_even", std::any::type_name::<$T>());
test_median_success::<$T>(&table_name, engine.clone()).await?;
let table_name = format!("{}_number_odd", std::any::type_name::<$T>());
test_median_success::<$T>(&table_name, engine.clone()).await?;
)*
}
}
for_all_ordered_primitive_types! { test_median }
Ok(())
}
async fn test_median_success<T>(table_name: &str, engine: Arc<dyn QueryEngine>) -> Result<()>
where
T: Primitive + Ord + DataTypeBuilder,
for<'a> T: Scalar<RefType<'a> = T>,
{
let result = execute_median(table_name, engine.clone()).await.unwrap();
assert_eq!(1, result.len());
assert_eq!(result[0].df_recordbatch.num_columns(), 1);
assert_eq!(1, result[0].schema.arrow_schema().fields().len());
assert_eq!("median", result[0].schema.arrow_schema().field(0).name());
let columns = result[0].df_recordbatch.columns();
assert_eq!(1, columns.len());
assert_eq!(columns[0].len(), 1);
let v = VectorHelper::try_into_vector(&columns[0]).unwrap();
assert_eq!(1, v.len());
let median = v.get(0);
let mut numbers = get_numbers_from_table::<T>(table_name, engine.clone()).await;
numbers.sort();
let len = numbers.len();
let expected_median: Value = if len % 2 == 1 {
numbers[len / 2]
} else {
let a: f64 = NumCast::from(numbers[len / 2 - 1]).unwrap();
let b: f64 = NumCast::from(numbers[len / 2]).unwrap();
NumCast::from(a / 2.0 + b / 2.0).unwrap()
}
.into();
assert_eq!(expected_median, median);
Ok(())
}
async fn test_median_failed<T>(table_name: &str, engine: Arc<dyn QueryEngine>) -> Result<()>
where
T: Primitive + DataTypeBuilder,
{
let result = execute_median(table_name, engine).await;
assert!(result.is_err());
let error = result.unwrap_err();
assert!(error.to_string().contains(&format!(
"Failed to create accumulator: \"MEDIAN\" aggregate function not support data type {}",
T::type_name()
)));
Ok(())
}
async fn execute_median(
table_name: &str,
engine: Arc<dyn QueryEngine>,
) -> RecordResult<Vec<RecordBatch>> {
let column_name = table_name;
let sql = format!(
"select MEDIAN({}) as median from {}",
column_name, table_name
);
let plan = engine.sql_to_plan(&sql).unwrap();
let output = engine.execute(&plan).await.unwrap();
let recordbatch_stream = match output {
Output::RecordBatch(batch) => batch,
_ => unreachable!(),
};
util::collect(recordbatch_stream).await
}

View File

@@ -0,0 +1,73 @@
use std::any::Any;
use std::pin::Pin;
use std::sync::Arc;
use common_query::prelude::Expr;
use common_recordbatch::error::Result as RecordBatchResult;
use common_recordbatch::{RecordBatch, RecordBatchStream, SendableRecordBatchStream};
use datatypes::prelude::VectorRef;
use datatypes::schema::{ColumnSchema, Schema, SchemaRef};
use futures::task::{Context, Poll};
use futures::Stream;
use table::error::Result;
use table::Table;
#[derive(Debug, Clone)]
pub struct TestingTable {
records: RecordBatch,
}
impl TestingTable {
pub fn new(column_name: &str, values: VectorRef) -> Self {
let column_schemas = vec![ColumnSchema::new(column_name, values.data_type(), false)];
let schema = Arc::new(Schema::new(column_schemas));
Self {
records: RecordBatch::new(schema, vec![values]).unwrap(),
}
}
}
#[async_trait::async_trait]
impl Table for TestingTable {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.records.schema.clone()
}
async fn scan(
&self,
_projection: &Option<Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
) -> Result<SendableRecordBatchStream> {
Ok(Box::pin(TestingRecordsStream {
schema: self.records.schema.clone(),
records: Some(self.records.clone()),
}))
}
}
impl RecordBatchStream for TestingRecordsStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}
struct TestingRecordsStream {
schema: SchemaRef,
records: Option<RecordBatch>,
}
impl Stream for TestingRecordsStream {
type Item = RecordBatchResult<RecordBatch>;
fn poll_next(mut self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.records.take() {
Some(records) => Poll::Ready(Some(Ok(records))),
None => Poll::Ready(None),
}
}
}