fix: Fix compiler errors in query crate (#746)

* fix: Fix compiler errors in state.rs

* fix: fix compiler errors in state

* feat: upgrade sqlparser to 0.26

* fix: fix datafusion engine compiler errors

* fix: Fix some tests in query crate

* fix: Fix all warnings in tests

* feat: Remove `Type` from timestamp's type name

* fix: fix query tests

Now datafusion already supports median, so this commit also remove the
median function

* style: Fix clippy

* feat: Remove RecordBatch::pretty_print

* chore: Address CR comments

* Update src/query/src/query_engine/state.rs

Co-authored-by: Ruihang Xia <waynestxia@gmail.com>
This commit is contained in:
Yingwen
2022-12-14 17:42:07 +08:00
committed by GitHub
parent 652d59a643
commit dbb3034ecb
34 changed files with 349 additions and 888 deletions

22
Cargo.lock generated
View File

@@ -1899,7 +1899,7 @@ dependencies = [
"pin-project-lite",
"rand 0.8.5",
"smallvec",
"sqlparser 0.26.0",
"sqlparser",
"tempfile",
"tokio",
"tokio-stream",
@@ -1919,7 +1919,7 @@ dependencies = [
"object_store",
"ordered-float 3.4.0",
"parquet",
"sqlparser 0.26.0",
"sqlparser",
]
[[package]]
@@ -1932,7 +1932,7 @@ dependencies = [
"arrow",
"datafusion-common",
"log",
"sqlparser 0.26.0",
"sqlparser",
]
[[package]]
@@ -2003,7 +2003,7 @@ dependencies = [
"arrow",
"datafusion-common",
"datafusion-expr",
"sqlparser 0.26.0",
"sqlparser",
]
[[package]]
@@ -2467,7 +2467,6 @@ dependencies = [
"session",
"snafu",
"sql",
"sqlparser 0.15.0",
"store-api",
"substrait 0.1.0",
"table",
@@ -4950,7 +4949,9 @@ dependencies = [
"datafusion",
"datafusion-common",
"datafusion-expr",
"datafusion-optimizer",
"datafusion-physical-expr",
"datafusion-sql",
"datatypes",
"format_num",
"futures",
@@ -6357,7 +6358,7 @@ dependencies = [
"mito",
"once_cell",
"snafu",
"sqlparser 0.15.0",
"sqlparser",
]
[[package]]
@@ -6386,15 +6387,6 @@ dependencies = [
"tokio",
]
[[package]]
name = "sqlparser"
version = "0.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "adbbea2526ad0d02ad9414a07c396078a5b944bbf9ca4fbab8f01bb4cb579081"
dependencies = [
"log",
]
[[package]]
name = "sqlparser"
version = "0.26.0"

View File

@@ -23,6 +23,5 @@ pub(crate) mod test;
mod timestamp;
pub mod udf;
pub use aggregate::MedianAccumulatorCreator;
pub use function::{Function, FunctionRef};
pub use function_registry::{FunctionRegistry, FUNCTION_REGISTRY};

View File

@@ -16,7 +16,6 @@ mod argmax;
mod argmin;
mod diff;
mod mean;
mod median;
mod percentile;
mod polyval;
mod scipy_stats_norm_cdf;
@@ -29,7 +28,6 @@ pub use argmin::ArgminAccumulatorCreator;
use common_query::logical_plan::AggregateFunctionCreatorRef;
pub use diff::DiffAccumulatorCreator;
pub use mean::MeanAccumulatorCreator;
pub use median::MedianAccumulatorCreator;
pub use percentile::PercentileAccumulatorCreator;
pub use polyval::PolyvalAccumulatorCreator;
pub use scipy_stats_norm_cdf::ScipyStatsNormCdfAccumulatorCreator;
@@ -88,7 +86,6 @@ impl AggregateFunctions {
};
}
register_aggr_func!("median", 1, MedianAccumulatorCreator);
register_aggr_func!("diff", 1, DiffAccumulatorCreator);
register_aggr_func!("mean", 1, MeanAccumulatorCreator);
register_aggr_func!("polyval", 2, PolyvalAccumulatorCreator);

View File

@@ -1,287 +0,0 @@
// Copyright 2022 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::sync::Arc;
use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore};
use common_query::error::{
CreateAccumulatorSnafu, DowncastVectorSnafu, FromScalarValueSnafu, Result,
};
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*;
use datatypes::prelude::*;
use datatypes::types::OrdPrimitive;
use datatypes::value::ListValue;
use datatypes::vectors::{ConstantVector, Helper, ListVector};
use datatypes::with_match_primitive_type_id;
use num::NumCast;
use snafu::{ensure, OptionExt, ResultExt};
// This median calculation algorithm's details can be found at
// https://leetcode.cn/problems/find-median-from-data-stream/
//
// Basically, it uses two heaps, a maximum heap and a minimum. The maximum heap stores numbers that
// are not greater than the median, and the minimum heap stores the greater. In a streaming of
// numbers, when a number is arrived, we adjust the heaps' tops, so that either one top is the
// median or both tops can be averaged to get the median.
//
// The time complexity to update the median is O(logn), O(1) to get the median; and the space
// complexity is O(n). (Ignore the costs for heap expansion.)
//
// From the point of algorithm, [quick select](https://en.wikipedia.org/wiki/Quickselect) might be
// better. But to use quick select here, we need a mutable self in the final calculation(`evaluate`)
// to swap stored numbers in the states vector. Though we can make our `evaluate` received
// `&mut self`, DataFusion calls our accumulator with `&self` (see `DfAccumulatorAdaptor`). That
// means we have to introduce some kinds of interior mutability, and the overhead is not neglectable.
//
// TODO(LFC): Use quick select to get median when we can modify DataFusion's code, and benchmark with two-heap algorithm.
#[derive(Debug, Default)]
pub struct Median<T>
where
T: WrapperType,
{
greater: BinaryHeap<Reverse<OrdPrimitive<T>>>,
not_greater: BinaryHeap<OrdPrimitive<T>>,
}
impl<T> Median<T>
where
T: WrapperType,
{
fn push(&mut self, value: T) {
let value = OrdPrimitive::<T>(value);
if self.not_greater.is_empty() {
self.not_greater.push(value);
return;
}
// The `unwrap`s below are safe because there are `push`s before them.
if value <= *self.not_greater.peek().unwrap() {
self.not_greater.push(value);
if self.not_greater.len() > self.greater.len() + 1 {
self.greater.push(Reverse(self.not_greater.pop().unwrap()));
}
} else {
self.greater.push(Reverse(value));
if self.greater.len() > self.not_greater.len() {
self.not_greater.push(self.greater.pop().unwrap().0);
}
}
}
}
// UDAFs are built using the trait `Accumulator`, that offers DataFusion the necessary functions
// to use them.
impl<T> Accumulator for Median<T>
where
T: WrapperType + NumCast,
{
// This function serializes our state to `ScalarValue`, which DataFusion uses to pass this
// state between execution stages. Note that this can be arbitrary data.
//
// The `ScalarValue`s returned here will be passed in as argument `states: &[VectorRef]` to
// `merge_batch` function.
fn state(&self) -> Result<Vec<Value>> {
let nums = self
.greater
.iter()
.map(|x| &x.0)
.chain(self.not_greater.iter())
.map(|&n| n.into())
.collect::<Vec<Value>>();
Ok(vec![Value::List(ListValue::new(
Some(Box::new(nums)),
T::LogicalType::build_data_type(),
))])
}
// DataFusion calls this function to update the accumulator's state for a batch of inputs rows.
// It is expected this function to update the accumulator's state.
fn update_batch(&mut self, values: &[VectorRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
}
ensure!(values.len() == 1, InvalidInputStateSnafu);
// This is a unary accumulator, so only one column is provided.
let column = &values[0];
let mut len = 1;
let column: &<T as Scalar>::VectorType = if column.is_const() {
len = column.len();
let column: &ConstantVector = unsafe { Helper::static_cast(column) };
unsafe { Helper::static_cast(column.inner()) }
} else {
unsafe { Helper::static_cast(column) }
};
(0..len).for_each(|_| {
for v in column.iter_data().flatten() {
self.push(v);
}
});
Ok(())
}
// DataFusion executes accumulators in partitions. In some execution stage, DataFusion will
// merge states from other accumulators (returned by `state()` method).
fn merge_batch(&mut self, states: &[VectorRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
}
// The states here are returned by the `state` method. Since we only returned a vector
// with one value in that method, `states[0]` is fine.
let states = &states[0];
let states = states
.as_any()
.downcast_ref::<ListVector>()
.with_context(|| DowncastVectorSnafu {
err_msg: format!(
"expect ListVector, got vector type {}",
states.vector_type_name()
),
})?;
for state in states.values_iter() {
if let Some(state) = state.context(FromScalarValueSnafu)? {
// merging state is simply accumulate stored numbers from others', so just call update
self.update_batch(&[state])?
}
}
Ok(())
}
// DataFusion expects this function to return the final value of this aggregator.
fn evaluate(&self) -> Result<Value> {
if self.not_greater.is_empty() {
assert!(
self.greater.is_empty(),
"not expected in two-heap median algorithm, there must be a bug when implementing it"
);
return Ok(Value::Null);
}
// unwrap is safe because we checked not_greater heap's len above
let not_greater = *self.not_greater.peek().unwrap();
let median = if self.not_greater.len() > self.greater.len() {
not_greater.into()
} else {
// unwrap is safe because greater heap len >= not_greater heap len, which is > 0
let greater = self.greater.peek().unwrap();
// the following three NumCast's `unwrap`s are safe because T is primitive
let not_greater_v: f64 = NumCast::from(not_greater.as_primitive()).unwrap();
let greater_v: f64 = NumCast::from(greater.0.as_primitive()).unwrap();
let median: T = NumCast::from((not_greater_v + greater_v) / 2.0).unwrap();
median.into()
};
Ok(median)
}
}
#[as_aggr_func_creator]
#[derive(Debug, Default, AggrFuncTypeStore)]
pub struct MedianAccumulatorCreator {}
impl AggregateFunctionCreator for MedianAccumulatorCreator {
fn creator(&self) -> AccumulatorCreatorFunction {
let creator: AccumulatorCreatorFunction = Arc::new(move |types: &[ConcreteDataType]| {
let input_type = &types[0];
with_match_primitive_type_id!(
input_type.logical_type_id(),
|$S| {
Ok(Box::new(Median::<<$S as LogicalPrimitiveType>::Wrapper>::default()))
},
{
let err_msg = format!(
"\"MEDIAN\" aggregate function not support data type {:?}",
input_type.logical_type_id(),
);
CreateAccumulatorSnafu { err_msg }.fail()?
}
)
});
creator
}
fn output_type(&self) -> Result<ConcreteDataType> {
let input_types = self.input_types()?;
ensure!(input_types.len() == 1, InvalidInputStateSnafu);
// unwrap is safe because we have checked input_types len must equals 1
Ok(input_types.into_iter().next().unwrap())
}
fn state_types(&self) -> Result<Vec<ConcreteDataType>> {
Ok(vec![ConcreteDataType::list_datatype(self.output_type()?)])
}
}
#[cfg(test)]
mod test {
use datatypes::vectors::Int32Vector;
use super::*;
#[test]
fn test_update_batch() {
// test update empty batch, expect not updating anything
let mut median = Median::<i32>::default();
assert!(median.update_batch(&[]).is_ok());
assert!(median.not_greater.is_empty());
assert!(median.greater.is_empty());
assert_eq!(Value::Null, median.evaluate().unwrap());
// test update one not-null value
let mut median = Median::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(Int32Vector::from(vec![Some(42)]))];
assert!(median.update_batch(&v).is_ok());
assert_eq!(Value::Int32(42), median.evaluate().unwrap());
// test update one null value
let mut median = Median::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(Int32Vector::from(vec![Option::<i32>::None]))];
assert!(median.update_batch(&v).is_ok());
assert_eq!(Value::Null, median.evaluate().unwrap());
// test update no null-value batch
let mut median = Median::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(Int32Vector::from(vec![
Some(-1i32),
Some(1),
Some(2),
]))];
assert!(median.update_batch(&v).is_ok());
assert_eq!(Value::Int32(1), median.evaluate().unwrap());
// test update null-value batch
let mut median = Median::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(Int32Vector::from(vec![
Some(-2i32),
None,
Some(3),
Some(4),
]))];
assert!(median.update_batch(&v).is_ok());
assert_eq!(Value::Int32(3), median.evaluate().unwrap());
// test update with constant vector
let mut median = Median::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(ConstantVector::new(
Arc::new(Int32Vector::from_vec(vec![4])),
10,
))];
assert!(median.update_batch(&v).is_ok());
assert_eq!(Value::Int32(4), median.evaluate().unwrap());
}
}

View File

@@ -66,7 +66,7 @@ impl Stream for EmptyRecordBatchStream {
}
}
#[derive(Debug)]
#[derive(Debug, PartialEq)]
pub struct RecordBatches {
schema: SchemaRef,
batches: Vec<RecordBatch>,

View File

@@ -15,13 +15,20 @@
use futures::TryStreamExt;
use crate::error::Result;
use crate::{RecordBatch, SendableRecordBatchStream};
use crate::{RecordBatch, RecordBatches, SendableRecordBatchStream};
/// Collect all the items from the stream into a vector of [`RecordBatch`].
pub async fn collect(stream: SendableRecordBatchStream) -> Result<Vec<RecordBatch>> {
stream.try_collect::<Vec<_>>().await
}
/// Collect all the items from the stream into [RecordBatches].
pub async fn collect_batches(stream: SendableRecordBatchStream) -> Result<RecordBatches> {
let schema = stream.schema();
let batches = stream.try_collect::<Vec<_>>().await?;
RecordBatches::try_new(schema, batches)
}
#[cfg(test)]
mod tests {
use std::mem;
@@ -90,7 +97,14 @@ mod tests {
};
let batches = collect(Box::pin(stream)).await.unwrap();
assert_eq!(1, batches.len());
assert_eq!(batch, batches[0]);
let stream = MockRecordBatchStream {
schema: schema.clone(),
batch: Some(batch.clone()),
};
let batches = collect_batches(Box::pin(stream)).await.unwrap();
let expect_batches = RecordBatches::try_new(schema.clone(), vec![batch]).unwrap();
assert_eq!(expect_batches, batches);
}
}

View File

@@ -70,7 +70,7 @@ macro_rules! impl_data_type_for_timestamp {
impl DataType for [<Timestamp $unit Type>] {
fn name(&self) -> &str {
stringify!([<Timestamp $unit Type>])
stringify!([<Timestamp $unit>])
}
fn logical_type_id(&self) -> LogicalTypeId {

View File

@@ -1346,7 +1346,7 @@ mod tests {
ConcreteDataType::timestamp_second_datatype(),
))
.to_string(),
"TimestampSecondType[]"
"TimestampSecond[]"
);
assert_eq!(
Value::List(ListValue::new(
@@ -1354,7 +1354,7 @@ mod tests {
ConcreteDataType::timestamp_millisecond_datatype(),
))
.to_string(),
"TimestampMillisecondType[]"
"TimestampMillisecond[]"
);
assert_eq!(
Value::List(ListValue::new(
@@ -1362,7 +1362,7 @@ mod tests {
ConcreteDataType::timestamp_microsecond_datatype(),
))
.to_string(),
"TimestampMicrosecondType[]"
"TimestampMicrosecond[]"
);
assert_eq!(
Value::List(ListValue::new(
@@ -1370,7 +1370,7 @@ mod tests {
ConcreteDataType::timestamp_nanosecond_datatype(),
))
.to_string(),
"TimestampNanosecondType[]"
"TimestampNanosecond[]"
);
}

View File

@@ -40,7 +40,6 @@ servers = { path = "../servers" }
session = { path = "../session" }
snafu = { version = "0.7", features = ["backtraces"] }
sql = { path = "../sql" }
sqlparser = "0.15"
store-api = { path = "../store-api" }
substrait = { path = "../common/substrait" }
table = { path = "../table" }

View File

@@ -19,9 +19,9 @@ use api::helper::ColumnDataTypeWrapper;
use api::v1::{Column, ColumnDataType, CreateExpr};
use datatypes::schema::ColumnSchema;
use snafu::{ensure, ResultExt};
use sql::ast::{ColumnDef, TableConstraint};
use sql::statements::create::{CreateTable, TIME_INDEX};
use sql::statements::{column_def_to_schema, table_idents_to_full_name};
use sqlparser::ast::{ColumnDef, TableConstraint};
use crate::error::{
BuildCreateExprOnInsertionSnafu, ColumnDataTypeSnafu, ConvertColumnDefaultConstraintSnafu,

View File

@@ -35,10 +35,10 @@ use query::sql::{describe_table, explain, show_databases, show_tables};
use query::{QueryEngineFactory, QueryEngineRef};
use session::context::QueryContextRef;
use snafu::{ensure, OptionExt, ResultExt};
use sql::ast::Value as SqlValue;
use sql::statements::create::Partitions;
use sql::statements::sql_value_to_value;
use sql::statements::statement::Statement;
use sqlparser::ast::Value as SqlValue;
use table::metadata::{RawTableInfo, RawTableMeta, TableIdent, TableType};
use crate::catalog::FrontendCatalogManager;
@@ -454,9 +454,9 @@ fn find_partition_columns(
#[cfg(test)]
mod test {
use sql::dialect::GenericDialect;
use sql::parser::ParserContext;
use sql::statements::statement::Statement;
use sqlparser::dialect::GenericDialect;
use super::*;
use crate::expr_factory::{CreateExprFactory, DefaultCreateExprFactory};

View File

@@ -530,9 +530,9 @@ mod test {
use meta_srv::mocks::MockInfo;
use meta_srv::service::store::kv::KvStoreRef;
use meta_srv::service::store::memory::MemStore;
use sql::dialect::GenericDialect;
use sql::parser::ParserContext;
use sql::statements::statement::Statement;
use sqlparser::dialect::GenericDialect;
use table::metadata::{TableInfoBuilder, TableMetaBuilder};
use table::TableRef;
use tempdir::TempDir;

View File

@@ -18,7 +18,9 @@ common-time = { path = "../common/time" }
datafusion = "14.0.0"
datafusion-common = "14.0.0"
datafusion-expr = "14.0.0"
datafusion-optimizer = "14.0.0"
datafusion-physical-expr = "14.0.0"
datafusion-sql = "14.0.0"
datatypes = { path = "../datatypes" }
futures = "0.3"
futures-util = "0.3"

View File

@@ -141,7 +141,6 @@ impl LogicalOptimizer for DatafusionQueryEngine {
LogicalPlan::DfPlan(df_plan) => {
let optimized_plan =
self.state
.df_context()
.optimize(df_plan)
.context(error::DatafusionSnafu {
msg: "Fail to optimize logical plan",
@@ -163,14 +162,11 @@ impl PhysicalPlanner for DatafusionQueryEngine {
let _timer = timer!(metric::METRIC_CREATE_PHYSICAL_ELAPSED);
match logical_plan {
LogicalPlan::DfPlan(df_plan) => {
let physical_plan = self
.state
.df_context()
.create_physical_plan(df_plan)
.await
.context(error::DatafusionSnafu {
let physical_plan = self.state.create_physical_plan(df_plan).await.context(
error::DatafusionSnafu {
msg: "Fail to create physical plan",
})?;
},
)?;
Ok(Arc::new(PhysicalPlanAdapter::new(
Arc::new(
@@ -193,22 +189,19 @@ impl PhysicalOptimizer for DatafusionQueryEngine {
plan: Arc<dyn PhysicalPlan>,
) -> Result<Arc<dyn PhysicalPlan>> {
let _timer = timer!(metric::METRIC_OPTIMIZE_PHYSICAL_ELAPSED);
let config = &self.state.df_context().state.lock().config;
let optimizers = &config.physical_optimizers;
let mut new_plan = plan
let new_plan = plan
.as_any()
.downcast_ref::<PhysicalPlanAdapter>()
.context(error::PhysicalPlanDowncastSnafu)?
.df_plan();
for optimizer in optimizers {
new_plan = optimizer
.optimize(new_plan, config)
let new_plan =
self.state
.optimize_physical_plan(new_plan)
.context(error::DatafusionSnafu {
msg: "Fail to optimize physical plan",
})?;
}
Ok(Arc::new(PhysicalPlanAdapter::new(plan.schema(), new_plan)))
}
}
@@ -224,7 +217,7 @@ impl QueryExecutor for DatafusionQueryEngine {
match plan.output_partitioning().partition_count() {
0 => Ok(Box::pin(EmptyRecordBatchStream::new(plan.schema()))),
1 => Ok(plan
.execute(0, ctx.state().runtime())
.execute(0, ctx.state().task_ctx())
.context(error::ExecutePhysicalPlanSnafu)?),
_ => {
// merge into a single partition
@@ -232,11 +225,11 @@ impl QueryExecutor for DatafusionQueryEngine {
CoalescePartitionsExec::new(Arc::new(DfPhysicalPlanAdapter(plan.clone())));
// CoalescePartitionsExec must produce a single partition
assert_eq!(1, plan.output_partitioning().partition_count());
let df_stream = plan.execute(0, ctx.state().runtime()).await.context(
error::DatafusionSnafu {
msg: "Failed to execute DataFusion merge exec",
},
)?;
let df_stream =
plan.execute(0, ctx.state().task_ctx())
.context(error::DatafusionSnafu {
msg: "Failed to execute DataFusion merge exec",
})?;
let stream = RecordBatchStreamAdapter::try_new(df_stream)
.context(error::ConvertDfRecordBatchStreamSnafu)?;
Ok(Box::pin(stream))
@@ -254,8 +247,7 @@ mod tests {
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_query::Output;
use common_recordbatch::util;
use datafusion::field_util::{FieldExt, SchemaExt};
use datatypes::arrow::array::UInt64Array;
use datatypes::vectors::{UInt64Vector, VectorRef};
use session::context::QueryContext;
use table::table::numbers::NumbersTable;
@@ -290,10 +282,10 @@ mod tests {
assert_eq!(
format!("{:?}", plan),
r#"DfPlan(Limit: 20
Projection: #SUM(numbers.number)
Aggregate: groupBy=[[]], aggr=[[SUM(#numbers.number)]]
TableScan: numbers projection=None)"#
r#"DfPlan(Limit: skip=0, fetch=20
Projection: SUM(numbers.number)
Aggregate: groupBy=[[]], aggr=[[SUM(numbers.number)]]
TableScan: numbers)"#
);
}
@@ -311,20 +303,20 @@ mod tests {
Output::Stream(recordbatch) => {
let numbers = util::collect(recordbatch).await.unwrap();
assert_eq!(1, numbers.len());
assert_eq!(numbers[0].df_recordbatch.num_columns(), 1);
assert_eq!(1, numbers[0].schema.arrow_schema().fields().len());
assert_eq!(numbers[0].num_columns(), 1);
assert_eq!(1, numbers[0].schema.num_columns());
assert_eq!(
"SUM(numbers.number)",
numbers[0].schema.arrow_schema().field(0).name()
numbers[0].schema.column_schemas()[0].name
);
let columns = numbers[0].df_recordbatch.columns();
assert_eq!(1, columns.len());
assert_eq!(columns[0].len(), 1);
let batch = &numbers[0];
assert_eq!(1, batch.num_columns());
assert_eq!(batch.column(0).len(), 1);
assert_eq!(
*columns[0].as_any().downcast_ref::<UInt64Array>().unwrap(),
UInt64Array::from_slice(&[4950])
*batch.column(0),
Arc::new(UInt64Vector::from_slice(&[4950])) as VectorRef
);
}
_ => unreachable!(),

View File

@@ -12,14 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use std::sync::Arc;
use common_query::logical_plan::create_aggregate_function;
use datafusion::catalog::TableReference;
use datafusion::datasource::TableProvider;
use datafusion::error::Result as DfResult;
use datafusion::physical_plan::udaf::AggregateUDF;
use datafusion::physical_plan::udf::ScalarUDF;
use datafusion::sql::planner::{ContextProvider, SqlToRel};
use datafusion_expr::TableSource;
use datatypes::arrow::datatypes::DataType;
use session::context::QueryContextRef;
use snafu::ResultExt;
@@ -50,7 +52,7 @@ impl<'a, S: ContextProvider + Send + Sync> DfPlanner<'a, S> {
let sql = query.inner.to_string();
let result = self
.sql_to_rel
.query_to_plan(query.inner)
.query_to_plan(query.inner, &mut HashMap::new())
.context(error::PlanSqlSnafu { sql })?;
Ok(LogicalPlan::DfPlan(result))
@@ -103,26 +105,14 @@ impl DfContextProviderAdapter {
}
}
/// TODO(dennis): Delegate all requests to ExecutionContext right now,
/// manage UDFs, UDAFs, variables by ourself in future.
impl ContextProvider for DfContextProviderAdapter {
fn get_table_provider(&self, name: TableReference) -> Option<Arc<dyn TableProvider>> {
fn get_table_provider(&self, name: TableReference) -> DfResult<Arc<dyn TableSource>> {
let schema = self.query_ctx.current_schema();
let execution_ctx = self.state.df_context().state.lock();
match name {
TableReference::Bare { table } if schema.is_some() => {
execution_ctx.get_table_provider(TableReference::Partial {
// unwrap safety: checked in this match's arm
schema: &schema.unwrap(),
table,
})
}
_ => execution_ctx.get_table_provider(name),
}
self.state.get_table_provider(schema.as_deref(), name)
}
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
self.state.df_context().state.lock().get_function_meta(name)
self.state.get_function_meta(name)
}
fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
@@ -134,10 +124,6 @@ impl ContextProvider for DfContextProviderAdapter {
}
fn get_variable_type(&self, variable_names: &[String]) -> Option<DataType> {
self.state
.df_context()
.state
.lock()
.get_variable_type(variable_names)
self.state.get_variable_type(variable_names)
}
}

View File

@@ -44,12 +44,10 @@ impl OptimizerRule for TypeConversionRule {
};
match plan {
LogicalPlan::Filter(Filter { predicate, input }) => {
Ok(LogicalPlan::Filter(Filter::try_new(
predicate.clone().rewrite(&mut converter)?,
Arc::new(self.optimize(input, optimizer_config)?),
)?))
}
LogicalPlan::Filter(filter) => Ok(LogicalPlan::Filter(Filter::try_new(
filter.predicate().clone().rewrite(&mut converter)?,
Arc::new(self.optimize(filter.input(), optimizer_config)?),
)?)),
LogicalPlan::TableScan(TableScan {
table_name,
source,
@@ -150,8 +148,8 @@ impl<'a> TypeConverter<'a> {
compute::cast(&value_arr, target_type).map_err(DataFusionError::ArrowError)?;
ScalarValue::try_from_array(
&Arc::from(arr), // index: Converts a value in `array` at `index` into a ScalarValue
0,
&arr,
0, // index: Converts a value in `array` at `index` into a ScalarValue
)
}
}

View File

@@ -21,14 +21,17 @@ use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_function::scalars::aggregate::AggregateFunctionMetaRef;
use common_query::physical_plan::{SessionContext, TaskContext};
use common_query::prelude::ScalarUdf;
use datafusion::optimizer::common_subexpr_eliminate::CommonSubexprEliminate;
use datafusion::optimizer::eliminate_limit::EliminateLimit;
use datafusion::optimizer::filter_push_down::FilterPushDown;
use datafusion::optimizer::limit_push_down::LimitPushDown;
use datafusion::optimizer::projection_push_down::ProjectionPushDown;
use datafusion::optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy;
use datafusion::optimizer::to_approx_perc::ToApproxPerc;
use datafusion::execution::context::SessionConfig;
use datafusion::catalog::TableReference;
use datafusion::error::Result as DfResult;
use datafusion::execution::context::{SessionConfig, SessionState};
use datafusion::execution::runtime_env::RuntimeEnv;
use datafusion::physical_plan::udf::ScalarUDF;
use datafusion::physical_plan::ExecutionPlan;
use datafusion_expr::{LogicalPlan as DfLogicalPlan, TableSource};
use datafusion_optimizer::optimizer::{Optimizer, OptimizerConfig};
use datafusion_sql::planner::ContextProvider;
use datatypes::arrow::datatypes::DataType;
use crate::datafusion::DfCatalogListAdapter;
use crate::optimizer::TypeConversionRule;
@@ -38,7 +41,7 @@ use crate::optimizer::TypeConversionRule;
// type in QueryEngine trait.
#[derive(Clone)]
pub struct QueryEngineState {
df_context: ExecutionContext,
df_context: SessionContext,
catalog_list: CatalogListRef,
aggregate_functions: Arc<RwLock<HashMap<String, AggregateFunctionMetaRef>>>,
}
@@ -52,25 +55,18 @@ impl fmt::Debug for QueryEngineState {
impl QueryEngineState {
pub(crate) fn new(catalog_list: CatalogListRef) -> Self {
let config = ExecutionConfig::new()
.with_default_catalog_and_schema(DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME)
.with_optimizer_rules(vec![
// TODO(hl): SimplifyExpressions is not exported.
Arc::new(TypeConversionRule {}),
// These are the default optimizer in datafusion
Arc::new(CommonSubexprEliminate::new()),
Arc::new(EliminateLimit::new()),
Arc::new(ProjectionPushDown::new()),
Arc::new(FilterPushDown::new()),
Arc::new(LimitPushDown::new()),
Arc::new(SingleDistinctToGroupBy::new()),
Arc::new(ToApproxPerc::new()),
]);
let runtime_env = Arc::new(RuntimeEnv::default());
let session_config = SessionConfig::new()
.with_default_catalog_and_schema(DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME);
let mut optimizer = Optimizer::new(&OptimizerConfig::new());
// Apply the type conversion rule first.
optimizer.rules.insert(0, Arc::new(TypeConversionRule {}));
let df_context = ExecutionContext::with_config(config);
let mut session_state = SessionState::with_config_rt(session_config, runtime_env);
session_state.optimizer = optimizer;
session_state.catalog_list = Arc::new(DfCatalogListAdapter::new(catalog_list.clone()));
df_context.state.lock().catalog_list =
Arc::new(DfCatalogListAdapter::new(catalog_list.clone()));
let df_context = SessionContext::with_state(session_state);
Self {
df_context,
@@ -80,11 +76,15 @@ impl QueryEngineState {
}
/// Register a udf function
/// TODO(dennis): manage UDFs by ourself.
// TODO(dennis): manage UDFs by ourself.
pub fn register_udf(&self, udf: ScalarUdf) {
// `SessionContext` has a `register_udf()` method, which requires `&mut self`, this is
// a workaround.
// TODO(yingwen): Use `SessionContext::register_udf()` once it taks `&self`.
// It's implemented in https://github.com/apache/arrow-datafusion/pull/4612
self.df_context
.state
.lock()
.write()
.scalar_functions
.insert(udf.name.clone(), Arc::new(udf.into_df_udf()));
}
@@ -112,12 +112,59 @@ impl QueryEngineState {
}
#[inline]
pub(crate) fn df_context(&self) -> &ExecutionContext {
&self.df_context
pub(crate) fn task_ctx(&self) -> Arc<TaskContext> {
self.df_context.task_ctx()
}
#[inline]
pub(crate) fn runtime(&self) -> Arc<RuntimeEnv> {
self.df_context.runtime_env()
pub(crate) fn get_table_provider(
&self,
schema: Option<&str>,
name: TableReference,
) -> DfResult<Arc<dyn TableSource>> {
let state = self.df_context.state.read();
match name {
TableReference::Bare { table } if schema.is_some() => {
state.get_table_provider(TableReference::Partial {
// unwrap safety: checked in this match's arm
schema: schema.unwrap(),
table,
})
}
_ => state.get_table_provider(name),
}
}
pub(crate) fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
let state = self.df_context.state.read();
state.get_function_meta(name)
}
pub(crate) fn get_variable_type(&self, variable_names: &[String]) -> Option<DataType> {
let state = self.df_context.state.read();
state.get_variable_type(variable_names)
}
pub(crate) fn optimize(&self, plan: &DfLogicalPlan) -> DfResult<DfLogicalPlan> {
self.df_context.optimize(plan)
}
pub(crate) async fn create_physical_plan(
&self,
logical_plan: &DfLogicalPlan,
) -> DfResult<Arc<dyn ExecutionPlan>> {
self.df_context.create_physical_plan(logical_plan).await
}
pub(crate) fn optimize_physical_plan(
&self,
mut plan: Arc<dyn ExecutionPlan>,
) -> DfResult<Arc<dyn ExecutionPlan>> {
let state = self.df_context.state.read();
let config = &state.config;
for optimizer in &state.physical_optimizers {
plan = optimizer.optimize(plan, config)?;
}
Ok(plan)
}
}

View File

@@ -261,10 +261,9 @@ mod test {
use common_query::Output;
use common_recordbatch::{RecordBatch, RecordBatches};
use common_time::timestamp::TimeUnit;
use datatypes::arrow::array::PrimitiveArray;
use datatypes::prelude::ConcreteDataType;
use datatypes::schema::{ColumnDefaultConstraint, ColumnSchema, Schema, SchemaRef};
use datatypes::vectors::{StringVector, TimestampVector, UInt32Vector, VectorRef};
use datatypes::vectors::{StringVector, TimestampMillisecondVector, UInt32Vector, VectorRef};
use snafu::ResultExt;
use sql::statements::describe::DescribeTable;
use table::test_util::MemTable;
@@ -379,12 +378,12 @@ mod test {
.with_time_index(true),
];
let data = vec![
Arc::new(UInt32Vector::from_vec(vec![0])) as _,
Arc::new(TimestampVector::new(PrimitiveArray::from_vec(vec![0]))) as _,
Arc::new(UInt32Vector::from_slice(&[0])) as _,
Arc::new(TimestampMillisecondVector::from_slice(&[0])) as _,
];
let expected_columns = vec![
Arc::new(StringVector::from(vec!["t1", "t2"])) as _,
Arc::new(StringVector::from(vec!["UInt32", "Timestamp"])) as _,
Arc::new(StringVector::from(vec!["UInt32", "TimestampMillisecond"])) as _,
Arc::new(StringVector::from(vec![NULLABLE_YES, NULLABLE_NO])) as _,
Arc::new(StringVector::from(vec!["", "current_timestamp()"])) as _,
Arc::new(StringVector::from(vec![

View File

@@ -12,16 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
mod function;
use std::sync::Arc;
use common_query::Output;
use common_recordbatch::error::Result as RecordResult;
use common_recordbatch::{util, RecordBatch};
use datafusion::field_util::{FieldExt, SchemaExt};
use datatypes::for_all_primitive_types;
use datatypes::prelude::*;
use datatypes::types::PrimitiveElement;
use function::{create_query_engine, get_numbers_from_table};
use datatypes::types::WrapperType;
use query::error::Result;
use query::QueryEngine;
use session::context::QueryContext;
@@ -29,7 +29,7 @@ use session::context::QueryContext;
#[tokio::test]
async fn test_argmax_aggregator() -> Result<()> {
common_telemetry::init_default_ut_logging();
let engine = create_query_engine();
let engine = function::create_query_engine();
macro_rules! test_argmax {
([], $( { $T:ty } ),*) => {
@@ -49,33 +49,23 @@ async fn test_argmax_success<T>(
engine: Arc<dyn QueryEngine>,
) -> Result<()>
where
T: PrimitiveElement + PartialOrd,
for<'a> T: Scalar<RefType<'a> = T>,
T: WrapperType + PartialOrd,
{
let result = execute_argmax(column_name, table_name, engine.clone())
.await
.unwrap();
assert_eq!(1, result.len());
assert_eq!(result[0].df_recordbatch.num_columns(), 1);
assert_eq!(1, result[0].schema.arrow_schema().fields().len());
assert_eq!("argmax", result[0].schema.arrow_schema().field(0).name());
let value = function::get_value_from_batches("argmax", result);
let columns = result[0].df_recordbatch.columns();
assert_eq!(1, columns.len());
assert_eq!(columns[0].len(), 1);
let v = VectorHelper::try_into_vector(&columns[0]).unwrap();
assert_eq!(1, v.len());
let value = v.get(0);
let numbers = get_numbers_from_table::<T>(column_name, table_name, engine.clone()).await;
let numbers =
function::get_numbers_from_table::<T>(column_name, table_name, engine.clone()).await;
let expected_value = match numbers.len() {
0 => 0_u64,
_ => {
let mut index = 0;
let mut max = numbers[0].into();
let mut max = numbers[0];
for (i, &number) in numbers.iter().enumerate() {
if max < number.into() {
max = number.into();
if max < number {
max = number;
index = i;
}
}

View File

@@ -12,17 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
mod function;
use std::sync::Arc;
use common_query::Output;
use common_recordbatch::error::Result as RecordResult;
use common_recordbatch::{util, RecordBatch};
use datafusion::field_util::{FieldExt, SchemaExt};
use datatypes::for_all_primitive_types;
use datatypes::prelude::*;
use datatypes::types::PrimitiveElement;
use function::{create_query_engine, get_numbers_from_table};
use datatypes::types::WrapperType;
use query::error::Result;
use query::QueryEngine;
use session::context::QueryContext;
@@ -30,7 +29,7 @@ use session::context::QueryContext;
#[tokio::test]
async fn test_argmin_aggregator() -> Result<()> {
common_telemetry::init_default_ut_logging();
let engine = create_query_engine();
let engine = function::create_query_engine();
macro_rules! test_argmin {
([], $( { $T:ty } ),*) => {
@@ -50,33 +49,23 @@ async fn test_argmin_success<T>(
engine: Arc<dyn QueryEngine>,
) -> Result<()>
where
T: PrimitiveElement + PartialOrd,
for<'a> T: Scalar<RefType<'a> = T>,
T: WrapperType + PartialOrd,
{
let result = execute_argmin(column_name, table_name, engine.clone())
.await
.unwrap();
assert_eq!(1, result.len());
assert_eq!(result[0].df_recordbatch.num_columns(), 1);
assert_eq!(1, result[0].schema.arrow_schema().fields().len());
assert_eq!("argmin", result[0].schema.arrow_schema().field(0).name());
let value = function::get_value_from_batches("argmin", result);
let columns = result[0].df_recordbatch.columns();
assert_eq!(1, columns.len());
assert_eq!(columns[0].len(), 1);
let v = VectorHelper::try_into_vector(&columns[0]).unwrap();
assert_eq!(1, v.len());
let value = v.get(0);
let numbers = get_numbers_from_table::<T>(column_name, table_name, engine.clone()).await;
let numbers =
function::get_numbers_from_table::<T>(column_name, table_name, engine.clone()).await;
let expected_value = match numbers.len() {
0 => 0_u32,
_ => {
let mut index = 0;
let mut min = numbers[0].into();
let mut min = numbers[0];
for (i, &number) in numbers.iter().enumerate() {
if min > number.into() {
min = number.into();
if min > number {
min = number;
index = i;
}
}

View File

@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// FIXME(yingwen): Consider move all tests under query/tests to query/src so we could reuse
// more codes.
use std::sync::Arc;
use catalog::local::{MemoryCatalogManager, MemoryCatalogProvider, MemorySchemaProvider};
@@ -22,8 +24,8 @@ use common_recordbatch::{util, RecordBatch};
use datatypes::for_all_primitive_types;
use datatypes::prelude::*;
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::types::PrimitiveElement;
use datatypes::vectors::PrimitiveVector;
use datatypes::types::WrapperType;
use datatypes::vectors::Helper;
use query::query_engine::QueryEngineFactory;
use query::QueryEngine;
use rand::Rng;
@@ -47,7 +49,7 @@ pub fn create_query_engine() -> Arc<dyn QueryEngine> {
column_schemas.push(column_schema);
let numbers = (1..=10).map(|_| rng.gen::<$T>()).collect::<Vec<$T>>();
let column: VectorRef = Arc::new(PrimitiveVector::<$T>::from_vec(numbers.to_vec()));
let column: VectorRef = Arc::new(<$T as Scalar>::VectorType::from_vec(numbers.to_vec()));
columns.push(column);
)*
}
@@ -77,8 +79,7 @@ pub async fn get_numbers_from_table<'s, T>(
engine: Arc<dyn QueryEngine>,
) -> Vec<T>
where
T: PrimitiveElement,
for<'a> T: Scalar<RefType<'a> = T>,
T: WrapperType,
{
let sql = format!("SELECT {} FROM {}", column_name, table_name);
let plan = engine
@@ -92,8 +93,21 @@ where
};
let numbers = util::collect(recordbatch_stream).await.unwrap();
let columns = numbers[0].df_recordbatch.columns();
let column = VectorHelper::try_into_vector(&columns[0]).unwrap();
let column: &<T as Scalar>::VectorType = unsafe { VectorHelper::static_cast(&column) };
let column = numbers[0].column(0);
let column: &<T as Scalar>::VectorType = unsafe { Helper::static_cast(column) };
column.iter_data().flatten().collect::<Vec<T>>()
}
pub fn get_value_from_batches(column_name: &str, batches: Vec<RecordBatch>) -> Value {
assert_eq!(1, batches.len());
assert_eq!(batches[0].num_columns(), 1);
assert_eq!(1, batches[0].schema.num_columns());
assert_eq!(column_name, batches[0].schema.column_schemas()[0].name);
let batch = &batches[0];
assert_eq!(1, batch.num_columns());
assert_eq!(batch.column(0).len(), 1);
let v = batch.column(0);
assert_eq!(1, v.len());
v.get(0)
}

View File

@@ -12,19 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
mod function;
use std::sync::Arc;
use common_query::Output;
use common_recordbatch::error::Result as RecordResult;
use common_recordbatch::{util, RecordBatch};
use datafusion::field_util::{FieldExt, SchemaExt};
use datatypes::for_all_primitive_types;
use datatypes::prelude::*;
use datatypes::types::PrimitiveElement;
use datatypes::types::WrapperType;
use datatypes::value::OrderedFloat;
use format_num::NumberFormat;
use function::{create_query_engine, get_numbers_from_table};
use num_traits::AsPrimitive;
use query::error::Result;
use query::QueryEngine;
@@ -33,7 +32,7 @@ use session::context::QueryContext;
#[tokio::test]
async fn test_mean_aggregator() -> Result<()> {
common_telemetry::init_default_ut_logging();
let engine = create_query_engine();
let engine = function::create_query_engine();
macro_rules! test_mean {
([], $( { $T:ty } ),*) => {
@@ -53,25 +52,15 @@ async fn test_mean_success<T>(
engine: Arc<dyn QueryEngine>,
) -> Result<()>
where
T: PrimitiveElement + AsPrimitive<f64>,
for<'a> T: Scalar<RefType<'a> = T>,
T: WrapperType + AsPrimitive<f64>,
{
let result = execute_mean(column_name, table_name, engine.clone())
.await
.unwrap();
assert_eq!(1, result.len());
assert_eq!(result[0].df_recordbatch.num_columns(), 1);
assert_eq!(1, result[0].schema.arrow_schema().fields().len());
assert_eq!("mean", result[0].schema.arrow_schema().field(0).name());
let value = function::get_value_from_batches("mean", result);
let columns = result[0].df_recordbatch.columns();
assert_eq!(1, columns.len());
assert_eq!(columns[0].len(), 1);
let v = VectorHelper::try_into_vector(&columns[0]).unwrap();
assert_eq!(1, v.len());
let value = v.get(0);
let numbers = get_numbers_from_table::<T>(column_name, table_name, engine.clone()).await;
let numbers =
function::get_numbers_from_table::<T>(column_name, table_name, engine.clone()).await;
let expected_value = numbers.iter().map(|&n| n.as_()).collect::<Vec<f64>>();
let expected_value = inc_stats::mean(expected_value.iter().cloned()).unwrap();

View File

@@ -26,12 +26,10 @@ use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*;
use common_query::Output;
use common_recordbatch::{util, RecordBatch};
use datafusion::arrow_print;
use datafusion_common::record_batch::RecordBatch as DfRecordBatch;
use datatypes::prelude::*;
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::types::{PrimitiveElement, PrimitiveType};
use datatypes::vectors::PrimitiveVector;
use datatypes::types::{LogicalPrimitiveType, WrapperType};
use datatypes::vectors::Helper;
use datatypes::with_match_primitive_type_id;
use num_traits::AsPrimitive;
use query::error::Result;
@@ -40,28 +38,30 @@ use session::context::QueryContext;
use table::test_util::MemTable;
#[derive(Debug, Default)]
struct MySumAccumulator<T, SumT>
where
T: Primitive + AsPrimitive<SumT>,
SumT: Primitive + std::ops::AddAssign,
{
struct MySumAccumulator<T, SumT> {
sum: SumT,
_phantom: PhantomData<T>,
}
impl<T, SumT> MySumAccumulator<T, SumT>
where
T: Primitive + AsPrimitive<SumT>,
SumT: Primitive + std::ops::AddAssign,
T: WrapperType,
SumT: WrapperType,
T::Native: AsPrimitive<SumT::Native>,
SumT::Native: std::ops::AddAssign,
{
#[inline(always)]
fn add(&mut self, v: T) {
self.sum += v.as_();
let mut sum_native = self.sum.into_native();
sum_native += v.into_native().as_();
self.sum = SumT::from_native(sum_native);
}
#[inline(always)]
fn merge(&mut self, s: SumT) {
self.sum += s;
let mut sum_native = self.sum.into_native();
sum_native += s.into_native();
self.sum = SumT::from_native(sum_native);
}
}
@@ -76,7 +76,7 @@ impl AggregateFunctionCreator for MySumAccumulatorCreator {
with_match_primitive_type_id!(
input_type.logical_type_id(),
|$S| {
Ok(Box::new(MySumAccumulator::<$S, <$S as Primitive>::LargestType>::default()))
Ok(Box::new(MySumAccumulator::<<$S as LogicalPrimitiveType>::Wrapper, <<$S as LogicalPrimitiveType>::LargestType as LogicalPrimitiveType>::Wrapper>::default()))
},
{
let err_msg = format!(
@@ -95,7 +95,7 @@ impl AggregateFunctionCreator for MySumAccumulatorCreator {
with_match_primitive_type_id!(
input_type.logical_type_id(),
|$S| {
Ok(PrimitiveType::<<$S as Primitive>::LargestType>::default().logical_type_id().data_type())
Ok(<<$S as LogicalPrimitiveType>::LargestType>::build_data_type())
},
{
unreachable!()
@@ -110,10 +110,10 @@ impl AggregateFunctionCreator for MySumAccumulatorCreator {
impl<T, SumT> Accumulator for MySumAccumulator<T, SumT>
where
T: Primitive + AsPrimitive<SumT>,
for<'a> T: Scalar<RefType<'a> = T>,
SumT: Primitive + std::ops::AddAssign,
for<'a> SumT: Scalar<RefType<'a> = SumT>,
T: WrapperType,
SumT: WrapperType,
T::Native: AsPrimitive<SumT::Native>,
SumT::Native: std::ops::AddAssign,
{
fn state(&self) -> QueryResult<Vec<Value>> {
Ok(vec![self.sum.into()])
@@ -124,7 +124,7 @@ where
return Ok(());
};
let column = &values[0];
let column: &<T as Scalar>::VectorType = unsafe { VectorHelper::static_cast(column) };
let column: &<T as Scalar>::VectorType = unsafe { Helper::static_cast(column) };
for v in column.iter_data().flatten() {
self.add(v)
}
@@ -136,7 +136,7 @@ where
return Ok(());
};
let states = &states[0];
let states: &<SumT as Scalar>::VectorType = unsafe { VectorHelper::static_cast(states) };
let states: &<SumT as Scalar>::VectorType = unsafe { Helper::static_cast(states) };
for s in states.iter_data().flatten() {
self.merge(s)
}
@@ -154,65 +154,57 @@ async fn test_my_sum() -> Result<()> {
test_my_sum_with(
(1..=10).collect::<Vec<u32>>(),
vec![
"+--------+",
"| my_sum |",
"+--------+",
"| 55 |",
"+--------+",
],
r#"+--------+
| my_sum |
+--------+
| 55 |
+--------+"#,
)
.await?;
test_my_sum_with(
(-10..=11).collect::<Vec<i32>>(),
vec![
"+--------+",
"| my_sum |",
"+--------+",
"| 11 |",
"+--------+",
],
r#"+--------+
| my_sum |
+--------+
| 11 |
+--------+"#,
)
.await?;
test_my_sum_with(
vec![-1.0f32, 1.0, 2.0, 3.0, 4.0],
vec![
"+--------+",
"| my_sum |",
"+--------+",
"| 9 |",
"+--------+",
],
r#"+--------+
| my_sum |
+--------+
| 9 |
+--------+"#,
)
.await?;
test_my_sum_with(
vec![u32::MAX, u32::MAX],
vec![
"+------------+",
"| my_sum |",
"+------------+",
"| 8589934590 |",
"+------------+",
],
r#"+------------+
| my_sum |
+------------+
| 8589934590 |
+------------+"#,
)
.await?;
Ok(())
}
async fn test_my_sum_with<T>(numbers: Vec<T>, expected: Vec<&str>) -> Result<()>
async fn test_my_sum_with<T>(numbers: Vec<T>, expected: &str) -> Result<()>
where
T: PrimitiveElement,
T: WrapperType,
{
let table_name = format!("{}_numbers", std::any::type_name::<T>());
let column_name = format!("{}_number", std::any::type_name::<T>());
let column_schemas = vec![ColumnSchema::new(
column_name.clone(),
T::build_data_type(),
T::LogicalType::build_data_type(),
true,
)];
let schema = Arc::new(Schema::new(column_schemas.clone()));
let column: VectorRef = Arc::new(PrimitiveVector::<T>::from_vec(numbers));
let column: VectorRef = Arc::new(T::VectorType::from_vec(numbers));
let recordbatch = RecordBatch::new(schema, vec![column]).unwrap();
let testing_table = MemTable::new(&table_name, recordbatch);
@@ -236,14 +228,9 @@ where
Output::Stream(batch) => batch,
_ => unreachable!(),
};
let recordbatch = util::collect(recordbatch_stream).await.unwrap();
let df_recordbatch = recordbatch
.into_iter()
.map(|r| r.df_recordbatch)
.collect::<Vec<DfRecordBatch>>();
let batches = util::collect_batches(recordbatch_stream).await.unwrap();
let pretty_print = arrow_print::write(&df_recordbatch);
let pretty_print = pretty_print.lines().collect::<Vec<&str>>();
let pretty_print = batches.pretty_print().unwrap();
assert_eq!(expected, pretty_print);
Ok(())
}

View File

@@ -20,12 +20,10 @@ use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_query::Output;
use common_recordbatch::error::Result as RecordResult;
use common_recordbatch::{util, RecordBatch};
use datafusion::field_util::{FieldExt, SchemaExt};
use datatypes::for_all_primitive_types;
use datatypes::prelude::*;
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::types::PrimitiveElement;
use datatypes::vectors::PrimitiveVector;
use datatypes::vectors::Int32Vector;
use function::{create_query_engine, get_numbers_from_table};
use num_traits::AsPrimitive;
use query::error::Result;
@@ -64,9 +62,8 @@ async fn test_percentile_correctness() -> Result<()> {
_ => unreachable!(),
};
let record_batch = util::collect(recordbatch_stream).await.unwrap();
let columns = record_batch[0].df_recordbatch.columns();
let v = VectorHelper::try_into_vector(&columns[0]).unwrap();
let value = v.get(0);
let column = record_batch[0].column(0);
let value = column.get(0);
assert_eq!(value, Value::from(9.280_000_000_000_001_f64));
Ok(())
}
@@ -77,26 +74,12 @@ async fn test_percentile_success<T>(
engine: Arc<dyn QueryEngine>,
) -> Result<()>
where
T: PrimitiveElement + AsPrimitive<f64>,
for<'a> T: Scalar<RefType<'a> = T>,
T: WrapperType + AsPrimitive<f64>,
{
let result = execute_percentile(column_name, table_name, engine.clone())
.await
.unwrap();
assert_eq!(1, result.len());
assert_eq!(result[0].df_recordbatch.num_columns(), 1);
assert_eq!(1, result[0].schema.arrow_schema().fields().len());
assert_eq!(
"percentile",
result[0].schema.arrow_schema().field(0).name()
);
let columns = result[0].df_recordbatch.columns();
assert_eq!(1, columns.len());
assert_eq!(columns[0].len(), 1);
let v = VectorHelper::try_into_vector(&columns[0]).unwrap();
assert_eq!(1, v.len());
let value = v.get(0);
let value = function::get_value_from_batches("percentile", result);
let numbers = get_numbers_from_table::<T>(column_name, table_name, engine.clone()).await;
let expected_value = numbers.iter().map(|&n| n.as_()).collect::<Vec<f64>>();
@@ -140,9 +123,9 @@ fn create_correctness_engine() -> Arc<dyn QueryEngine> {
let column_schema = ColumnSchema::new("corr_number", ConcreteDataType::int32_datatype(), true);
column_schemas.push(column_schema);
let numbers = vec![3_i32, 6_i32, 8_i32, 10_i32];
let numbers = [3_i32, 6_i32, 8_i32, 10_i32];
let column: VectorRef = Arc::new(PrimitiveVector::<i32>::from_vec(numbers.to_vec()));
let column: VectorRef = Arc::new(Int32Vector::from_slice(&numbers));
columns.push(column);
let schema = Arc::new(Schema::new(column_schemas));

View File

@@ -18,11 +18,9 @@ mod function;
use common_query::Output;
use common_recordbatch::error::Result as RecordResult;
use common_recordbatch::{util, RecordBatch};
use datafusion::field_util::{FieldExt, SchemaExt};
use datatypes::for_all_primitive_types;
use datatypes::prelude::*;
use datatypes::types::PrimitiveElement;
use function::{create_query_engine, get_numbers_from_table};
use datatypes::types::WrapperType;
use num_traits::AsPrimitive;
use query::error::Result;
use query::QueryEngine;
@@ -31,13 +29,13 @@ use session::context::QueryContext;
#[tokio::test]
async fn test_polyval_aggregator() -> Result<()> {
common_telemetry::init_default_ut_logging();
let engine = create_query_engine();
let engine = function::create_query_engine();
macro_rules! test_polyval {
([], $( { $T:ty } ),*) => {
$(
let column_name = format!("{}_number", std::any::type_name::<$T>());
test_polyval_success::<$T,<$T as Primitive>::LargestType>(&column_name, "numbers", engine.clone()).await?;
test_polyval_success::<$T, <<<$T as WrapperType>::LogicalType as LogicalPrimitiveType>::LargestType as LogicalPrimitiveType>::Wrapper>(&column_name, "numbers", engine.clone()).await?;
)*
}
}
@@ -51,36 +49,27 @@ async fn test_polyval_success<T, PolyT>(
engine: Arc<dyn QueryEngine>,
) -> Result<()>
where
T: Primitive + AsPrimitive<PolyT> + PrimitiveElement,
PolyT: Primitive + std::ops::Mul<Output = PolyT> + std::iter::Sum,
for<'a> T: Scalar<RefType<'a> = T>,
for<'a> PolyT: Scalar<RefType<'a> = PolyT>,
i64: AsPrimitive<PolyT>,
T: WrapperType,
PolyT: WrapperType,
T::Native: AsPrimitive<PolyT::Native>,
PolyT::Native: std::ops::Mul<Output = PolyT::Native> + std::iter::Sum,
i64: AsPrimitive<PolyT::Native>,
{
let result = execute_polyval(column_name, table_name, engine.clone())
.await
.unwrap();
assert_eq!(1, result.len());
assert_eq!(result[0].df_recordbatch.num_columns(), 1);
assert_eq!(1, result[0].schema.arrow_schema().fields().len());
assert_eq!("polyval", result[0].schema.arrow_schema().field(0).name());
let value = function::get_value_from_batches("polyval", result);
let columns = result[0].df_recordbatch.columns();
assert_eq!(1, columns.len());
assert_eq!(columns[0].len(), 1);
let v = VectorHelper::try_into_vector(&columns[0]).unwrap();
assert_eq!(1, v.len());
let value = v.get(0);
let numbers = get_numbers_from_table::<T>(column_name, table_name, engine.clone()).await;
let numbers =
function::get_numbers_from_table::<T>(column_name, table_name, engine.clone()).await;
let expected_value = numbers.iter().copied();
let x = 0i64;
let len = expected_value.len();
let expected_value: PolyT = expected_value
let expected_native: PolyT::Native = expected_value
.enumerate()
.map(|(i, value)| value.as_() * (x.pow((len - 1 - i) as u32)).as_())
.map(|(i, v)| v.into_native().as_() * (x.pow((len - 1 - i) as u32)).as_())
.sum();
assert_eq!(value, expected_value.into());
assert_eq!(value, PolyT::from_native(expected_native).into());
Ok(())
}

View File

@@ -32,7 +32,7 @@ pub fn pow(args: &[VectorRef]) -> Result<VectorRef> {
assert_eq!(exponent.len(), base.len());
let v = base
let iter = base
.iter_data()
.zip(exponent.iter_data())
.map(|(base, exponent)| {
@@ -42,8 +42,8 @@ pub fn pow(args: &[VectorRef]) -> Result<VectorRef> {
(Some(base), Some(exponent)) => Some(base.pow(exponent)),
_ => None,
}
})
.collect::<UInt32Vector>();
});
let v = UInt32Vector::from_owned_iterator(iter);
Ok(Arc::new(v) as _)
}

View File

@@ -13,30 +13,28 @@
// limitations under the License.
mod pow;
// This is used to suppress the warning: function `create_query_engine` is never used.
// FIXME(yingwen): We finally need to refactor these tests and move them to `query/src`
// so tests can share codes with other mods.
#[allow(unused)]
mod function;
use std::sync::Arc;
use catalog::local::{MemoryCatalogManager, MemoryCatalogProvider, MemorySchemaProvider};
use catalog::local::{MemoryCatalogProvider, MemorySchemaProvider};
use catalog::{CatalogList, CatalogProvider, SchemaProvider};
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_query::prelude::{create_udf, make_scalar_function, Volatility};
use common_query::Output;
use common_recordbatch::error::Result as RecordResult;
use common_recordbatch::{util, RecordBatch};
use datafusion::field_util::{FieldExt, SchemaExt};
use datafusion::logical_plan::LogicalPlanBuilder;
use datatypes::arrow::array::UInt32Array;
use datatypes::for_all_primitive_types;
use datafusion::datasource::DefaultTableSource;
use datafusion_expr::logical_plan::builder::LogicalPlanBuilder;
use datatypes::prelude::*;
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::types::{OrdPrimitive, PrimitiveElement};
use datatypes::vectors::{PrimitiveVector, UInt32Vector};
use num::NumCast;
use datatypes::vectors::UInt32Vector;
use query::error::Result;
use query::plan::LogicalPlan;
use query::query_engine::QueryEngineFactory;
use query::QueryEngine;
use rand::Rng;
use session::context::QueryContext;
use table::table::adapter::DfTableProviderAdapter;
use table::table::numbers::NumbersTable;
@@ -66,12 +64,16 @@ async fn test_datafusion_query_engine() -> Result<()> {
let limit = 10;
let table_provider = Arc::new(DfTableProviderAdapter::new(table.clone()));
let plan = LogicalPlan::DfPlan(
LogicalPlanBuilder::scan("numbers", table_provider, None)
.unwrap()
.limit(limit)
.unwrap()
.build()
.unwrap(),
LogicalPlanBuilder::scan(
"numbers",
Arc::new(DefaultTableSource { table_provider }),
None,
)
.unwrap()
.limit(0, Some(limit))
.unwrap()
.build()
.unwrap(),
);
let output = engine.execute(&plan).await?;
@@ -84,17 +86,17 @@ async fn test_datafusion_query_engine() -> Result<()> {
let numbers = util::collect(recordbatch).await.unwrap();
assert_eq!(1, numbers.len());
assert_eq!(numbers[0].df_recordbatch.num_columns(), 1);
assert_eq!(1, numbers[0].schema.arrow_schema().fields().len());
assert_eq!("number", numbers[0].schema.arrow_schema().field(0).name());
assert_eq!(numbers[0].num_columns(), 1);
assert_eq!(1, numbers[0].schema.num_columns());
assert_eq!("number", numbers[0].schema.column_schemas()[0].name);
let columns = numbers[0].df_recordbatch.columns();
assert_eq!(1, columns.len());
assert_eq!(columns[0].len(), limit);
let batch = &numbers[0];
assert_eq!(1, batch.num_columns());
assert_eq!(batch.column(0).len(), limit);
let expected: Vec<u32> = (0u32..limit as u32).collect();
assert_eq!(
*columns[0].as_any().downcast_ref::<UInt32Array>().unwrap(),
UInt32Array::from_slice(&expected)
*batch.column(0),
Arc::new(UInt32Vector::from_slice(&expected)) as VectorRef
);
Ok(())
@@ -123,7 +125,8 @@ async fn test_udf() -> Result<()> {
let pow = make_scalar_function(pow);
let udf = create_udf(
"pow",
// datafusion already supports pow, so we use a different name.
"my_pow",
vec![
ConcreteDataType::uint32_datatype(),
ConcreteDataType::uint32_datatype(),
@@ -136,7 +139,7 @@ async fn test_udf() -> Result<()> {
engine.register_udf(udf);
let plan = engine.sql_to_plan(
"select pow(number, number) as p from numbers limit 10",
"select my_pow(number, number) as p from numbers limit 10",
Arc::new(QueryContext::new()),
)?;
@@ -148,202 +151,18 @@ async fn test_udf() -> Result<()> {
let numbers = util::collect(recordbatch).await.unwrap();
assert_eq!(1, numbers.len());
assert_eq!(numbers[0].df_recordbatch.num_columns(), 1);
assert_eq!(1, numbers[0].schema.arrow_schema().fields().len());
assert_eq!("p", numbers[0].schema.arrow_schema().field(0).name());
assert_eq!(numbers[0].num_columns(), 1);
assert_eq!(1, numbers[0].schema.num_columns());
assert_eq!("p", numbers[0].schema.column_schemas()[0].name);
let columns = numbers[0].df_recordbatch.columns();
assert_eq!(1, columns.len());
assert_eq!(columns[0].len(), 10);
let batch = &numbers[0];
assert_eq!(1, batch.num_columns());
assert_eq!(batch.column(0).len(), 10);
let expected: Vec<u32> = vec![1, 1, 4, 27, 256, 3125, 46656, 823543, 16777216, 387420489];
assert_eq!(
*columns[0].as_any().downcast_ref::<UInt32Array>().unwrap(),
UInt32Array::from_slice(&expected)
*batch.column(0),
Arc::new(UInt32Vector::from_slice(&expected)) as VectorRef
);
Ok(())
}
fn create_query_engine() -> Arc<dyn QueryEngine> {
let schema_provider = Arc::new(MemorySchemaProvider::new());
let catalog_provider = Arc::new(MemoryCatalogProvider::new());
let catalog_list = Arc::new(MemoryCatalogManager::default());
// create table with primitives, and all columns' length are even
let mut column_schemas = vec![];
let mut columns = vec![];
macro_rules! create_even_number_table {
([], $( { $T:ty } ),*) => {
$(
let mut rng = rand::thread_rng();
let column_name = format!("{}_number_even", std::any::type_name::<$T>());
let column_schema = ColumnSchema::new(column_name, Value::from(<$T>::default()).data_type(), true);
column_schemas.push(column_schema);
let numbers = (1..=100).map(|_| rng.gen::<$T>()).collect::<Vec<$T>>();
let column: VectorRef = Arc::new(PrimitiveVector::<$T>::from_vec(numbers.to_vec()));
columns.push(column);
)*
}
}
for_all_primitive_types! { create_even_number_table }
let schema = Arc::new(Schema::new(column_schemas.clone()));
let recordbatch = RecordBatch::new(schema, columns).unwrap();
let even_number_table = Arc::new(MemTable::new("even_numbers", recordbatch));
schema_provider
.register_table(
even_number_table.table_name().to_string(),
even_number_table,
)
.unwrap();
// create table with primitives, and all columns' length are odd
let mut column_schemas = vec![];
let mut columns = vec![];
macro_rules! create_odd_number_table {
([], $( { $T:ty } ),*) => {
$(
let mut rng = rand::thread_rng();
let column_name = format!("{}_number_odd", std::any::type_name::<$T>());
let column_schema = ColumnSchema::new(column_name, Value::from(<$T>::default()).data_type(), true);
column_schemas.push(column_schema);
let numbers = (1..=99).map(|_| rng.gen::<$T>()).collect::<Vec<$T>>();
let column: VectorRef = Arc::new(PrimitiveVector::<$T>::from_vec(numbers.to_vec()));
columns.push(column);
)*
}
}
for_all_primitive_types! { create_odd_number_table }
let schema = Arc::new(Schema::new(column_schemas.clone()));
let recordbatch = RecordBatch::new(schema, columns).unwrap();
let odd_number_table = Arc::new(MemTable::new("odd_numbers", recordbatch));
schema_provider
.register_table(odd_number_table.table_name().to_string(), odd_number_table)
.unwrap();
catalog_provider
.register_schema(DEFAULT_SCHEMA_NAME.to_string(), schema_provider)
.unwrap();
catalog_list
.register_catalog(DEFAULT_CATALOG_NAME.to_string(), catalog_provider)
.unwrap();
QueryEngineFactory::new(catalog_list).query_engine()
}
async fn get_numbers_from_table<'s, T>(
column_name: &'s str,
table_name: &'s str,
engine: Arc<dyn QueryEngine>,
) -> Vec<OrdPrimitive<T>>
where
T: PrimitiveElement,
for<'a> T: Scalar<RefType<'a> = T>,
{
let sql = format!("SELECT {} FROM {}", column_name, table_name);
let plan = engine
.sql_to_plan(&sql, Arc::new(QueryContext::new()))
.unwrap();
let output = engine.execute(&plan).await.unwrap();
let recordbatch_stream = match output {
Output::Stream(batch) => batch,
_ => unreachable!(),
};
let numbers = util::collect(recordbatch_stream).await.unwrap();
let columns = numbers[0].df_recordbatch.columns();
let column = VectorHelper::try_into_vector(&columns[0]).unwrap();
let column: &<T as Scalar>::VectorType = unsafe { VectorHelper::static_cast(&column) };
column
.iter_data()
.flatten()
.map(|x| OrdPrimitive::<T>(x))
.collect::<Vec<OrdPrimitive<T>>>()
}
#[tokio::test]
async fn test_median_aggregator() -> Result<()> {
common_telemetry::init_default_ut_logging();
let engine = create_query_engine();
macro_rules! test_median {
([], $( { $T:ty } ),*) => {
$(
let column_name = format!("{}_number_even", std::any::type_name::<$T>());
test_median_success::<$T>(&column_name, "even_numbers", engine.clone()).await?;
let column_name = format!("{}_number_odd", std::any::type_name::<$T>());
test_median_success::<$T>(&column_name, "odd_numbers", engine.clone()).await?;
)*
}
}
for_all_primitive_types! { test_median }
Ok(())
}
async fn test_median_success<T>(
column_name: &str,
table_name: &str,
engine: Arc<dyn QueryEngine>,
) -> Result<()>
where
T: PrimitiveElement,
for<'a> T: Scalar<RefType<'a> = T>,
{
let result = execute_median(column_name, table_name, engine.clone())
.await
.unwrap();
assert_eq!(1, result.len());
assert_eq!(result[0].df_recordbatch.num_columns(), 1);
assert_eq!(1, result[0].schema.arrow_schema().fields().len());
assert_eq!("median", result[0].schema.arrow_schema().field(0).name());
let columns = result[0].df_recordbatch.columns();
assert_eq!(1, columns.len());
assert_eq!(columns[0].len(), 1);
let v = VectorHelper::try_into_vector(&columns[0]).unwrap();
assert_eq!(1, v.len());
let median = v.get(0);
let mut numbers = get_numbers_from_table::<T>(column_name, table_name, engine.clone()).await;
numbers.sort();
let len = numbers.len();
let expected_median: Value = if len % 2 == 1 {
numbers[len / 2]
} else {
let a: f64 = NumCast::from(numbers[len / 2 - 1].as_primitive()).unwrap();
let b: f64 = NumCast::from(numbers[len / 2].as_primitive()).unwrap();
OrdPrimitive::<T>(NumCast::from(a / 2.0 + b / 2.0).unwrap())
}
.into();
assert_eq!(expected_median, median);
Ok(())
}
async fn execute_median<'a>(
column_name: &'a str,
table_name: &'a str,
engine: Arc<dyn QueryEngine>,
) -> RecordResult<Vec<RecordBatch>> {
let sql = format!(
"select MEDIAN({}) as median from {}",
column_name, table_name
);
let plan = engine
.sql_to_plan(&sql, Arc::new(QueryContext::new()))
.unwrap();
let output = engine.execute(&plan).await.unwrap();
let recordbatch_stream = match output {
Output::Stream(batch) => batch,
_ => unreachable!(),
};
util::collect(recordbatch_stream).await
}

View File

@@ -18,11 +18,8 @@ mod function;
use common_query::Output;
use common_recordbatch::error::Result as RecordResult;
use common_recordbatch::{util, RecordBatch};
use datafusion::field_util::{FieldExt, SchemaExt};
use datatypes::for_all_primitive_types;
use datatypes::prelude::*;
use datatypes::types::PrimitiveElement;
use function::{create_query_engine, get_numbers_from_table};
use datatypes::types::WrapperType;
use num_traits::AsPrimitive;
use query::error::Result;
use query::QueryEngine;
@@ -33,7 +30,7 @@ use statrs::statistics::Statistics;
#[tokio::test]
async fn test_scipy_stats_norm_cdf_aggregator() -> Result<()> {
common_telemetry::init_default_ut_logging();
let engine = create_query_engine();
let engine = function::create_query_engine();
macro_rules! test_scipy_stats_norm_cdf {
([], $( { $T:ty } ),*) => {
@@ -53,28 +50,15 @@ async fn test_scipy_stats_norm_cdf_success<T>(
engine: Arc<dyn QueryEngine>,
) -> Result<()>
where
T: PrimitiveElement + AsPrimitive<f64>,
for<'a> T: Scalar<RefType<'a> = T>,
T: WrapperType + AsPrimitive<f64>,
{
let result = execute_scipy_stats_norm_cdf(column_name, table_name, engine.clone())
.await
.unwrap();
assert_eq!(1, result.len());
assert_eq!(result[0].df_recordbatch.num_columns(), 1);
assert_eq!(1, result[0].schema.arrow_schema().fields().len());
assert_eq!(
"scipy_stats_norm_cdf",
result[0].schema.arrow_schema().field(0).name()
);
let value = function::get_value_from_batches("scipy_stats_norm_cdf", result);
let columns = result[0].df_recordbatch.columns();
assert_eq!(1, columns.len());
assert_eq!(columns[0].len(), 1);
let v = VectorHelper::try_into_vector(&columns[0]).unwrap();
assert_eq!(1, v.len());
let value = v.get(0);
let numbers = get_numbers_from_table::<T>(column_name, table_name, engine.clone()).await;
let numbers =
function::get_numbers_from_table::<T>(column_name, table_name, engine.clone()).await;
let expected_value = numbers.iter().map(|&n| n.as_()).collect::<Vec<f64>>();
let mean = expected_value.clone().mean();
let stddev = expected_value.std_dev();

View File

@@ -18,11 +18,8 @@ mod function;
use common_query::Output;
use common_recordbatch::error::Result as RecordResult;
use common_recordbatch::{util, RecordBatch};
use datafusion::field_util::{FieldExt, SchemaExt};
use datatypes::for_all_primitive_types;
use datatypes::prelude::*;
use datatypes::types::PrimitiveElement;
use function::{create_query_engine, get_numbers_from_table};
use datatypes::types::WrapperType;
use num_traits::AsPrimitive;
use query::error::Result;
use query::QueryEngine;
@@ -33,7 +30,7 @@ use statrs::statistics::Statistics;
#[tokio::test]
async fn test_scipy_stats_norm_pdf_aggregator() -> Result<()> {
common_telemetry::init_default_ut_logging();
let engine = create_query_engine();
let engine = function::create_query_engine();
macro_rules! test_scipy_stats_norm_pdf {
([], $( { $T:ty } ),*) => {
@@ -53,28 +50,15 @@ async fn test_scipy_stats_norm_pdf_success<T>(
engine: Arc<dyn QueryEngine>,
) -> Result<()>
where
T: PrimitiveElement + AsPrimitive<f64>,
for<'a> T: Scalar<RefType<'a> = T>,
T: WrapperType + AsPrimitive<f64>,
{
let result = execute_scipy_stats_norm_pdf(column_name, table_name, engine.clone())
.await
.unwrap();
assert_eq!(1, result.len());
assert_eq!(result[0].df_recordbatch.num_columns(), 1);
assert_eq!(1, result[0].schema.arrow_schema().fields().len());
assert_eq!(
"scipy_stats_norm_pdf",
result[0].schema.arrow_schema().field(0).name()
);
let value = function::get_value_from_batches("scipy_stats_norm_pdf", result);
let columns = result[0].df_recordbatch.columns();
assert_eq!(1, columns.len());
assert_eq!(columns[0].len(), 1);
let v = VectorHelper::try_into_vector(&columns[0]).unwrap();
assert_eq!(1, v.len());
let value = v.get(0);
let numbers = get_numbers_from_table::<T>(column_name, table_name, engine.clone()).await;
let numbers =
function::get_numbers_from_table::<T>(column_name, table_name, engine.clone()).await;
let expected_value = numbers.iter().map(|&n| n.as_()).collect::<Vec<f64>>();
let mean = expected_value.clone().mean();
let stddev = expected_value.std_dev();

View File

@@ -15,4 +15,4 @@ itertools = "0.10"
mito = { path = "../mito" }
once_cell = "1.10"
snafu = { version = "0.7", features = ["backtraces"] }
sqlparser = "0.15.0"
sqlparser = "0.26"

View File

@@ -14,5 +14,5 @@
pub use sqlparser::ast::{
ColumnDef, ColumnOption, ColumnOptionDef, DataType, Expr, Function, FunctionArg,
FunctionArgExpr, Ident, ObjectName, SqlOption, TableConstraint, Value,
FunctionArgExpr, Ident, ObjectName, SqlOption, TableConstraint, TimezoneInfo, Value,
};

View File

@@ -505,11 +505,7 @@ mod tests {
assert_matches!(
&stmts[0],
Statement::ShowTables(ShowTables {
kind: ShowKind::Where(sqlparser::ast::Expr::BinaryOp {
left: _,
right: _,
op: sqlparser::ast::BinaryOperator::Like,
}),
kind: ShowKind::Where(sqlparser::ast::Expr::Like { .. }),
database: None,
})
);
@@ -522,11 +518,7 @@ mod tests {
assert_matches!(
&stmts[0],
Statement::ShowTables(ShowTables {
kind: ShowKind::Where(sqlparser::ast::Expr::BinaryOp {
left: _,
right: _,
op: sqlparser::ast::BinaryOperator::Like,
}),
kind: ShowKind::Where(sqlparser::ast::Expr::Like { .. }),
database: Some(_),
})
);
@@ -543,11 +535,12 @@ mod tests {
distinct: false,
top: None,
projection: vec![sqlparser::ast::SelectItem::Wildcard],
into: None,
from: vec![sqlparser::ast::TableWithJoins {
relation: sqlparser::ast::TableFactor::Table {
name: sqlparser::ast::ObjectName(vec![sqlparser::ast::Ident::new("foo")]),
alias: None,
args: vec![],
args: None,
with_hints: vec![],
},
joins: vec![],
@@ -559,11 +552,12 @@ mod tests {
distribute_by: vec![],
sort_by: vec![],
having: None,
qualify: None,
};
let sp_statement = SpStatement::Query(Box::new(SpQuery {
with: None,
body: sqlparser::ast::SetExpr::Select(Box::new(select)),
body: Box::new(sqlparser::ast::SetExpr::Select(Box::new(select))),
order_by: vec![],
limit: None,
offset: None,
@@ -576,6 +570,7 @@ mod tests {
analyze: false,
verbose: false,
statement: Box::new(sp_statement),
format: None,
})
.unwrap();

View File

@@ -318,7 +318,7 @@ pub fn sql_data_type_to_concrete_data_type(data_type: &SqlDataType) -> Result<Co
}
.fail(),
},
SqlDataType::Timestamp => Ok(ConcreteDataType::timestamp_millisecond_datatype()),
SqlDataType::Timestamp(_) => Ok(ConcreteDataType::timestamp_millisecond_datatype()),
_ => error::SqlTypeNotSupportedSnafu {
t: data_type.clone(),
}
@@ -336,7 +336,7 @@ mod tests {
use datatypes::value::OrderedFloat;
use super::*;
use crate::ast::{DataType, Ident};
use crate::ast::{Ident, TimezoneInfo};
use crate::statements::ColumnOption;
fn check_type(sql_type: SqlDataType, data_type: ConcreteDataType) {
@@ -376,7 +376,7 @@ mod tests {
ConcreteDataType::datetime_datatype(),
);
check_type(
SqlDataType::Timestamp,
SqlDataType::Timestamp(TimezoneInfo::None),
ConcreteDataType::timestamp_millisecond_datatype(),
);
}
@@ -573,7 +573,7 @@ mod tests {
// test basic
let column_def = ColumnDef {
name: "col".into(),
data_type: DataType::Double,
data_type: SqlDataType::Double,
collation: None,
options: vec![],
};
@@ -588,7 +588,7 @@ mod tests {
// test not null
let column_def = ColumnDef {
name: "col".into(),
data_type: DataType::Double,
data_type: SqlDataType::Double,
collation: None,
options: vec![ColumnOptionDef {
name: None,

View File

@@ -49,7 +49,7 @@ impl Insert {
pub fn values(&self) -> Result<Vec<Vec<Value>>> {
let values = match &self.inner {
Statement::Insert { source, .. } => match &source.body {
Statement::Insert { source, .. } => match &*source.body {
SetExpr::Values(Values(exprs)) => sql_exprs_to_values(exprs)?,
_ => unreachable!(),
},