mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-17 21:40:37 +00:00
fix: Fix compiler errors in query crate (#746)
* fix: Fix compiler errors in state.rs * fix: fix compiler errors in state * feat: upgrade sqlparser to 0.26 * fix: fix datafusion engine compiler errors * fix: Fix some tests in query crate * fix: Fix all warnings in tests * feat: Remove `Type` from timestamp's type name * fix: fix query tests Now datafusion already supports median, so this commit also remove the median function * style: Fix clippy * feat: Remove RecordBatch::pretty_print * chore: Address CR comments * Update src/query/src/query_engine/state.rs Co-authored-by: Ruihang Xia <waynestxia@gmail.com>
This commit is contained in:
22
Cargo.lock
generated
22
Cargo.lock
generated
@@ -1899,7 +1899,7 @@ dependencies = [
|
||||
"pin-project-lite",
|
||||
"rand 0.8.5",
|
||||
"smallvec",
|
||||
"sqlparser 0.26.0",
|
||||
"sqlparser",
|
||||
"tempfile",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
@@ -1919,7 +1919,7 @@ dependencies = [
|
||||
"object_store",
|
||||
"ordered-float 3.4.0",
|
||||
"parquet",
|
||||
"sqlparser 0.26.0",
|
||||
"sqlparser",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1932,7 +1932,7 @@ dependencies = [
|
||||
"arrow",
|
||||
"datafusion-common",
|
||||
"log",
|
||||
"sqlparser 0.26.0",
|
||||
"sqlparser",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2003,7 +2003,7 @@ dependencies = [
|
||||
"arrow",
|
||||
"datafusion-common",
|
||||
"datafusion-expr",
|
||||
"sqlparser 0.26.0",
|
||||
"sqlparser",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2467,7 +2467,6 @@ dependencies = [
|
||||
"session",
|
||||
"snafu",
|
||||
"sql",
|
||||
"sqlparser 0.15.0",
|
||||
"store-api",
|
||||
"substrait 0.1.0",
|
||||
"table",
|
||||
@@ -4950,7 +4949,9 @@ dependencies = [
|
||||
"datafusion",
|
||||
"datafusion-common",
|
||||
"datafusion-expr",
|
||||
"datafusion-optimizer",
|
||||
"datafusion-physical-expr",
|
||||
"datafusion-sql",
|
||||
"datatypes",
|
||||
"format_num",
|
||||
"futures",
|
||||
@@ -6357,7 +6358,7 @@ dependencies = [
|
||||
"mito",
|
||||
"once_cell",
|
||||
"snafu",
|
||||
"sqlparser 0.15.0",
|
||||
"sqlparser",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -6386,15 +6387,6 @@ dependencies = [
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sqlparser"
|
||||
version = "0.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "adbbea2526ad0d02ad9414a07c396078a5b944bbf9ca4fbab8f01bb4cb579081"
|
||||
dependencies = [
|
||||
"log",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sqlparser"
|
||||
version = "0.26.0"
|
||||
|
||||
@@ -23,6 +23,5 @@ pub(crate) mod test;
|
||||
mod timestamp;
|
||||
pub mod udf;
|
||||
|
||||
pub use aggregate::MedianAccumulatorCreator;
|
||||
pub use function::{Function, FunctionRef};
|
||||
pub use function_registry::{FunctionRegistry, FUNCTION_REGISTRY};
|
||||
|
||||
@@ -16,7 +16,6 @@ mod argmax;
|
||||
mod argmin;
|
||||
mod diff;
|
||||
mod mean;
|
||||
mod median;
|
||||
mod percentile;
|
||||
mod polyval;
|
||||
mod scipy_stats_norm_cdf;
|
||||
@@ -29,7 +28,6 @@ pub use argmin::ArgminAccumulatorCreator;
|
||||
use common_query::logical_plan::AggregateFunctionCreatorRef;
|
||||
pub use diff::DiffAccumulatorCreator;
|
||||
pub use mean::MeanAccumulatorCreator;
|
||||
pub use median::MedianAccumulatorCreator;
|
||||
pub use percentile::PercentileAccumulatorCreator;
|
||||
pub use polyval::PolyvalAccumulatorCreator;
|
||||
pub use scipy_stats_norm_cdf::ScipyStatsNormCdfAccumulatorCreator;
|
||||
@@ -88,7 +86,6 @@ impl AggregateFunctions {
|
||||
};
|
||||
}
|
||||
|
||||
register_aggr_func!("median", 1, MedianAccumulatorCreator);
|
||||
register_aggr_func!("diff", 1, DiffAccumulatorCreator);
|
||||
register_aggr_func!("mean", 1, MeanAccumulatorCreator);
|
||||
register_aggr_func!("polyval", 2, PolyvalAccumulatorCreator);
|
||||
|
||||
@@ -1,287 +0,0 @@
|
||||
// Copyright 2022 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::cmp::Reverse;
|
||||
use std::collections::BinaryHeap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore};
|
||||
use common_query::error::{
|
||||
CreateAccumulatorSnafu, DowncastVectorSnafu, FromScalarValueSnafu, Result,
|
||||
};
|
||||
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
|
||||
use common_query::prelude::*;
|
||||
use datatypes::prelude::*;
|
||||
use datatypes::types::OrdPrimitive;
|
||||
use datatypes::value::ListValue;
|
||||
use datatypes::vectors::{ConstantVector, Helper, ListVector};
|
||||
use datatypes::with_match_primitive_type_id;
|
||||
use num::NumCast;
|
||||
use snafu::{ensure, OptionExt, ResultExt};
|
||||
|
||||
// This median calculation algorithm's details can be found at
|
||||
// https://leetcode.cn/problems/find-median-from-data-stream/
|
||||
//
|
||||
// Basically, it uses two heaps, a maximum heap and a minimum. The maximum heap stores numbers that
|
||||
// are not greater than the median, and the minimum heap stores the greater. In a streaming of
|
||||
// numbers, when a number is arrived, we adjust the heaps' tops, so that either one top is the
|
||||
// median or both tops can be averaged to get the median.
|
||||
//
|
||||
// The time complexity to update the median is O(logn), O(1) to get the median; and the space
|
||||
// complexity is O(n). (Ignore the costs for heap expansion.)
|
||||
//
|
||||
// From the point of algorithm, [quick select](https://en.wikipedia.org/wiki/Quickselect) might be
|
||||
// better. But to use quick select here, we need a mutable self in the final calculation(`evaluate`)
|
||||
// to swap stored numbers in the states vector. Though we can make our `evaluate` received
|
||||
// `&mut self`, DataFusion calls our accumulator with `&self` (see `DfAccumulatorAdaptor`). That
|
||||
// means we have to introduce some kinds of interior mutability, and the overhead is not neglectable.
|
||||
//
|
||||
// TODO(LFC): Use quick select to get median when we can modify DataFusion's code, and benchmark with two-heap algorithm.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct Median<T>
|
||||
where
|
||||
T: WrapperType,
|
||||
{
|
||||
greater: BinaryHeap<Reverse<OrdPrimitive<T>>>,
|
||||
not_greater: BinaryHeap<OrdPrimitive<T>>,
|
||||
}
|
||||
|
||||
impl<T> Median<T>
|
||||
where
|
||||
T: WrapperType,
|
||||
{
|
||||
fn push(&mut self, value: T) {
|
||||
let value = OrdPrimitive::<T>(value);
|
||||
|
||||
if self.not_greater.is_empty() {
|
||||
self.not_greater.push(value);
|
||||
return;
|
||||
}
|
||||
// The `unwrap`s below are safe because there are `push`s before them.
|
||||
if value <= *self.not_greater.peek().unwrap() {
|
||||
self.not_greater.push(value);
|
||||
if self.not_greater.len() > self.greater.len() + 1 {
|
||||
self.greater.push(Reverse(self.not_greater.pop().unwrap()));
|
||||
}
|
||||
} else {
|
||||
self.greater.push(Reverse(value));
|
||||
if self.greater.len() > self.not_greater.len() {
|
||||
self.not_greater.push(self.greater.pop().unwrap().0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// UDAFs are built using the trait `Accumulator`, that offers DataFusion the necessary functions
|
||||
// to use them.
|
||||
impl<T> Accumulator for Median<T>
|
||||
where
|
||||
T: WrapperType + NumCast,
|
||||
{
|
||||
// This function serializes our state to `ScalarValue`, which DataFusion uses to pass this
|
||||
// state between execution stages. Note that this can be arbitrary data.
|
||||
//
|
||||
// The `ScalarValue`s returned here will be passed in as argument `states: &[VectorRef]` to
|
||||
// `merge_batch` function.
|
||||
fn state(&self) -> Result<Vec<Value>> {
|
||||
let nums = self
|
||||
.greater
|
||||
.iter()
|
||||
.map(|x| &x.0)
|
||||
.chain(self.not_greater.iter())
|
||||
.map(|&n| n.into())
|
||||
.collect::<Vec<Value>>();
|
||||
Ok(vec![Value::List(ListValue::new(
|
||||
Some(Box::new(nums)),
|
||||
T::LogicalType::build_data_type(),
|
||||
))])
|
||||
}
|
||||
|
||||
// DataFusion calls this function to update the accumulator's state for a batch of inputs rows.
|
||||
// It is expected this function to update the accumulator's state.
|
||||
fn update_batch(&mut self, values: &[VectorRef]) -> Result<()> {
|
||||
if values.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
ensure!(values.len() == 1, InvalidInputStateSnafu);
|
||||
|
||||
// This is a unary accumulator, so only one column is provided.
|
||||
let column = &values[0];
|
||||
let mut len = 1;
|
||||
let column: &<T as Scalar>::VectorType = if column.is_const() {
|
||||
len = column.len();
|
||||
let column: &ConstantVector = unsafe { Helper::static_cast(column) };
|
||||
unsafe { Helper::static_cast(column.inner()) }
|
||||
} else {
|
||||
unsafe { Helper::static_cast(column) }
|
||||
};
|
||||
(0..len).for_each(|_| {
|
||||
for v in column.iter_data().flatten() {
|
||||
self.push(v);
|
||||
}
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// DataFusion executes accumulators in partitions. In some execution stage, DataFusion will
|
||||
// merge states from other accumulators (returned by `state()` method).
|
||||
fn merge_batch(&mut self, states: &[VectorRef]) -> Result<()> {
|
||||
if states.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// The states here are returned by the `state` method. Since we only returned a vector
|
||||
// with one value in that method, `states[0]` is fine.
|
||||
let states = &states[0];
|
||||
let states = states
|
||||
.as_any()
|
||||
.downcast_ref::<ListVector>()
|
||||
.with_context(|| DowncastVectorSnafu {
|
||||
err_msg: format!(
|
||||
"expect ListVector, got vector type {}",
|
||||
states.vector_type_name()
|
||||
),
|
||||
})?;
|
||||
for state in states.values_iter() {
|
||||
if let Some(state) = state.context(FromScalarValueSnafu)? {
|
||||
// merging state is simply accumulate stored numbers from others', so just call update
|
||||
self.update_batch(&[state])?
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// DataFusion expects this function to return the final value of this aggregator.
|
||||
fn evaluate(&self) -> Result<Value> {
|
||||
if self.not_greater.is_empty() {
|
||||
assert!(
|
||||
self.greater.is_empty(),
|
||||
"not expected in two-heap median algorithm, there must be a bug when implementing it"
|
||||
);
|
||||
return Ok(Value::Null);
|
||||
}
|
||||
|
||||
// unwrap is safe because we checked not_greater heap's len above
|
||||
let not_greater = *self.not_greater.peek().unwrap();
|
||||
let median = if self.not_greater.len() > self.greater.len() {
|
||||
not_greater.into()
|
||||
} else {
|
||||
// unwrap is safe because greater heap len >= not_greater heap len, which is > 0
|
||||
let greater = self.greater.peek().unwrap();
|
||||
|
||||
// the following three NumCast's `unwrap`s are safe because T is primitive
|
||||
let not_greater_v: f64 = NumCast::from(not_greater.as_primitive()).unwrap();
|
||||
let greater_v: f64 = NumCast::from(greater.0.as_primitive()).unwrap();
|
||||
let median: T = NumCast::from((not_greater_v + greater_v) / 2.0).unwrap();
|
||||
median.into()
|
||||
};
|
||||
Ok(median)
|
||||
}
|
||||
}
|
||||
|
||||
#[as_aggr_func_creator]
|
||||
#[derive(Debug, Default, AggrFuncTypeStore)]
|
||||
pub struct MedianAccumulatorCreator {}
|
||||
|
||||
impl AggregateFunctionCreator for MedianAccumulatorCreator {
|
||||
fn creator(&self) -> AccumulatorCreatorFunction {
|
||||
let creator: AccumulatorCreatorFunction = Arc::new(move |types: &[ConcreteDataType]| {
|
||||
let input_type = &types[0];
|
||||
with_match_primitive_type_id!(
|
||||
input_type.logical_type_id(),
|
||||
|$S| {
|
||||
Ok(Box::new(Median::<<$S as LogicalPrimitiveType>::Wrapper>::default()))
|
||||
},
|
||||
{
|
||||
let err_msg = format!(
|
||||
"\"MEDIAN\" aggregate function not support data type {:?}",
|
||||
input_type.logical_type_id(),
|
||||
);
|
||||
CreateAccumulatorSnafu { err_msg }.fail()?
|
||||
}
|
||||
)
|
||||
});
|
||||
creator
|
||||
}
|
||||
|
||||
fn output_type(&self) -> Result<ConcreteDataType> {
|
||||
let input_types = self.input_types()?;
|
||||
ensure!(input_types.len() == 1, InvalidInputStateSnafu);
|
||||
// unwrap is safe because we have checked input_types len must equals 1
|
||||
Ok(input_types.into_iter().next().unwrap())
|
||||
}
|
||||
|
||||
fn state_types(&self) -> Result<Vec<ConcreteDataType>> {
|
||||
Ok(vec![ConcreteDataType::list_datatype(self.output_type()?)])
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use datatypes::vectors::Int32Vector;
|
||||
|
||||
use super::*;
|
||||
#[test]
|
||||
fn test_update_batch() {
|
||||
// test update empty batch, expect not updating anything
|
||||
let mut median = Median::<i32>::default();
|
||||
assert!(median.update_batch(&[]).is_ok());
|
||||
assert!(median.not_greater.is_empty());
|
||||
assert!(median.greater.is_empty());
|
||||
assert_eq!(Value::Null, median.evaluate().unwrap());
|
||||
|
||||
// test update one not-null value
|
||||
let mut median = Median::<i32>::default();
|
||||
let v: Vec<VectorRef> = vec![Arc::new(Int32Vector::from(vec![Some(42)]))];
|
||||
assert!(median.update_batch(&v).is_ok());
|
||||
assert_eq!(Value::Int32(42), median.evaluate().unwrap());
|
||||
|
||||
// test update one null value
|
||||
let mut median = Median::<i32>::default();
|
||||
let v: Vec<VectorRef> = vec![Arc::new(Int32Vector::from(vec![Option::<i32>::None]))];
|
||||
assert!(median.update_batch(&v).is_ok());
|
||||
assert_eq!(Value::Null, median.evaluate().unwrap());
|
||||
|
||||
// test update no null-value batch
|
||||
let mut median = Median::<i32>::default();
|
||||
let v: Vec<VectorRef> = vec![Arc::new(Int32Vector::from(vec![
|
||||
Some(-1i32),
|
||||
Some(1),
|
||||
Some(2),
|
||||
]))];
|
||||
assert!(median.update_batch(&v).is_ok());
|
||||
assert_eq!(Value::Int32(1), median.evaluate().unwrap());
|
||||
|
||||
// test update null-value batch
|
||||
let mut median = Median::<i32>::default();
|
||||
let v: Vec<VectorRef> = vec![Arc::new(Int32Vector::from(vec![
|
||||
Some(-2i32),
|
||||
None,
|
||||
Some(3),
|
||||
Some(4),
|
||||
]))];
|
||||
assert!(median.update_batch(&v).is_ok());
|
||||
assert_eq!(Value::Int32(3), median.evaluate().unwrap());
|
||||
|
||||
// test update with constant vector
|
||||
let mut median = Median::<i32>::default();
|
||||
let v: Vec<VectorRef> = vec![Arc::new(ConstantVector::new(
|
||||
Arc::new(Int32Vector::from_vec(vec![4])),
|
||||
10,
|
||||
))];
|
||||
assert!(median.update_batch(&v).is_ok());
|
||||
assert_eq!(Value::Int32(4), median.evaluate().unwrap());
|
||||
}
|
||||
}
|
||||
@@ -66,7 +66,7 @@ impl Stream for EmptyRecordBatchStream {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct RecordBatches {
|
||||
schema: SchemaRef,
|
||||
batches: Vec<RecordBatch>,
|
||||
|
||||
@@ -15,13 +15,20 @@
|
||||
use futures::TryStreamExt;
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::{RecordBatch, SendableRecordBatchStream};
|
||||
use crate::{RecordBatch, RecordBatches, SendableRecordBatchStream};
|
||||
|
||||
/// Collect all the items from the stream into a vector of [`RecordBatch`].
|
||||
pub async fn collect(stream: SendableRecordBatchStream) -> Result<Vec<RecordBatch>> {
|
||||
stream.try_collect::<Vec<_>>().await
|
||||
}
|
||||
|
||||
/// Collect all the items from the stream into [RecordBatches].
|
||||
pub async fn collect_batches(stream: SendableRecordBatchStream) -> Result<RecordBatches> {
|
||||
let schema = stream.schema();
|
||||
let batches = stream.try_collect::<Vec<_>>().await?;
|
||||
RecordBatches::try_new(schema, batches)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::mem;
|
||||
@@ -90,7 +97,14 @@ mod tests {
|
||||
};
|
||||
let batches = collect(Box::pin(stream)).await.unwrap();
|
||||
assert_eq!(1, batches.len());
|
||||
|
||||
assert_eq!(batch, batches[0]);
|
||||
|
||||
let stream = MockRecordBatchStream {
|
||||
schema: schema.clone(),
|
||||
batch: Some(batch.clone()),
|
||||
};
|
||||
let batches = collect_batches(Box::pin(stream)).await.unwrap();
|
||||
let expect_batches = RecordBatches::try_new(schema.clone(), vec![batch]).unwrap();
|
||||
assert_eq!(expect_batches, batches);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,7 +70,7 @@ macro_rules! impl_data_type_for_timestamp {
|
||||
|
||||
impl DataType for [<Timestamp $unit Type>] {
|
||||
fn name(&self) -> &str {
|
||||
stringify!([<Timestamp $unit Type>])
|
||||
stringify!([<Timestamp $unit>])
|
||||
}
|
||||
|
||||
fn logical_type_id(&self) -> LogicalTypeId {
|
||||
|
||||
@@ -1346,7 +1346,7 @@ mod tests {
|
||||
ConcreteDataType::timestamp_second_datatype(),
|
||||
))
|
||||
.to_string(),
|
||||
"TimestampSecondType[]"
|
||||
"TimestampSecond[]"
|
||||
);
|
||||
assert_eq!(
|
||||
Value::List(ListValue::new(
|
||||
@@ -1354,7 +1354,7 @@ mod tests {
|
||||
ConcreteDataType::timestamp_millisecond_datatype(),
|
||||
))
|
||||
.to_string(),
|
||||
"TimestampMillisecondType[]"
|
||||
"TimestampMillisecond[]"
|
||||
);
|
||||
assert_eq!(
|
||||
Value::List(ListValue::new(
|
||||
@@ -1362,7 +1362,7 @@ mod tests {
|
||||
ConcreteDataType::timestamp_microsecond_datatype(),
|
||||
))
|
||||
.to_string(),
|
||||
"TimestampMicrosecondType[]"
|
||||
"TimestampMicrosecond[]"
|
||||
);
|
||||
assert_eq!(
|
||||
Value::List(ListValue::new(
|
||||
@@ -1370,7 +1370,7 @@ mod tests {
|
||||
ConcreteDataType::timestamp_nanosecond_datatype(),
|
||||
))
|
||||
.to_string(),
|
||||
"TimestampNanosecondType[]"
|
||||
"TimestampNanosecond[]"
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -40,7 +40,6 @@ servers = { path = "../servers" }
|
||||
session = { path = "../session" }
|
||||
snafu = { version = "0.7", features = ["backtraces"] }
|
||||
sql = { path = "../sql" }
|
||||
sqlparser = "0.15"
|
||||
store-api = { path = "../store-api" }
|
||||
substrait = { path = "../common/substrait" }
|
||||
table = { path = "../table" }
|
||||
|
||||
@@ -19,9 +19,9 @@ use api::helper::ColumnDataTypeWrapper;
|
||||
use api::v1::{Column, ColumnDataType, CreateExpr};
|
||||
use datatypes::schema::ColumnSchema;
|
||||
use snafu::{ensure, ResultExt};
|
||||
use sql::ast::{ColumnDef, TableConstraint};
|
||||
use sql::statements::create::{CreateTable, TIME_INDEX};
|
||||
use sql::statements::{column_def_to_schema, table_idents_to_full_name};
|
||||
use sqlparser::ast::{ColumnDef, TableConstraint};
|
||||
|
||||
use crate::error::{
|
||||
BuildCreateExprOnInsertionSnafu, ColumnDataTypeSnafu, ConvertColumnDefaultConstraintSnafu,
|
||||
|
||||
@@ -35,10 +35,10 @@ use query::sql::{describe_table, explain, show_databases, show_tables};
|
||||
use query::{QueryEngineFactory, QueryEngineRef};
|
||||
use session::context::QueryContextRef;
|
||||
use snafu::{ensure, OptionExt, ResultExt};
|
||||
use sql::ast::Value as SqlValue;
|
||||
use sql::statements::create::Partitions;
|
||||
use sql::statements::sql_value_to_value;
|
||||
use sql::statements::statement::Statement;
|
||||
use sqlparser::ast::Value as SqlValue;
|
||||
use table::metadata::{RawTableInfo, RawTableMeta, TableIdent, TableType};
|
||||
|
||||
use crate::catalog::FrontendCatalogManager;
|
||||
@@ -454,9 +454,9 @@ fn find_partition_columns(
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use sql::dialect::GenericDialect;
|
||||
use sql::parser::ParserContext;
|
||||
use sql::statements::statement::Statement;
|
||||
use sqlparser::dialect::GenericDialect;
|
||||
|
||||
use super::*;
|
||||
use crate::expr_factory::{CreateExprFactory, DefaultCreateExprFactory};
|
||||
|
||||
@@ -530,9 +530,9 @@ mod test {
|
||||
use meta_srv::mocks::MockInfo;
|
||||
use meta_srv::service::store::kv::KvStoreRef;
|
||||
use meta_srv::service::store::memory::MemStore;
|
||||
use sql::dialect::GenericDialect;
|
||||
use sql::parser::ParserContext;
|
||||
use sql::statements::statement::Statement;
|
||||
use sqlparser::dialect::GenericDialect;
|
||||
use table::metadata::{TableInfoBuilder, TableMetaBuilder};
|
||||
use table::TableRef;
|
||||
use tempdir::TempDir;
|
||||
|
||||
@@ -18,7 +18,9 @@ common-time = { path = "../common/time" }
|
||||
datafusion = "14.0.0"
|
||||
datafusion-common = "14.0.0"
|
||||
datafusion-expr = "14.0.0"
|
||||
datafusion-optimizer = "14.0.0"
|
||||
datafusion-physical-expr = "14.0.0"
|
||||
datafusion-sql = "14.0.0"
|
||||
datatypes = { path = "../datatypes" }
|
||||
futures = "0.3"
|
||||
futures-util = "0.3"
|
||||
|
||||
@@ -141,7 +141,6 @@ impl LogicalOptimizer for DatafusionQueryEngine {
|
||||
LogicalPlan::DfPlan(df_plan) => {
|
||||
let optimized_plan =
|
||||
self.state
|
||||
.df_context()
|
||||
.optimize(df_plan)
|
||||
.context(error::DatafusionSnafu {
|
||||
msg: "Fail to optimize logical plan",
|
||||
@@ -163,14 +162,11 @@ impl PhysicalPlanner for DatafusionQueryEngine {
|
||||
let _timer = timer!(metric::METRIC_CREATE_PHYSICAL_ELAPSED);
|
||||
match logical_plan {
|
||||
LogicalPlan::DfPlan(df_plan) => {
|
||||
let physical_plan = self
|
||||
.state
|
||||
.df_context()
|
||||
.create_physical_plan(df_plan)
|
||||
.await
|
||||
.context(error::DatafusionSnafu {
|
||||
let physical_plan = self.state.create_physical_plan(df_plan).await.context(
|
||||
error::DatafusionSnafu {
|
||||
msg: "Fail to create physical plan",
|
||||
})?;
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(Arc::new(PhysicalPlanAdapter::new(
|
||||
Arc::new(
|
||||
@@ -193,22 +189,19 @@ impl PhysicalOptimizer for DatafusionQueryEngine {
|
||||
plan: Arc<dyn PhysicalPlan>,
|
||||
) -> Result<Arc<dyn PhysicalPlan>> {
|
||||
let _timer = timer!(metric::METRIC_OPTIMIZE_PHYSICAL_ELAPSED);
|
||||
let config = &self.state.df_context().state.lock().config;
|
||||
let optimizers = &config.physical_optimizers;
|
||||
|
||||
let mut new_plan = plan
|
||||
let new_plan = plan
|
||||
.as_any()
|
||||
.downcast_ref::<PhysicalPlanAdapter>()
|
||||
.context(error::PhysicalPlanDowncastSnafu)?
|
||||
.df_plan();
|
||||
|
||||
for optimizer in optimizers {
|
||||
new_plan = optimizer
|
||||
.optimize(new_plan, config)
|
||||
let new_plan =
|
||||
self.state
|
||||
.optimize_physical_plan(new_plan)
|
||||
.context(error::DatafusionSnafu {
|
||||
msg: "Fail to optimize physical plan",
|
||||
})?;
|
||||
}
|
||||
Ok(Arc::new(PhysicalPlanAdapter::new(plan.schema(), new_plan)))
|
||||
}
|
||||
}
|
||||
@@ -224,7 +217,7 @@ impl QueryExecutor for DatafusionQueryEngine {
|
||||
match plan.output_partitioning().partition_count() {
|
||||
0 => Ok(Box::pin(EmptyRecordBatchStream::new(plan.schema()))),
|
||||
1 => Ok(plan
|
||||
.execute(0, ctx.state().runtime())
|
||||
.execute(0, ctx.state().task_ctx())
|
||||
.context(error::ExecutePhysicalPlanSnafu)?),
|
||||
_ => {
|
||||
// merge into a single partition
|
||||
@@ -232,11 +225,11 @@ impl QueryExecutor for DatafusionQueryEngine {
|
||||
CoalescePartitionsExec::new(Arc::new(DfPhysicalPlanAdapter(plan.clone())));
|
||||
// CoalescePartitionsExec must produce a single partition
|
||||
assert_eq!(1, plan.output_partitioning().partition_count());
|
||||
let df_stream = plan.execute(0, ctx.state().runtime()).await.context(
|
||||
error::DatafusionSnafu {
|
||||
msg: "Failed to execute DataFusion merge exec",
|
||||
},
|
||||
)?;
|
||||
let df_stream =
|
||||
plan.execute(0, ctx.state().task_ctx())
|
||||
.context(error::DatafusionSnafu {
|
||||
msg: "Failed to execute DataFusion merge exec",
|
||||
})?;
|
||||
let stream = RecordBatchStreamAdapter::try_new(df_stream)
|
||||
.context(error::ConvertDfRecordBatchStreamSnafu)?;
|
||||
Ok(Box::pin(stream))
|
||||
@@ -254,8 +247,7 @@ mod tests {
|
||||
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
|
||||
use common_query::Output;
|
||||
use common_recordbatch::util;
|
||||
use datafusion::field_util::{FieldExt, SchemaExt};
|
||||
use datatypes::arrow::array::UInt64Array;
|
||||
use datatypes::vectors::{UInt64Vector, VectorRef};
|
||||
use session::context::QueryContext;
|
||||
use table::table::numbers::NumbersTable;
|
||||
|
||||
@@ -290,10 +282,10 @@ mod tests {
|
||||
|
||||
assert_eq!(
|
||||
format!("{:?}", plan),
|
||||
r#"DfPlan(Limit: 20
|
||||
Projection: #SUM(numbers.number)
|
||||
Aggregate: groupBy=[[]], aggr=[[SUM(#numbers.number)]]
|
||||
TableScan: numbers projection=None)"#
|
||||
r#"DfPlan(Limit: skip=0, fetch=20
|
||||
Projection: SUM(numbers.number)
|
||||
Aggregate: groupBy=[[]], aggr=[[SUM(numbers.number)]]
|
||||
TableScan: numbers)"#
|
||||
);
|
||||
}
|
||||
|
||||
@@ -311,20 +303,20 @@ mod tests {
|
||||
Output::Stream(recordbatch) => {
|
||||
let numbers = util::collect(recordbatch).await.unwrap();
|
||||
assert_eq!(1, numbers.len());
|
||||
assert_eq!(numbers[0].df_recordbatch.num_columns(), 1);
|
||||
assert_eq!(1, numbers[0].schema.arrow_schema().fields().len());
|
||||
assert_eq!(numbers[0].num_columns(), 1);
|
||||
assert_eq!(1, numbers[0].schema.num_columns());
|
||||
assert_eq!(
|
||||
"SUM(numbers.number)",
|
||||
numbers[0].schema.arrow_schema().field(0).name()
|
||||
numbers[0].schema.column_schemas()[0].name
|
||||
);
|
||||
|
||||
let columns = numbers[0].df_recordbatch.columns();
|
||||
assert_eq!(1, columns.len());
|
||||
assert_eq!(columns[0].len(), 1);
|
||||
let batch = &numbers[0];
|
||||
assert_eq!(1, batch.num_columns());
|
||||
assert_eq!(batch.column(0).len(), 1);
|
||||
|
||||
assert_eq!(
|
||||
*columns[0].as_any().downcast_ref::<UInt64Array>().unwrap(),
|
||||
UInt64Array::from_slice(&[4950])
|
||||
*batch.column(0),
|
||||
Arc::new(UInt64Vector::from_slice(&[4950])) as VectorRef
|
||||
);
|
||||
}
|
||||
_ => unreachable!(),
|
||||
|
||||
@@ -12,14 +12,16 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::logical_plan::create_aggregate_function;
|
||||
use datafusion::catalog::TableReference;
|
||||
use datafusion::datasource::TableProvider;
|
||||
use datafusion::error::Result as DfResult;
|
||||
use datafusion::physical_plan::udaf::AggregateUDF;
|
||||
use datafusion::physical_plan::udf::ScalarUDF;
|
||||
use datafusion::sql::planner::{ContextProvider, SqlToRel};
|
||||
use datafusion_expr::TableSource;
|
||||
use datatypes::arrow::datatypes::DataType;
|
||||
use session::context::QueryContextRef;
|
||||
use snafu::ResultExt;
|
||||
@@ -50,7 +52,7 @@ impl<'a, S: ContextProvider + Send + Sync> DfPlanner<'a, S> {
|
||||
let sql = query.inner.to_string();
|
||||
let result = self
|
||||
.sql_to_rel
|
||||
.query_to_plan(query.inner)
|
||||
.query_to_plan(query.inner, &mut HashMap::new())
|
||||
.context(error::PlanSqlSnafu { sql })?;
|
||||
|
||||
Ok(LogicalPlan::DfPlan(result))
|
||||
@@ -103,26 +105,14 @@ impl DfContextProviderAdapter {
|
||||
}
|
||||
}
|
||||
|
||||
/// TODO(dennis): Delegate all requests to ExecutionContext right now,
|
||||
/// manage UDFs, UDAFs, variables by ourself in future.
|
||||
impl ContextProvider for DfContextProviderAdapter {
|
||||
fn get_table_provider(&self, name: TableReference) -> Option<Arc<dyn TableProvider>> {
|
||||
fn get_table_provider(&self, name: TableReference) -> DfResult<Arc<dyn TableSource>> {
|
||||
let schema = self.query_ctx.current_schema();
|
||||
let execution_ctx = self.state.df_context().state.lock();
|
||||
match name {
|
||||
TableReference::Bare { table } if schema.is_some() => {
|
||||
execution_ctx.get_table_provider(TableReference::Partial {
|
||||
// unwrap safety: checked in this match's arm
|
||||
schema: &schema.unwrap(),
|
||||
table,
|
||||
})
|
||||
}
|
||||
_ => execution_ctx.get_table_provider(name),
|
||||
}
|
||||
self.state.get_table_provider(schema.as_deref(), name)
|
||||
}
|
||||
|
||||
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
|
||||
self.state.df_context().state.lock().get_function_meta(name)
|
||||
self.state.get_function_meta(name)
|
||||
}
|
||||
|
||||
fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
|
||||
@@ -134,10 +124,6 @@ impl ContextProvider for DfContextProviderAdapter {
|
||||
}
|
||||
|
||||
fn get_variable_type(&self, variable_names: &[String]) -> Option<DataType> {
|
||||
self.state
|
||||
.df_context()
|
||||
.state
|
||||
.lock()
|
||||
.get_variable_type(variable_names)
|
||||
self.state.get_variable_type(variable_names)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,12 +44,10 @@ impl OptimizerRule for TypeConversionRule {
|
||||
};
|
||||
|
||||
match plan {
|
||||
LogicalPlan::Filter(Filter { predicate, input }) => {
|
||||
Ok(LogicalPlan::Filter(Filter::try_new(
|
||||
predicate.clone().rewrite(&mut converter)?,
|
||||
Arc::new(self.optimize(input, optimizer_config)?),
|
||||
)?))
|
||||
}
|
||||
LogicalPlan::Filter(filter) => Ok(LogicalPlan::Filter(Filter::try_new(
|
||||
filter.predicate().clone().rewrite(&mut converter)?,
|
||||
Arc::new(self.optimize(filter.input(), optimizer_config)?),
|
||||
)?)),
|
||||
LogicalPlan::TableScan(TableScan {
|
||||
table_name,
|
||||
source,
|
||||
@@ -150,8 +148,8 @@ impl<'a> TypeConverter<'a> {
|
||||
compute::cast(&value_arr, target_type).map_err(DataFusionError::ArrowError)?;
|
||||
|
||||
ScalarValue::try_from_array(
|
||||
&Arc::from(arr), // index: Converts a value in `array` at `index` into a ScalarValue
|
||||
0,
|
||||
&arr,
|
||||
0, // index: Converts a value in `array` at `index` into a ScalarValue
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,14 +21,17 @@ use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
|
||||
use common_function::scalars::aggregate::AggregateFunctionMetaRef;
|
||||
use common_query::physical_plan::{SessionContext, TaskContext};
|
||||
use common_query::prelude::ScalarUdf;
|
||||
use datafusion::optimizer::common_subexpr_eliminate::CommonSubexprEliminate;
|
||||
use datafusion::optimizer::eliminate_limit::EliminateLimit;
|
||||
use datafusion::optimizer::filter_push_down::FilterPushDown;
|
||||
use datafusion::optimizer::limit_push_down::LimitPushDown;
|
||||
use datafusion::optimizer::projection_push_down::ProjectionPushDown;
|
||||
use datafusion::optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy;
|
||||
use datafusion::optimizer::to_approx_perc::ToApproxPerc;
|
||||
use datafusion::execution::context::SessionConfig;
|
||||
use datafusion::catalog::TableReference;
|
||||
use datafusion::error::Result as DfResult;
|
||||
use datafusion::execution::context::{SessionConfig, SessionState};
|
||||
use datafusion::execution::runtime_env::RuntimeEnv;
|
||||
use datafusion::physical_plan::udf::ScalarUDF;
|
||||
use datafusion::physical_plan::ExecutionPlan;
|
||||
use datafusion_expr::{LogicalPlan as DfLogicalPlan, TableSource};
|
||||
use datafusion_optimizer::optimizer::{Optimizer, OptimizerConfig};
|
||||
use datafusion_sql::planner::ContextProvider;
|
||||
use datatypes::arrow::datatypes::DataType;
|
||||
|
||||
use crate::datafusion::DfCatalogListAdapter;
|
||||
use crate::optimizer::TypeConversionRule;
|
||||
|
||||
@@ -38,7 +41,7 @@ use crate::optimizer::TypeConversionRule;
|
||||
// type in QueryEngine trait.
|
||||
#[derive(Clone)]
|
||||
pub struct QueryEngineState {
|
||||
df_context: ExecutionContext,
|
||||
df_context: SessionContext,
|
||||
catalog_list: CatalogListRef,
|
||||
aggregate_functions: Arc<RwLock<HashMap<String, AggregateFunctionMetaRef>>>,
|
||||
}
|
||||
@@ -52,25 +55,18 @@ impl fmt::Debug for QueryEngineState {
|
||||
|
||||
impl QueryEngineState {
|
||||
pub(crate) fn new(catalog_list: CatalogListRef) -> Self {
|
||||
let config = ExecutionConfig::new()
|
||||
.with_default_catalog_and_schema(DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME)
|
||||
.with_optimizer_rules(vec![
|
||||
// TODO(hl): SimplifyExpressions is not exported.
|
||||
Arc::new(TypeConversionRule {}),
|
||||
// These are the default optimizer in datafusion
|
||||
Arc::new(CommonSubexprEliminate::new()),
|
||||
Arc::new(EliminateLimit::new()),
|
||||
Arc::new(ProjectionPushDown::new()),
|
||||
Arc::new(FilterPushDown::new()),
|
||||
Arc::new(LimitPushDown::new()),
|
||||
Arc::new(SingleDistinctToGroupBy::new()),
|
||||
Arc::new(ToApproxPerc::new()),
|
||||
]);
|
||||
let runtime_env = Arc::new(RuntimeEnv::default());
|
||||
let session_config = SessionConfig::new()
|
||||
.with_default_catalog_and_schema(DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME);
|
||||
let mut optimizer = Optimizer::new(&OptimizerConfig::new());
|
||||
// Apply the type conversion rule first.
|
||||
optimizer.rules.insert(0, Arc::new(TypeConversionRule {}));
|
||||
|
||||
let df_context = ExecutionContext::with_config(config);
|
||||
let mut session_state = SessionState::with_config_rt(session_config, runtime_env);
|
||||
session_state.optimizer = optimizer;
|
||||
session_state.catalog_list = Arc::new(DfCatalogListAdapter::new(catalog_list.clone()));
|
||||
|
||||
df_context.state.lock().catalog_list =
|
||||
Arc::new(DfCatalogListAdapter::new(catalog_list.clone()));
|
||||
let df_context = SessionContext::with_state(session_state);
|
||||
|
||||
Self {
|
||||
df_context,
|
||||
@@ -80,11 +76,15 @@ impl QueryEngineState {
|
||||
}
|
||||
|
||||
/// Register a udf function
|
||||
/// TODO(dennis): manage UDFs by ourself.
|
||||
// TODO(dennis): manage UDFs by ourself.
|
||||
pub fn register_udf(&self, udf: ScalarUdf) {
|
||||
// `SessionContext` has a `register_udf()` method, which requires `&mut self`, this is
|
||||
// a workaround.
|
||||
// TODO(yingwen): Use `SessionContext::register_udf()` once it taks `&self`.
|
||||
// It's implemented in https://github.com/apache/arrow-datafusion/pull/4612
|
||||
self.df_context
|
||||
.state
|
||||
.lock()
|
||||
.write()
|
||||
.scalar_functions
|
||||
.insert(udf.name.clone(), Arc::new(udf.into_df_udf()));
|
||||
}
|
||||
@@ -112,12 +112,59 @@ impl QueryEngineState {
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn df_context(&self) -> &ExecutionContext {
|
||||
&self.df_context
|
||||
pub(crate) fn task_ctx(&self) -> Arc<TaskContext> {
|
||||
self.df_context.task_ctx()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn runtime(&self) -> Arc<RuntimeEnv> {
|
||||
self.df_context.runtime_env()
|
||||
pub(crate) fn get_table_provider(
|
||||
&self,
|
||||
schema: Option<&str>,
|
||||
name: TableReference,
|
||||
) -> DfResult<Arc<dyn TableSource>> {
|
||||
let state = self.df_context.state.read();
|
||||
match name {
|
||||
TableReference::Bare { table } if schema.is_some() => {
|
||||
state.get_table_provider(TableReference::Partial {
|
||||
// unwrap safety: checked in this match's arm
|
||||
schema: schema.unwrap(),
|
||||
table,
|
||||
})
|
||||
}
|
||||
_ => state.get_table_provider(name),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
|
||||
let state = self.df_context.state.read();
|
||||
state.get_function_meta(name)
|
||||
}
|
||||
|
||||
pub(crate) fn get_variable_type(&self, variable_names: &[String]) -> Option<DataType> {
|
||||
let state = self.df_context.state.read();
|
||||
state.get_variable_type(variable_names)
|
||||
}
|
||||
|
||||
pub(crate) fn optimize(&self, plan: &DfLogicalPlan) -> DfResult<DfLogicalPlan> {
|
||||
self.df_context.optimize(plan)
|
||||
}
|
||||
|
||||
pub(crate) async fn create_physical_plan(
|
||||
&self,
|
||||
logical_plan: &DfLogicalPlan,
|
||||
) -> DfResult<Arc<dyn ExecutionPlan>> {
|
||||
self.df_context.create_physical_plan(logical_plan).await
|
||||
}
|
||||
|
||||
pub(crate) fn optimize_physical_plan(
|
||||
&self,
|
||||
mut plan: Arc<dyn ExecutionPlan>,
|
||||
) -> DfResult<Arc<dyn ExecutionPlan>> {
|
||||
let state = self.df_context.state.read();
|
||||
let config = &state.config;
|
||||
for optimizer in &state.physical_optimizers {
|
||||
plan = optimizer.optimize(plan, config)?;
|
||||
}
|
||||
|
||||
Ok(plan)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -261,10 +261,9 @@ mod test {
|
||||
use common_query::Output;
|
||||
use common_recordbatch::{RecordBatch, RecordBatches};
|
||||
use common_time::timestamp::TimeUnit;
|
||||
use datatypes::arrow::array::PrimitiveArray;
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use datatypes::schema::{ColumnDefaultConstraint, ColumnSchema, Schema, SchemaRef};
|
||||
use datatypes::vectors::{StringVector, TimestampVector, UInt32Vector, VectorRef};
|
||||
use datatypes::vectors::{StringVector, TimestampMillisecondVector, UInt32Vector, VectorRef};
|
||||
use snafu::ResultExt;
|
||||
use sql::statements::describe::DescribeTable;
|
||||
use table::test_util::MemTable;
|
||||
@@ -379,12 +378,12 @@ mod test {
|
||||
.with_time_index(true),
|
||||
];
|
||||
let data = vec![
|
||||
Arc::new(UInt32Vector::from_vec(vec![0])) as _,
|
||||
Arc::new(TimestampVector::new(PrimitiveArray::from_vec(vec![0]))) as _,
|
||||
Arc::new(UInt32Vector::from_slice(&[0])) as _,
|
||||
Arc::new(TimestampMillisecondVector::from_slice(&[0])) as _,
|
||||
];
|
||||
let expected_columns = vec![
|
||||
Arc::new(StringVector::from(vec!["t1", "t2"])) as _,
|
||||
Arc::new(StringVector::from(vec!["UInt32", "Timestamp"])) as _,
|
||||
Arc::new(StringVector::from(vec!["UInt32", "TimestampMillisecond"])) as _,
|
||||
Arc::new(StringVector::from(vec![NULLABLE_YES, NULLABLE_NO])) as _,
|
||||
Arc::new(StringVector::from(vec!["", "current_timestamp()"])) as _,
|
||||
Arc::new(StringVector::from(vec![
|
||||
|
||||
@@ -12,16 +12,16 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
mod function;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::Output;
|
||||
use common_recordbatch::error::Result as RecordResult;
|
||||
use common_recordbatch::{util, RecordBatch};
|
||||
use datafusion::field_util::{FieldExt, SchemaExt};
|
||||
use datatypes::for_all_primitive_types;
|
||||
use datatypes::prelude::*;
|
||||
use datatypes::types::PrimitiveElement;
|
||||
use function::{create_query_engine, get_numbers_from_table};
|
||||
use datatypes::types::WrapperType;
|
||||
use query::error::Result;
|
||||
use query::QueryEngine;
|
||||
use session::context::QueryContext;
|
||||
@@ -29,7 +29,7 @@ use session::context::QueryContext;
|
||||
#[tokio::test]
|
||||
async fn test_argmax_aggregator() -> Result<()> {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
let engine = create_query_engine();
|
||||
let engine = function::create_query_engine();
|
||||
|
||||
macro_rules! test_argmax {
|
||||
([], $( { $T:ty } ),*) => {
|
||||
@@ -49,33 +49,23 @@ async fn test_argmax_success<T>(
|
||||
engine: Arc<dyn QueryEngine>,
|
||||
) -> Result<()>
|
||||
where
|
||||
T: PrimitiveElement + PartialOrd,
|
||||
for<'a> T: Scalar<RefType<'a> = T>,
|
||||
T: WrapperType + PartialOrd,
|
||||
{
|
||||
let result = execute_argmax(column_name, table_name, engine.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(1, result.len());
|
||||
assert_eq!(result[0].df_recordbatch.num_columns(), 1);
|
||||
assert_eq!(1, result[0].schema.arrow_schema().fields().len());
|
||||
assert_eq!("argmax", result[0].schema.arrow_schema().field(0).name());
|
||||
let value = function::get_value_from_batches("argmax", result);
|
||||
|
||||
let columns = result[0].df_recordbatch.columns();
|
||||
assert_eq!(1, columns.len());
|
||||
assert_eq!(columns[0].len(), 1);
|
||||
let v = VectorHelper::try_into_vector(&columns[0]).unwrap();
|
||||
assert_eq!(1, v.len());
|
||||
let value = v.get(0);
|
||||
|
||||
let numbers = get_numbers_from_table::<T>(column_name, table_name, engine.clone()).await;
|
||||
let numbers =
|
||||
function::get_numbers_from_table::<T>(column_name, table_name, engine.clone()).await;
|
||||
let expected_value = match numbers.len() {
|
||||
0 => 0_u64,
|
||||
_ => {
|
||||
let mut index = 0;
|
||||
let mut max = numbers[0].into();
|
||||
let mut max = numbers[0];
|
||||
for (i, &number) in numbers.iter().enumerate() {
|
||||
if max < number.into() {
|
||||
max = number.into();
|
||||
if max < number {
|
||||
max = number;
|
||||
index = i;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,17 +12,16 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
mod function;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::Output;
|
||||
use common_recordbatch::error::Result as RecordResult;
|
||||
use common_recordbatch::{util, RecordBatch};
|
||||
use datafusion::field_util::{FieldExt, SchemaExt};
|
||||
use datatypes::for_all_primitive_types;
|
||||
use datatypes::prelude::*;
|
||||
use datatypes::types::PrimitiveElement;
|
||||
use function::{create_query_engine, get_numbers_from_table};
|
||||
use datatypes::types::WrapperType;
|
||||
use query::error::Result;
|
||||
use query::QueryEngine;
|
||||
use session::context::QueryContext;
|
||||
@@ -30,7 +29,7 @@ use session::context::QueryContext;
|
||||
#[tokio::test]
|
||||
async fn test_argmin_aggregator() -> Result<()> {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
let engine = create_query_engine();
|
||||
let engine = function::create_query_engine();
|
||||
|
||||
macro_rules! test_argmin {
|
||||
([], $( { $T:ty } ),*) => {
|
||||
@@ -50,33 +49,23 @@ async fn test_argmin_success<T>(
|
||||
engine: Arc<dyn QueryEngine>,
|
||||
) -> Result<()>
|
||||
where
|
||||
T: PrimitiveElement + PartialOrd,
|
||||
for<'a> T: Scalar<RefType<'a> = T>,
|
||||
T: WrapperType + PartialOrd,
|
||||
{
|
||||
let result = execute_argmin(column_name, table_name, engine.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(1, result.len());
|
||||
assert_eq!(result[0].df_recordbatch.num_columns(), 1);
|
||||
assert_eq!(1, result[0].schema.arrow_schema().fields().len());
|
||||
assert_eq!("argmin", result[0].schema.arrow_schema().field(0).name());
|
||||
let value = function::get_value_from_batches("argmin", result);
|
||||
|
||||
let columns = result[0].df_recordbatch.columns();
|
||||
assert_eq!(1, columns.len());
|
||||
assert_eq!(columns[0].len(), 1);
|
||||
let v = VectorHelper::try_into_vector(&columns[0]).unwrap();
|
||||
assert_eq!(1, v.len());
|
||||
let value = v.get(0);
|
||||
|
||||
let numbers = get_numbers_from_table::<T>(column_name, table_name, engine.clone()).await;
|
||||
let numbers =
|
||||
function::get_numbers_from_table::<T>(column_name, table_name, engine.clone()).await;
|
||||
let expected_value = match numbers.len() {
|
||||
0 => 0_u32,
|
||||
_ => {
|
||||
let mut index = 0;
|
||||
let mut min = numbers[0].into();
|
||||
let mut min = numbers[0];
|
||||
for (i, &number) in numbers.iter().enumerate() {
|
||||
if min > number.into() {
|
||||
min = number.into();
|
||||
if min > number {
|
||||
min = number;
|
||||
index = i;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// FIXME(yingwen): Consider move all tests under query/tests to query/src so we could reuse
|
||||
// more codes.
|
||||
use std::sync::Arc;
|
||||
|
||||
use catalog::local::{MemoryCatalogManager, MemoryCatalogProvider, MemorySchemaProvider};
|
||||
@@ -22,8 +24,8 @@ use common_recordbatch::{util, RecordBatch};
|
||||
use datatypes::for_all_primitive_types;
|
||||
use datatypes::prelude::*;
|
||||
use datatypes::schema::{ColumnSchema, Schema};
|
||||
use datatypes::types::PrimitiveElement;
|
||||
use datatypes::vectors::PrimitiveVector;
|
||||
use datatypes::types::WrapperType;
|
||||
use datatypes::vectors::Helper;
|
||||
use query::query_engine::QueryEngineFactory;
|
||||
use query::QueryEngine;
|
||||
use rand::Rng;
|
||||
@@ -47,7 +49,7 @@ pub fn create_query_engine() -> Arc<dyn QueryEngine> {
|
||||
column_schemas.push(column_schema);
|
||||
|
||||
let numbers = (1..=10).map(|_| rng.gen::<$T>()).collect::<Vec<$T>>();
|
||||
let column: VectorRef = Arc::new(PrimitiveVector::<$T>::from_vec(numbers.to_vec()));
|
||||
let column: VectorRef = Arc::new(<$T as Scalar>::VectorType::from_vec(numbers.to_vec()));
|
||||
columns.push(column);
|
||||
)*
|
||||
}
|
||||
@@ -77,8 +79,7 @@ pub async fn get_numbers_from_table<'s, T>(
|
||||
engine: Arc<dyn QueryEngine>,
|
||||
) -> Vec<T>
|
||||
where
|
||||
T: PrimitiveElement,
|
||||
for<'a> T: Scalar<RefType<'a> = T>,
|
||||
T: WrapperType,
|
||||
{
|
||||
let sql = format!("SELECT {} FROM {}", column_name, table_name);
|
||||
let plan = engine
|
||||
@@ -92,8 +93,21 @@ where
|
||||
};
|
||||
let numbers = util::collect(recordbatch_stream).await.unwrap();
|
||||
|
||||
let columns = numbers[0].df_recordbatch.columns();
|
||||
let column = VectorHelper::try_into_vector(&columns[0]).unwrap();
|
||||
let column: &<T as Scalar>::VectorType = unsafe { VectorHelper::static_cast(&column) };
|
||||
let column = numbers[0].column(0);
|
||||
let column: &<T as Scalar>::VectorType = unsafe { Helper::static_cast(column) };
|
||||
column.iter_data().flatten().collect::<Vec<T>>()
|
||||
}
|
||||
|
||||
pub fn get_value_from_batches(column_name: &str, batches: Vec<RecordBatch>) -> Value {
|
||||
assert_eq!(1, batches.len());
|
||||
assert_eq!(batches[0].num_columns(), 1);
|
||||
assert_eq!(1, batches[0].schema.num_columns());
|
||||
assert_eq!(column_name, batches[0].schema.column_schemas()[0].name);
|
||||
|
||||
let batch = &batches[0];
|
||||
assert_eq!(1, batch.num_columns());
|
||||
assert_eq!(batch.column(0).len(), 1);
|
||||
let v = batch.column(0);
|
||||
assert_eq!(1, v.len());
|
||||
v.get(0)
|
||||
}
|
||||
|
||||
@@ -12,19 +12,18 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
mod function;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::Output;
|
||||
use common_recordbatch::error::Result as RecordResult;
|
||||
use common_recordbatch::{util, RecordBatch};
|
||||
use datafusion::field_util::{FieldExt, SchemaExt};
|
||||
use datatypes::for_all_primitive_types;
|
||||
use datatypes::prelude::*;
|
||||
use datatypes::types::PrimitiveElement;
|
||||
use datatypes::types::WrapperType;
|
||||
use datatypes::value::OrderedFloat;
|
||||
use format_num::NumberFormat;
|
||||
use function::{create_query_engine, get_numbers_from_table};
|
||||
use num_traits::AsPrimitive;
|
||||
use query::error::Result;
|
||||
use query::QueryEngine;
|
||||
@@ -33,7 +32,7 @@ use session::context::QueryContext;
|
||||
#[tokio::test]
|
||||
async fn test_mean_aggregator() -> Result<()> {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
let engine = create_query_engine();
|
||||
let engine = function::create_query_engine();
|
||||
|
||||
macro_rules! test_mean {
|
||||
([], $( { $T:ty } ),*) => {
|
||||
@@ -53,25 +52,15 @@ async fn test_mean_success<T>(
|
||||
engine: Arc<dyn QueryEngine>,
|
||||
) -> Result<()>
|
||||
where
|
||||
T: PrimitiveElement + AsPrimitive<f64>,
|
||||
for<'a> T: Scalar<RefType<'a> = T>,
|
||||
T: WrapperType + AsPrimitive<f64>,
|
||||
{
|
||||
let result = execute_mean(column_name, table_name, engine.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(1, result.len());
|
||||
assert_eq!(result[0].df_recordbatch.num_columns(), 1);
|
||||
assert_eq!(1, result[0].schema.arrow_schema().fields().len());
|
||||
assert_eq!("mean", result[0].schema.arrow_schema().field(0).name());
|
||||
let value = function::get_value_from_batches("mean", result);
|
||||
|
||||
let columns = result[0].df_recordbatch.columns();
|
||||
assert_eq!(1, columns.len());
|
||||
assert_eq!(columns[0].len(), 1);
|
||||
let v = VectorHelper::try_into_vector(&columns[0]).unwrap();
|
||||
assert_eq!(1, v.len());
|
||||
let value = v.get(0);
|
||||
|
||||
let numbers = get_numbers_from_table::<T>(column_name, table_name, engine.clone()).await;
|
||||
let numbers =
|
||||
function::get_numbers_from_table::<T>(column_name, table_name, engine.clone()).await;
|
||||
let expected_value = numbers.iter().map(|&n| n.as_()).collect::<Vec<f64>>();
|
||||
|
||||
let expected_value = inc_stats::mean(expected_value.iter().cloned()).unwrap();
|
||||
|
||||
@@ -26,12 +26,10 @@ use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
|
||||
use common_query::prelude::*;
|
||||
use common_query::Output;
|
||||
use common_recordbatch::{util, RecordBatch};
|
||||
use datafusion::arrow_print;
|
||||
use datafusion_common::record_batch::RecordBatch as DfRecordBatch;
|
||||
use datatypes::prelude::*;
|
||||
use datatypes::schema::{ColumnSchema, Schema};
|
||||
use datatypes::types::{PrimitiveElement, PrimitiveType};
|
||||
use datatypes::vectors::PrimitiveVector;
|
||||
use datatypes::types::{LogicalPrimitiveType, WrapperType};
|
||||
use datatypes::vectors::Helper;
|
||||
use datatypes::with_match_primitive_type_id;
|
||||
use num_traits::AsPrimitive;
|
||||
use query::error::Result;
|
||||
@@ -40,28 +38,30 @@ use session::context::QueryContext;
|
||||
use table::test_util::MemTable;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct MySumAccumulator<T, SumT>
|
||||
where
|
||||
T: Primitive + AsPrimitive<SumT>,
|
||||
SumT: Primitive + std::ops::AddAssign,
|
||||
{
|
||||
struct MySumAccumulator<T, SumT> {
|
||||
sum: SumT,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T, SumT> MySumAccumulator<T, SumT>
|
||||
where
|
||||
T: Primitive + AsPrimitive<SumT>,
|
||||
SumT: Primitive + std::ops::AddAssign,
|
||||
T: WrapperType,
|
||||
SumT: WrapperType,
|
||||
T::Native: AsPrimitive<SumT::Native>,
|
||||
SumT::Native: std::ops::AddAssign,
|
||||
{
|
||||
#[inline(always)]
|
||||
fn add(&mut self, v: T) {
|
||||
self.sum += v.as_();
|
||||
let mut sum_native = self.sum.into_native();
|
||||
sum_native += v.into_native().as_();
|
||||
self.sum = SumT::from_native(sum_native);
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn merge(&mut self, s: SumT) {
|
||||
self.sum += s;
|
||||
let mut sum_native = self.sum.into_native();
|
||||
sum_native += s.into_native();
|
||||
self.sum = SumT::from_native(sum_native);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,7 +76,7 @@ impl AggregateFunctionCreator for MySumAccumulatorCreator {
|
||||
with_match_primitive_type_id!(
|
||||
input_type.logical_type_id(),
|
||||
|$S| {
|
||||
Ok(Box::new(MySumAccumulator::<$S, <$S as Primitive>::LargestType>::default()))
|
||||
Ok(Box::new(MySumAccumulator::<<$S as LogicalPrimitiveType>::Wrapper, <<$S as LogicalPrimitiveType>::LargestType as LogicalPrimitiveType>::Wrapper>::default()))
|
||||
},
|
||||
{
|
||||
let err_msg = format!(
|
||||
@@ -95,7 +95,7 @@ impl AggregateFunctionCreator for MySumAccumulatorCreator {
|
||||
with_match_primitive_type_id!(
|
||||
input_type.logical_type_id(),
|
||||
|$S| {
|
||||
Ok(PrimitiveType::<<$S as Primitive>::LargestType>::default().logical_type_id().data_type())
|
||||
Ok(<<$S as LogicalPrimitiveType>::LargestType>::build_data_type())
|
||||
},
|
||||
{
|
||||
unreachable!()
|
||||
@@ -110,10 +110,10 @@ impl AggregateFunctionCreator for MySumAccumulatorCreator {
|
||||
|
||||
impl<T, SumT> Accumulator for MySumAccumulator<T, SumT>
|
||||
where
|
||||
T: Primitive + AsPrimitive<SumT>,
|
||||
for<'a> T: Scalar<RefType<'a> = T>,
|
||||
SumT: Primitive + std::ops::AddAssign,
|
||||
for<'a> SumT: Scalar<RefType<'a> = SumT>,
|
||||
T: WrapperType,
|
||||
SumT: WrapperType,
|
||||
T::Native: AsPrimitive<SumT::Native>,
|
||||
SumT::Native: std::ops::AddAssign,
|
||||
{
|
||||
fn state(&self) -> QueryResult<Vec<Value>> {
|
||||
Ok(vec![self.sum.into()])
|
||||
@@ -124,7 +124,7 @@ where
|
||||
return Ok(());
|
||||
};
|
||||
let column = &values[0];
|
||||
let column: &<T as Scalar>::VectorType = unsafe { VectorHelper::static_cast(column) };
|
||||
let column: &<T as Scalar>::VectorType = unsafe { Helper::static_cast(column) };
|
||||
for v in column.iter_data().flatten() {
|
||||
self.add(v)
|
||||
}
|
||||
@@ -136,7 +136,7 @@ where
|
||||
return Ok(());
|
||||
};
|
||||
let states = &states[0];
|
||||
let states: &<SumT as Scalar>::VectorType = unsafe { VectorHelper::static_cast(states) };
|
||||
let states: &<SumT as Scalar>::VectorType = unsafe { Helper::static_cast(states) };
|
||||
for s in states.iter_data().flatten() {
|
||||
self.merge(s)
|
||||
}
|
||||
@@ -154,65 +154,57 @@ async fn test_my_sum() -> Result<()> {
|
||||
|
||||
test_my_sum_with(
|
||||
(1..=10).collect::<Vec<u32>>(),
|
||||
vec![
|
||||
"+--------+",
|
||||
"| my_sum |",
|
||||
"+--------+",
|
||||
"| 55 |",
|
||||
"+--------+",
|
||||
],
|
||||
r#"+--------+
|
||||
| my_sum |
|
||||
+--------+
|
||||
| 55 |
|
||||
+--------+"#,
|
||||
)
|
||||
.await?;
|
||||
test_my_sum_with(
|
||||
(-10..=11).collect::<Vec<i32>>(),
|
||||
vec![
|
||||
"+--------+",
|
||||
"| my_sum |",
|
||||
"+--------+",
|
||||
"| 11 |",
|
||||
"+--------+",
|
||||
],
|
||||
r#"+--------+
|
||||
| my_sum |
|
||||
+--------+
|
||||
| 11 |
|
||||
+--------+"#,
|
||||
)
|
||||
.await?;
|
||||
test_my_sum_with(
|
||||
vec![-1.0f32, 1.0, 2.0, 3.0, 4.0],
|
||||
vec![
|
||||
"+--------+",
|
||||
"| my_sum |",
|
||||
"+--------+",
|
||||
"| 9 |",
|
||||
"+--------+",
|
||||
],
|
||||
r#"+--------+
|
||||
| my_sum |
|
||||
+--------+
|
||||
| 9 |
|
||||
+--------+"#,
|
||||
)
|
||||
.await?;
|
||||
test_my_sum_with(
|
||||
vec![u32::MAX, u32::MAX],
|
||||
vec![
|
||||
"+------------+",
|
||||
"| my_sum |",
|
||||
"+------------+",
|
||||
"| 8589934590 |",
|
||||
"+------------+",
|
||||
],
|
||||
r#"+------------+
|
||||
| my_sum |
|
||||
+------------+
|
||||
| 8589934590 |
|
||||
+------------+"#,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn test_my_sum_with<T>(numbers: Vec<T>, expected: Vec<&str>) -> Result<()>
|
||||
async fn test_my_sum_with<T>(numbers: Vec<T>, expected: &str) -> Result<()>
|
||||
where
|
||||
T: PrimitiveElement,
|
||||
T: WrapperType,
|
||||
{
|
||||
let table_name = format!("{}_numbers", std::any::type_name::<T>());
|
||||
let column_name = format!("{}_number", std::any::type_name::<T>());
|
||||
|
||||
let column_schemas = vec![ColumnSchema::new(
|
||||
column_name.clone(),
|
||||
T::build_data_type(),
|
||||
T::LogicalType::build_data_type(),
|
||||
true,
|
||||
)];
|
||||
let schema = Arc::new(Schema::new(column_schemas.clone()));
|
||||
let column: VectorRef = Arc::new(PrimitiveVector::<T>::from_vec(numbers));
|
||||
let column: VectorRef = Arc::new(T::VectorType::from_vec(numbers));
|
||||
let recordbatch = RecordBatch::new(schema, vec![column]).unwrap();
|
||||
let testing_table = MemTable::new(&table_name, recordbatch);
|
||||
|
||||
@@ -236,14 +228,9 @@ where
|
||||
Output::Stream(batch) => batch,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let recordbatch = util::collect(recordbatch_stream).await.unwrap();
|
||||
let df_recordbatch = recordbatch
|
||||
.into_iter()
|
||||
.map(|r| r.df_recordbatch)
|
||||
.collect::<Vec<DfRecordBatch>>();
|
||||
let batches = util::collect_batches(recordbatch_stream).await.unwrap();
|
||||
|
||||
let pretty_print = arrow_print::write(&df_recordbatch);
|
||||
let pretty_print = pretty_print.lines().collect::<Vec<&str>>();
|
||||
let pretty_print = batches.pretty_print().unwrap();
|
||||
assert_eq!(expected, pretty_print);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -20,12 +20,10 @@ use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
|
||||
use common_query::Output;
|
||||
use common_recordbatch::error::Result as RecordResult;
|
||||
use common_recordbatch::{util, RecordBatch};
|
||||
use datafusion::field_util::{FieldExt, SchemaExt};
|
||||
use datatypes::for_all_primitive_types;
|
||||
use datatypes::prelude::*;
|
||||
use datatypes::schema::{ColumnSchema, Schema};
|
||||
use datatypes::types::PrimitiveElement;
|
||||
use datatypes::vectors::PrimitiveVector;
|
||||
use datatypes::vectors::Int32Vector;
|
||||
use function::{create_query_engine, get_numbers_from_table};
|
||||
use num_traits::AsPrimitive;
|
||||
use query::error::Result;
|
||||
@@ -64,9 +62,8 @@ async fn test_percentile_correctness() -> Result<()> {
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let record_batch = util::collect(recordbatch_stream).await.unwrap();
|
||||
let columns = record_batch[0].df_recordbatch.columns();
|
||||
let v = VectorHelper::try_into_vector(&columns[0]).unwrap();
|
||||
let value = v.get(0);
|
||||
let column = record_batch[0].column(0);
|
||||
let value = column.get(0);
|
||||
assert_eq!(value, Value::from(9.280_000_000_000_001_f64));
|
||||
Ok(())
|
||||
}
|
||||
@@ -77,26 +74,12 @@ async fn test_percentile_success<T>(
|
||||
engine: Arc<dyn QueryEngine>,
|
||||
) -> Result<()>
|
||||
where
|
||||
T: PrimitiveElement + AsPrimitive<f64>,
|
||||
for<'a> T: Scalar<RefType<'a> = T>,
|
||||
T: WrapperType + AsPrimitive<f64>,
|
||||
{
|
||||
let result = execute_percentile(column_name, table_name, engine.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(1, result.len());
|
||||
assert_eq!(result[0].df_recordbatch.num_columns(), 1);
|
||||
assert_eq!(1, result[0].schema.arrow_schema().fields().len());
|
||||
assert_eq!(
|
||||
"percentile",
|
||||
result[0].schema.arrow_schema().field(0).name()
|
||||
);
|
||||
|
||||
let columns = result[0].df_recordbatch.columns();
|
||||
assert_eq!(1, columns.len());
|
||||
assert_eq!(columns[0].len(), 1);
|
||||
let v = VectorHelper::try_into_vector(&columns[0]).unwrap();
|
||||
assert_eq!(1, v.len());
|
||||
let value = v.get(0);
|
||||
let value = function::get_value_from_batches("percentile", result);
|
||||
|
||||
let numbers = get_numbers_from_table::<T>(column_name, table_name, engine.clone()).await;
|
||||
let expected_value = numbers.iter().map(|&n| n.as_()).collect::<Vec<f64>>();
|
||||
@@ -140,9 +123,9 @@ fn create_correctness_engine() -> Arc<dyn QueryEngine> {
|
||||
let column_schema = ColumnSchema::new("corr_number", ConcreteDataType::int32_datatype(), true);
|
||||
column_schemas.push(column_schema);
|
||||
|
||||
let numbers = vec![3_i32, 6_i32, 8_i32, 10_i32];
|
||||
let numbers = [3_i32, 6_i32, 8_i32, 10_i32];
|
||||
|
||||
let column: VectorRef = Arc::new(PrimitiveVector::<i32>::from_vec(numbers.to_vec()));
|
||||
let column: VectorRef = Arc::new(Int32Vector::from_slice(&numbers));
|
||||
columns.push(column);
|
||||
|
||||
let schema = Arc::new(Schema::new(column_schemas));
|
||||
|
||||
@@ -18,11 +18,9 @@ mod function;
|
||||
use common_query::Output;
|
||||
use common_recordbatch::error::Result as RecordResult;
|
||||
use common_recordbatch::{util, RecordBatch};
|
||||
use datafusion::field_util::{FieldExt, SchemaExt};
|
||||
use datatypes::for_all_primitive_types;
|
||||
use datatypes::prelude::*;
|
||||
use datatypes::types::PrimitiveElement;
|
||||
use function::{create_query_engine, get_numbers_from_table};
|
||||
use datatypes::types::WrapperType;
|
||||
use num_traits::AsPrimitive;
|
||||
use query::error::Result;
|
||||
use query::QueryEngine;
|
||||
@@ -31,13 +29,13 @@ use session::context::QueryContext;
|
||||
#[tokio::test]
|
||||
async fn test_polyval_aggregator() -> Result<()> {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
let engine = create_query_engine();
|
||||
let engine = function::create_query_engine();
|
||||
|
||||
macro_rules! test_polyval {
|
||||
([], $( { $T:ty } ),*) => {
|
||||
$(
|
||||
let column_name = format!("{}_number", std::any::type_name::<$T>());
|
||||
test_polyval_success::<$T,<$T as Primitive>::LargestType>(&column_name, "numbers", engine.clone()).await?;
|
||||
test_polyval_success::<$T, <<<$T as WrapperType>::LogicalType as LogicalPrimitiveType>::LargestType as LogicalPrimitiveType>::Wrapper>(&column_name, "numbers", engine.clone()).await?;
|
||||
)*
|
||||
}
|
||||
}
|
||||
@@ -51,36 +49,27 @@ async fn test_polyval_success<T, PolyT>(
|
||||
engine: Arc<dyn QueryEngine>,
|
||||
) -> Result<()>
|
||||
where
|
||||
T: Primitive + AsPrimitive<PolyT> + PrimitiveElement,
|
||||
PolyT: Primitive + std::ops::Mul<Output = PolyT> + std::iter::Sum,
|
||||
for<'a> T: Scalar<RefType<'a> = T>,
|
||||
for<'a> PolyT: Scalar<RefType<'a> = PolyT>,
|
||||
i64: AsPrimitive<PolyT>,
|
||||
T: WrapperType,
|
||||
PolyT: WrapperType,
|
||||
T::Native: AsPrimitive<PolyT::Native>,
|
||||
PolyT::Native: std::ops::Mul<Output = PolyT::Native> + std::iter::Sum,
|
||||
i64: AsPrimitive<PolyT::Native>,
|
||||
{
|
||||
let result = execute_polyval(column_name, table_name, engine.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(1, result.len());
|
||||
assert_eq!(result[0].df_recordbatch.num_columns(), 1);
|
||||
assert_eq!(1, result[0].schema.arrow_schema().fields().len());
|
||||
assert_eq!("polyval", result[0].schema.arrow_schema().field(0).name());
|
||||
let value = function::get_value_from_batches("polyval", result);
|
||||
|
||||
let columns = result[0].df_recordbatch.columns();
|
||||
assert_eq!(1, columns.len());
|
||||
assert_eq!(columns[0].len(), 1);
|
||||
let v = VectorHelper::try_into_vector(&columns[0]).unwrap();
|
||||
assert_eq!(1, v.len());
|
||||
let value = v.get(0);
|
||||
|
||||
let numbers = get_numbers_from_table::<T>(column_name, table_name, engine.clone()).await;
|
||||
let numbers =
|
||||
function::get_numbers_from_table::<T>(column_name, table_name, engine.clone()).await;
|
||||
let expected_value = numbers.iter().copied();
|
||||
let x = 0i64;
|
||||
let len = expected_value.len();
|
||||
let expected_value: PolyT = expected_value
|
||||
let expected_native: PolyT::Native = expected_value
|
||||
.enumerate()
|
||||
.map(|(i, value)| value.as_() * (x.pow((len - 1 - i) as u32)).as_())
|
||||
.map(|(i, v)| v.into_native().as_() * (x.pow((len - 1 - i) as u32)).as_())
|
||||
.sum();
|
||||
assert_eq!(value, expected_value.into());
|
||||
assert_eq!(value, PolyT::from_native(expected_native).into());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ pub fn pow(args: &[VectorRef]) -> Result<VectorRef> {
|
||||
|
||||
assert_eq!(exponent.len(), base.len());
|
||||
|
||||
let v = base
|
||||
let iter = base
|
||||
.iter_data()
|
||||
.zip(exponent.iter_data())
|
||||
.map(|(base, exponent)| {
|
||||
@@ -42,8 +42,8 @@ pub fn pow(args: &[VectorRef]) -> Result<VectorRef> {
|
||||
(Some(base), Some(exponent)) => Some(base.pow(exponent)),
|
||||
_ => None,
|
||||
}
|
||||
})
|
||||
.collect::<UInt32Vector>();
|
||||
});
|
||||
let v = UInt32Vector::from_owned_iterator(iter);
|
||||
|
||||
Ok(Arc::new(v) as _)
|
||||
}
|
||||
|
||||
@@ -13,30 +13,28 @@
|
||||
// limitations under the License.
|
||||
|
||||
mod pow;
|
||||
// This is used to suppress the warning: function `create_query_engine` is never used.
|
||||
// FIXME(yingwen): We finally need to refactor these tests and move them to `query/src`
|
||||
// so tests can share codes with other mods.
|
||||
#[allow(unused)]
|
||||
mod function;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use catalog::local::{MemoryCatalogManager, MemoryCatalogProvider, MemorySchemaProvider};
|
||||
use catalog::local::{MemoryCatalogProvider, MemorySchemaProvider};
|
||||
use catalog::{CatalogList, CatalogProvider, SchemaProvider};
|
||||
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
|
||||
use common_query::prelude::{create_udf, make_scalar_function, Volatility};
|
||||
use common_query::Output;
|
||||
use common_recordbatch::error::Result as RecordResult;
|
||||
use common_recordbatch::{util, RecordBatch};
|
||||
use datafusion::field_util::{FieldExt, SchemaExt};
|
||||
use datafusion::logical_plan::LogicalPlanBuilder;
|
||||
use datatypes::arrow::array::UInt32Array;
|
||||
use datatypes::for_all_primitive_types;
|
||||
use datafusion::datasource::DefaultTableSource;
|
||||
use datafusion_expr::logical_plan::builder::LogicalPlanBuilder;
|
||||
use datatypes::prelude::*;
|
||||
use datatypes::schema::{ColumnSchema, Schema};
|
||||
use datatypes::types::{OrdPrimitive, PrimitiveElement};
|
||||
use datatypes::vectors::{PrimitiveVector, UInt32Vector};
|
||||
use num::NumCast;
|
||||
use datatypes::vectors::UInt32Vector;
|
||||
use query::error::Result;
|
||||
use query::plan::LogicalPlan;
|
||||
use query::query_engine::QueryEngineFactory;
|
||||
use query::QueryEngine;
|
||||
use rand::Rng;
|
||||
use session::context::QueryContext;
|
||||
use table::table::adapter::DfTableProviderAdapter;
|
||||
use table::table::numbers::NumbersTable;
|
||||
@@ -66,12 +64,16 @@ async fn test_datafusion_query_engine() -> Result<()> {
|
||||
let limit = 10;
|
||||
let table_provider = Arc::new(DfTableProviderAdapter::new(table.clone()));
|
||||
let plan = LogicalPlan::DfPlan(
|
||||
LogicalPlanBuilder::scan("numbers", table_provider, None)
|
||||
.unwrap()
|
||||
.limit(limit)
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap(),
|
||||
LogicalPlanBuilder::scan(
|
||||
"numbers",
|
||||
Arc::new(DefaultTableSource { table_provider }),
|
||||
None,
|
||||
)
|
||||
.unwrap()
|
||||
.limit(0, Some(limit))
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
let output = engine.execute(&plan).await?;
|
||||
@@ -84,17 +86,17 @@ async fn test_datafusion_query_engine() -> Result<()> {
|
||||
let numbers = util::collect(recordbatch).await.unwrap();
|
||||
|
||||
assert_eq!(1, numbers.len());
|
||||
assert_eq!(numbers[0].df_recordbatch.num_columns(), 1);
|
||||
assert_eq!(1, numbers[0].schema.arrow_schema().fields().len());
|
||||
assert_eq!("number", numbers[0].schema.arrow_schema().field(0).name());
|
||||
assert_eq!(numbers[0].num_columns(), 1);
|
||||
assert_eq!(1, numbers[0].schema.num_columns());
|
||||
assert_eq!("number", numbers[0].schema.column_schemas()[0].name);
|
||||
|
||||
let columns = numbers[0].df_recordbatch.columns();
|
||||
assert_eq!(1, columns.len());
|
||||
assert_eq!(columns[0].len(), limit);
|
||||
let batch = &numbers[0];
|
||||
assert_eq!(1, batch.num_columns());
|
||||
assert_eq!(batch.column(0).len(), limit);
|
||||
let expected: Vec<u32> = (0u32..limit as u32).collect();
|
||||
assert_eq!(
|
||||
*columns[0].as_any().downcast_ref::<UInt32Array>().unwrap(),
|
||||
UInt32Array::from_slice(&expected)
|
||||
*batch.column(0),
|
||||
Arc::new(UInt32Vector::from_slice(&expected)) as VectorRef
|
||||
);
|
||||
|
||||
Ok(())
|
||||
@@ -123,7 +125,8 @@ async fn test_udf() -> Result<()> {
|
||||
let pow = make_scalar_function(pow);
|
||||
|
||||
let udf = create_udf(
|
||||
"pow",
|
||||
// datafusion already supports pow, so we use a different name.
|
||||
"my_pow",
|
||||
vec![
|
||||
ConcreteDataType::uint32_datatype(),
|
||||
ConcreteDataType::uint32_datatype(),
|
||||
@@ -136,7 +139,7 @@ async fn test_udf() -> Result<()> {
|
||||
engine.register_udf(udf);
|
||||
|
||||
let plan = engine.sql_to_plan(
|
||||
"select pow(number, number) as p from numbers limit 10",
|
||||
"select my_pow(number, number) as p from numbers limit 10",
|
||||
Arc::new(QueryContext::new()),
|
||||
)?;
|
||||
|
||||
@@ -148,202 +151,18 @@ async fn test_udf() -> Result<()> {
|
||||
|
||||
let numbers = util::collect(recordbatch).await.unwrap();
|
||||
assert_eq!(1, numbers.len());
|
||||
assert_eq!(numbers[0].df_recordbatch.num_columns(), 1);
|
||||
assert_eq!(1, numbers[0].schema.arrow_schema().fields().len());
|
||||
assert_eq!("p", numbers[0].schema.arrow_schema().field(0).name());
|
||||
assert_eq!(numbers[0].num_columns(), 1);
|
||||
assert_eq!(1, numbers[0].schema.num_columns());
|
||||
assert_eq!("p", numbers[0].schema.column_schemas()[0].name);
|
||||
|
||||
let columns = numbers[0].df_recordbatch.columns();
|
||||
assert_eq!(1, columns.len());
|
||||
assert_eq!(columns[0].len(), 10);
|
||||
let batch = &numbers[0];
|
||||
assert_eq!(1, batch.num_columns());
|
||||
assert_eq!(batch.column(0).len(), 10);
|
||||
let expected: Vec<u32> = vec![1, 1, 4, 27, 256, 3125, 46656, 823543, 16777216, 387420489];
|
||||
assert_eq!(
|
||||
*columns[0].as_any().downcast_ref::<UInt32Array>().unwrap(),
|
||||
UInt32Array::from_slice(&expected)
|
||||
*batch.column(0),
|
||||
Arc::new(UInt32Vector::from_slice(&expected)) as VectorRef
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn create_query_engine() -> Arc<dyn QueryEngine> {
|
||||
let schema_provider = Arc::new(MemorySchemaProvider::new());
|
||||
let catalog_provider = Arc::new(MemoryCatalogProvider::new());
|
||||
let catalog_list = Arc::new(MemoryCatalogManager::default());
|
||||
|
||||
// create table with primitives, and all columns' length are even
|
||||
let mut column_schemas = vec![];
|
||||
let mut columns = vec![];
|
||||
macro_rules! create_even_number_table {
|
||||
([], $( { $T:ty } ),*) => {
|
||||
$(
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let column_name = format!("{}_number_even", std::any::type_name::<$T>());
|
||||
let column_schema = ColumnSchema::new(column_name, Value::from(<$T>::default()).data_type(), true);
|
||||
column_schemas.push(column_schema);
|
||||
|
||||
let numbers = (1..=100).map(|_| rng.gen::<$T>()).collect::<Vec<$T>>();
|
||||
let column: VectorRef = Arc::new(PrimitiveVector::<$T>::from_vec(numbers.to_vec()));
|
||||
columns.push(column);
|
||||
)*
|
||||
}
|
||||
}
|
||||
for_all_primitive_types! { create_even_number_table }
|
||||
|
||||
let schema = Arc::new(Schema::new(column_schemas.clone()));
|
||||
let recordbatch = RecordBatch::new(schema, columns).unwrap();
|
||||
let even_number_table = Arc::new(MemTable::new("even_numbers", recordbatch));
|
||||
schema_provider
|
||||
.register_table(
|
||||
even_number_table.table_name().to_string(),
|
||||
even_number_table,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// create table with primitives, and all columns' length are odd
|
||||
let mut column_schemas = vec![];
|
||||
let mut columns = vec![];
|
||||
macro_rules! create_odd_number_table {
|
||||
([], $( { $T:ty } ),*) => {
|
||||
$(
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let column_name = format!("{}_number_odd", std::any::type_name::<$T>());
|
||||
let column_schema = ColumnSchema::new(column_name, Value::from(<$T>::default()).data_type(), true);
|
||||
column_schemas.push(column_schema);
|
||||
|
||||
let numbers = (1..=99).map(|_| rng.gen::<$T>()).collect::<Vec<$T>>();
|
||||
let column: VectorRef = Arc::new(PrimitiveVector::<$T>::from_vec(numbers.to_vec()));
|
||||
columns.push(column);
|
||||
)*
|
||||
}
|
||||
}
|
||||
for_all_primitive_types! { create_odd_number_table }
|
||||
|
||||
let schema = Arc::new(Schema::new(column_schemas.clone()));
|
||||
let recordbatch = RecordBatch::new(schema, columns).unwrap();
|
||||
let odd_number_table = Arc::new(MemTable::new("odd_numbers", recordbatch));
|
||||
schema_provider
|
||||
.register_table(odd_number_table.table_name().to_string(), odd_number_table)
|
||||
.unwrap();
|
||||
|
||||
catalog_provider
|
||||
.register_schema(DEFAULT_SCHEMA_NAME.to_string(), schema_provider)
|
||||
.unwrap();
|
||||
catalog_list
|
||||
.register_catalog(DEFAULT_CATALOG_NAME.to_string(), catalog_provider)
|
||||
.unwrap();
|
||||
|
||||
QueryEngineFactory::new(catalog_list).query_engine()
|
||||
}
|
||||
|
||||
async fn get_numbers_from_table<'s, T>(
|
||||
column_name: &'s str,
|
||||
table_name: &'s str,
|
||||
engine: Arc<dyn QueryEngine>,
|
||||
) -> Vec<OrdPrimitive<T>>
|
||||
where
|
||||
T: PrimitiveElement,
|
||||
for<'a> T: Scalar<RefType<'a> = T>,
|
||||
{
|
||||
let sql = format!("SELECT {} FROM {}", column_name, table_name);
|
||||
let plan = engine
|
||||
.sql_to_plan(&sql, Arc::new(QueryContext::new()))
|
||||
.unwrap();
|
||||
|
||||
let output = engine.execute(&plan).await.unwrap();
|
||||
let recordbatch_stream = match output {
|
||||
Output::Stream(batch) => batch,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let numbers = util::collect(recordbatch_stream).await.unwrap();
|
||||
|
||||
let columns = numbers[0].df_recordbatch.columns();
|
||||
let column = VectorHelper::try_into_vector(&columns[0]).unwrap();
|
||||
let column: &<T as Scalar>::VectorType = unsafe { VectorHelper::static_cast(&column) };
|
||||
column
|
||||
.iter_data()
|
||||
.flatten()
|
||||
.map(|x| OrdPrimitive::<T>(x))
|
||||
.collect::<Vec<OrdPrimitive<T>>>()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_median_aggregator() -> Result<()> {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let engine = create_query_engine();
|
||||
|
||||
macro_rules! test_median {
|
||||
([], $( { $T:ty } ),*) => {
|
||||
$(
|
||||
let column_name = format!("{}_number_even", std::any::type_name::<$T>());
|
||||
test_median_success::<$T>(&column_name, "even_numbers", engine.clone()).await?;
|
||||
|
||||
let column_name = format!("{}_number_odd", std::any::type_name::<$T>());
|
||||
test_median_success::<$T>(&column_name, "odd_numbers", engine.clone()).await?;
|
||||
)*
|
||||
}
|
||||
}
|
||||
for_all_primitive_types! { test_median }
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn test_median_success<T>(
|
||||
column_name: &str,
|
||||
table_name: &str,
|
||||
engine: Arc<dyn QueryEngine>,
|
||||
) -> Result<()>
|
||||
where
|
||||
T: PrimitiveElement,
|
||||
for<'a> T: Scalar<RefType<'a> = T>,
|
||||
{
|
||||
let result = execute_median(column_name, table_name, engine.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(1, result.len());
|
||||
assert_eq!(result[0].df_recordbatch.num_columns(), 1);
|
||||
assert_eq!(1, result[0].schema.arrow_schema().fields().len());
|
||||
assert_eq!("median", result[0].schema.arrow_schema().field(0).name());
|
||||
|
||||
let columns = result[0].df_recordbatch.columns();
|
||||
assert_eq!(1, columns.len());
|
||||
assert_eq!(columns[0].len(), 1);
|
||||
let v = VectorHelper::try_into_vector(&columns[0]).unwrap();
|
||||
assert_eq!(1, v.len());
|
||||
let median = v.get(0);
|
||||
|
||||
let mut numbers = get_numbers_from_table::<T>(column_name, table_name, engine.clone()).await;
|
||||
numbers.sort();
|
||||
let len = numbers.len();
|
||||
let expected_median: Value = if len % 2 == 1 {
|
||||
numbers[len / 2]
|
||||
} else {
|
||||
let a: f64 = NumCast::from(numbers[len / 2 - 1].as_primitive()).unwrap();
|
||||
let b: f64 = NumCast::from(numbers[len / 2].as_primitive()).unwrap();
|
||||
OrdPrimitive::<T>(NumCast::from(a / 2.0 + b / 2.0).unwrap())
|
||||
}
|
||||
.into();
|
||||
assert_eq!(expected_median, median);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn execute_median<'a>(
|
||||
column_name: &'a str,
|
||||
table_name: &'a str,
|
||||
engine: Arc<dyn QueryEngine>,
|
||||
) -> RecordResult<Vec<RecordBatch>> {
|
||||
let sql = format!(
|
||||
"select MEDIAN({}) as median from {}",
|
||||
column_name, table_name
|
||||
);
|
||||
let plan = engine
|
||||
.sql_to_plan(&sql, Arc::new(QueryContext::new()))
|
||||
.unwrap();
|
||||
|
||||
let output = engine.execute(&plan).await.unwrap();
|
||||
let recordbatch_stream = match output {
|
||||
Output::Stream(batch) => batch,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
util::collect(recordbatch_stream).await
|
||||
}
|
||||
|
||||
@@ -18,11 +18,8 @@ mod function;
|
||||
use common_query::Output;
|
||||
use common_recordbatch::error::Result as RecordResult;
|
||||
use common_recordbatch::{util, RecordBatch};
|
||||
use datafusion::field_util::{FieldExt, SchemaExt};
|
||||
use datatypes::for_all_primitive_types;
|
||||
use datatypes::prelude::*;
|
||||
use datatypes::types::PrimitiveElement;
|
||||
use function::{create_query_engine, get_numbers_from_table};
|
||||
use datatypes::types::WrapperType;
|
||||
use num_traits::AsPrimitive;
|
||||
use query::error::Result;
|
||||
use query::QueryEngine;
|
||||
@@ -33,7 +30,7 @@ use statrs::statistics::Statistics;
|
||||
#[tokio::test]
|
||||
async fn test_scipy_stats_norm_cdf_aggregator() -> Result<()> {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
let engine = create_query_engine();
|
||||
let engine = function::create_query_engine();
|
||||
|
||||
macro_rules! test_scipy_stats_norm_cdf {
|
||||
([], $( { $T:ty } ),*) => {
|
||||
@@ -53,28 +50,15 @@ async fn test_scipy_stats_norm_cdf_success<T>(
|
||||
engine: Arc<dyn QueryEngine>,
|
||||
) -> Result<()>
|
||||
where
|
||||
T: PrimitiveElement + AsPrimitive<f64>,
|
||||
for<'a> T: Scalar<RefType<'a> = T>,
|
||||
T: WrapperType + AsPrimitive<f64>,
|
||||
{
|
||||
let result = execute_scipy_stats_norm_cdf(column_name, table_name, engine.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(1, result.len());
|
||||
assert_eq!(result[0].df_recordbatch.num_columns(), 1);
|
||||
assert_eq!(1, result[0].schema.arrow_schema().fields().len());
|
||||
assert_eq!(
|
||||
"scipy_stats_norm_cdf",
|
||||
result[0].schema.arrow_schema().field(0).name()
|
||||
);
|
||||
let value = function::get_value_from_batches("scipy_stats_norm_cdf", result);
|
||||
|
||||
let columns = result[0].df_recordbatch.columns();
|
||||
assert_eq!(1, columns.len());
|
||||
assert_eq!(columns[0].len(), 1);
|
||||
let v = VectorHelper::try_into_vector(&columns[0]).unwrap();
|
||||
assert_eq!(1, v.len());
|
||||
let value = v.get(0);
|
||||
|
||||
let numbers = get_numbers_from_table::<T>(column_name, table_name, engine.clone()).await;
|
||||
let numbers =
|
||||
function::get_numbers_from_table::<T>(column_name, table_name, engine.clone()).await;
|
||||
let expected_value = numbers.iter().map(|&n| n.as_()).collect::<Vec<f64>>();
|
||||
let mean = expected_value.clone().mean();
|
||||
let stddev = expected_value.std_dev();
|
||||
|
||||
@@ -18,11 +18,8 @@ mod function;
|
||||
use common_query::Output;
|
||||
use common_recordbatch::error::Result as RecordResult;
|
||||
use common_recordbatch::{util, RecordBatch};
|
||||
use datafusion::field_util::{FieldExt, SchemaExt};
|
||||
use datatypes::for_all_primitive_types;
|
||||
use datatypes::prelude::*;
|
||||
use datatypes::types::PrimitiveElement;
|
||||
use function::{create_query_engine, get_numbers_from_table};
|
||||
use datatypes::types::WrapperType;
|
||||
use num_traits::AsPrimitive;
|
||||
use query::error::Result;
|
||||
use query::QueryEngine;
|
||||
@@ -33,7 +30,7 @@ use statrs::statistics::Statistics;
|
||||
#[tokio::test]
|
||||
async fn test_scipy_stats_norm_pdf_aggregator() -> Result<()> {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
let engine = create_query_engine();
|
||||
let engine = function::create_query_engine();
|
||||
|
||||
macro_rules! test_scipy_stats_norm_pdf {
|
||||
([], $( { $T:ty } ),*) => {
|
||||
@@ -53,28 +50,15 @@ async fn test_scipy_stats_norm_pdf_success<T>(
|
||||
engine: Arc<dyn QueryEngine>,
|
||||
) -> Result<()>
|
||||
where
|
||||
T: PrimitiveElement + AsPrimitive<f64>,
|
||||
for<'a> T: Scalar<RefType<'a> = T>,
|
||||
T: WrapperType + AsPrimitive<f64>,
|
||||
{
|
||||
let result = execute_scipy_stats_norm_pdf(column_name, table_name, engine.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(1, result.len());
|
||||
assert_eq!(result[0].df_recordbatch.num_columns(), 1);
|
||||
assert_eq!(1, result[0].schema.arrow_schema().fields().len());
|
||||
assert_eq!(
|
||||
"scipy_stats_norm_pdf",
|
||||
result[0].schema.arrow_schema().field(0).name()
|
||||
);
|
||||
let value = function::get_value_from_batches("scipy_stats_norm_pdf", result);
|
||||
|
||||
let columns = result[0].df_recordbatch.columns();
|
||||
assert_eq!(1, columns.len());
|
||||
assert_eq!(columns[0].len(), 1);
|
||||
let v = VectorHelper::try_into_vector(&columns[0]).unwrap();
|
||||
assert_eq!(1, v.len());
|
||||
let value = v.get(0);
|
||||
|
||||
let numbers = get_numbers_from_table::<T>(column_name, table_name, engine.clone()).await;
|
||||
let numbers =
|
||||
function::get_numbers_from_table::<T>(column_name, table_name, engine.clone()).await;
|
||||
let expected_value = numbers.iter().map(|&n| n.as_()).collect::<Vec<f64>>();
|
||||
let mean = expected_value.clone().mean();
|
||||
let stddev = expected_value.std_dev();
|
||||
|
||||
@@ -15,4 +15,4 @@ itertools = "0.10"
|
||||
mito = { path = "../mito" }
|
||||
once_cell = "1.10"
|
||||
snafu = { version = "0.7", features = ["backtraces"] }
|
||||
sqlparser = "0.15.0"
|
||||
sqlparser = "0.26"
|
||||
|
||||
@@ -14,5 +14,5 @@
|
||||
|
||||
pub use sqlparser::ast::{
|
||||
ColumnDef, ColumnOption, ColumnOptionDef, DataType, Expr, Function, FunctionArg,
|
||||
FunctionArgExpr, Ident, ObjectName, SqlOption, TableConstraint, Value,
|
||||
FunctionArgExpr, Ident, ObjectName, SqlOption, TableConstraint, TimezoneInfo, Value,
|
||||
};
|
||||
|
||||
@@ -505,11 +505,7 @@ mod tests {
|
||||
assert_matches!(
|
||||
&stmts[0],
|
||||
Statement::ShowTables(ShowTables {
|
||||
kind: ShowKind::Where(sqlparser::ast::Expr::BinaryOp {
|
||||
left: _,
|
||||
right: _,
|
||||
op: sqlparser::ast::BinaryOperator::Like,
|
||||
}),
|
||||
kind: ShowKind::Where(sqlparser::ast::Expr::Like { .. }),
|
||||
database: None,
|
||||
})
|
||||
);
|
||||
@@ -522,11 +518,7 @@ mod tests {
|
||||
assert_matches!(
|
||||
&stmts[0],
|
||||
Statement::ShowTables(ShowTables {
|
||||
kind: ShowKind::Where(sqlparser::ast::Expr::BinaryOp {
|
||||
left: _,
|
||||
right: _,
|
||||
op: sqlparser::ast::BinaryOperator::Like,
|
||||
}),
|
||||
kind: ShowKind::Where(sqlparser::ast::Expr::Like { .. }),
|
||||
database: Some(_),
|
||||
})
|
||||
);
|
||||
@@ -543,11 +535,12 @@ mod tests {
|
||||
distinct: false,
|
||||
top: None,
|
||||
projection: vec![sqlparser::ast::SelectItem::Wildcard],
|
||||
into: None,
|
||||
from: vec![sqlparser::ast::TableWithJoins {
|
||||
relation: sqlparser::ast::TableFactor::Table {
|
||||
name: sqlparser::ast::ObjectName(vec![sqlparser::ast::Ident::new("foo")]),
|
||||
alias: None,
|
||||
args: vec![],
|
||||
args: None,
|
||||
with_hints: vec![],
|
||||
},
|
||||
joins: vec![],
|
||||
@@ -559,11 +552,12 @@ mod tests {
|
||||
distribute_by: vec![],
|
||||
sort_by: vec![],
|
||||
having: None,
|
||||
qualify: None,
|
||||
};
|
||||
|
||||
let sp_statement = SpStatement::Query(Box::new(SpQuery {
|
||||
with: None,
|
||||
body: sqlparser::ast::SetExpr::Select(Box::new(select)),
|
||||
body: Box::new(sqlparser::ast::SetExpr::Select(Box::new(select))),
|
||||
order_by: vec![],
|
||||
limit: None,
|
||||
offset: None,
|
||||
@@ -576,6 +570,7 @@ mod tests {
|
||||
analyze: false,
|
||||
verbose: false,
|
||||
statement: Box::new(sp_statement),
|
||||
format: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
|
||||
@@ -318,7 +318,7 @@ pub fn sql_data_type_to_concrete_data_type(data_type: &SqlDataType) -> Result<Co
|
||||
}
|
||||
.fail(),
|
||||
},
|
||||
SqlDataType::Timestamp => Ok(ConcreteDataType::timestamp_millisecond_datatype()),
|
||||
SqlDataType::Timestamp(_) => Ok(ConcreteDataType::timestamp_millisecond_datatype()),
|
||||
_ => error::SqlTypeNotSupportedSnafu {
|
||||
t: data_type.clone(),
|
||||
}
|
||||
@@ -336,7 +336,7 @@ mod tests {
|
||||
use datatypes::value::OrderedFloat;
|
||||
|
||||
use super::*;
|
||||
use crate::ast::{DataType, Ident};
|
||||
use crate::ast::{Ident, TimezoneInfo};
|
||||
use crate::statements::ColumnOption;
|
||||
|
||||
fn check_type(sql_type: SqlDataType, data_type: ConcreteDataType) {
|
||||
@@ -376,7 +376,7 @@ mod tests {
|
||||
ConcreteDataType::datetime_datatype(),
|
||||
);
|
||||
check_type(
|
||||
SqlDataType::Timestamp,
|
||||
SqlDataType::Timestamp(TimezoneInfo::None),
|
||||
ConcreteDataType::timestamp_millisecond_datatype(),
|
||||
);
|
||||
}
|
||||
@@ -573,7 +573,7 @@ mod tests {
|
||||
// test basic
|
||||
let column_def = ColumnDef {
|
||||
name: "col".into(),
|
||||
data_type: DataType::Double,
|
||||
data_type: SqlDataType::Double,
|
||||
collation: None,
|
||||
options: vec![],
|
||||
};
|
||||
@@ -588,7 +588,7 @@ mod tests {
|
||||
// test not null
|
||||
let column_def = ColumnDef {
|
||||
name: "col".into(),
|
||||
data_type: DataType::Double,
|
||||
data_type: SqlDataType::Double,
|
||||
collation: None,
|
||||
options: vec![ColumnOptionDef {
|
||||
name: None,
|
||||
|
||||
@@ -49,7 +49,7 @@ impl Insert {
|
||||
|
||||
pub fn values(&self) -> Result<Vec<Vec<Value>>> {
|
||||
let values = match &self.inner {
|
||||
Statement::Insert { source, .. } => match &source.body {
|
||||
Statement::Insert { source, .. } => match &*source.body {
|
||||
SetExpr::Values(Values(exprs)) => sql_exprs_to_values(exprs)?,
|
||||
_ => unreachable!(),
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user