diff --git a/Cargo.lock b/Cargo.lock index b66d9254c6..070481be91 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1899,7 +1899,7 @@ dependencies = [ "pin-project-lite", "rand 0.8.5", "smallvec", - "sqlparser 0.26.0", + "sqlparser", "tempfile", "tokio", "tokio-stream", @@ -1919,7 +1919,7 @@ dependencies = [ "object_store", "ordered-float 3.4.0", "parquet", - "sqlparser 0.26.0", + "sqlparser", ] [[package]] @@ -1932,7 +1932,7 @@ dependencies = [ "arrow", "datafusion-common", "log", - "sqlparser 0.26.0", + "sqlparser", ] [[package]] @@ -2003,7 +2003,7 @@ dependencies = [ "arrow", "datafusion-common", "datafusion-expr", - "sqlparser 0.26.0", + "sqlparser", ] [[package]] @@ -2467,7 +2467,6 @@ dependencies = [ "session", "snafu", "sql", - "sqlparser 0.15.0", "store-api", "substrait 0.1.0", "table", @@ -4950,7 +4949,9 @@ dependencies = [ "datafusion", "datafusion-common", "datafusion-expr", + "datafusion-optimizer", "datafusion-physical-expr", + "datafusion-sql", "datatypes", "format_num", "futures", @@ -6357,7 +6358,7 @@ dependencies = [ "mito", "once_cell", "snafu", - "sqlparser 0.15.0", + "sqlparser", ] [[package]] @@ -6386,15 +6387,6 @@ dependencies = [ "tokio", ] -[[package]] -name = "sqlparser" -version = "0.15.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adbbea2526ad0d02ad9414a07c396078a5b944bbf9ca4fbab8f01bb4cb579081" -dependencies = [ - "log", -] - [[package]] name = "sqlparser" version = "0.26.0" diff --git a/src/common/function/src/scalars.rs b/src/common/function/src/scalars.rs index d362ea5f89..e9499b2151 100644 --- a/src/common/function/src/scalars.rs +++ b/src/common/function/src/scalars.rs @@ -23,6 +23,5 @@ pub(crate) mod test; mod timestamp; pub mod udf; -pub use aggregate::MedianAccumulatorCreator; pub use function::{Function, FunctionRef}; pub use function_registry::{FunctionRegistry, FUNCTION_REGISTRY}; diff --git a/src/common/function/src/scalars/aggregate.rs b/src/common/function/src/scalars/aggregate.rs index 8a4712a1b8..f605fff2f2 100644 --- a/src/common/function/src/scalars/aggregate.rs +++ b/src/common/function/src/scalars/aggregate.rs @@ -16,7 +16,6 @@ mod argmax; mod argmin; mod diff; mod mean; -mod median; mod percentile; mod polyval; mod scipy_stats_norm_cdf; @@ -29,7 +28,6 @@ pub use argmin::ArgminAccumulatorCreator; use common_query::logical_plan::AggregateFunctionCreatorRef; pub use diff::DiffAccumulatorCreator; pub use mean::MeanAccumulatorCreator; -pub use median::MedianAccumulatorCreator; pub use percentile::PercentileAccumulatorCreator; pub use polyval::PolyvalAccumulatorCreator; pub use scipy_stats_norm_cdf::ScipyStatsNormCdfAccumulatorCreator; @@ -88,7 +86,6 @@ impl AggregateFunctions { }; } - register_aggr_func!("median", 1, MedianAccumulatorCreator); register_aggr_func!("diff", 1, DiffAccumulatorCreator); register_aggr_func!("mean", 1, MeanAccumulatorCreator); register_aggr_func!("polyval", 2, PolyvalAccumulatorCreator); diff --git a/src/common/function/src/scalars/aggregate/median.rs b/src/common/function/src/scalars/aggregate/median.rs deleted file mode 100644 index facbd8702a..0000000000 --- a/src/common/function/src/scalars/aggregate/median.rs +++ /dev/null @@ -1,287 +0,0 @@ -// Copyright 2022 Greptime Team -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::cmp::Reverse; -use std::collections::BinaryHeap; -use std::sync::Arc; - -use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore}; -use common_query::error::{ - CreateAccumulatorSnafu, DowncastVectorSnafu, FromScalarValueSnafu, Result, -}; -use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; -use common_query::prelude::*; -use datatypes::prelude::*; -use datatypes::types::OrdPrimitive; -use datatypes::value::ListValue; -use datatypes::vectors::{ConstantVector, Helper, ListVector}; -use datatypes::with_match_primitive_type_id; -use num::NumCast; -use snafu::{ensure, OptionExt, ResultExt}; - -// This median calculation algorithm's details can be found at -// https://leetcode.cn/problems/find-median-from-data-stream/ -// -// Basically, it uses two heaps, a maximum heap and a minimum. The maximum heap stores numbers that -// are not greater than the median, and the minimum heap stores the greater. In a streaming of -// numbers, when a number is arrived, we adjust the heaps' tops, so that either one top is the -// median or both tops can be averaged to get the median. -// -// The time complexity to update the median is O(logn), O(1) to get the median; and the space -// complexity is O(n). (Ignore the costs for heap expansion.) -// -// From the point of algorithm, [quick select](https://en.wikipedia.org/wiki/Quickselect) might be -// better. But to use quick select here, we need a mutable self in the final calculation(`evaluate`) -// to swap stored numbers in the states vector. Though we can make our `evaluate` received -// `&mut self`, DataFusion calls our accumulator with `&self` (see `DfAccumulatorAdaptor`). That -// means we have to introduce some kinds of interior mutability, and the overhead is not neglectable. -// -// TODO(LFC): Use quick select to get median when we can modify DataFusion's code, and benchmark with two-heap algorithm. -#[derive(Debug, Default)] -pub struct Median -where - T: WrapperType, -{ - greater: BinaryHeap>>, - not_greater: BinaryHeap>, -} - -impl Median -where - T: WrapperType, -{ - fn push(&mut self, value: T) { - let value = OrdPrimitive::(value); - - if self.not_greater.is_empty() { - self.not_greater.push(value); - return; - } - // The `unwrap`s below are safe because there are `push`s before them. - if value <= *self.not_greater.peek().unwrap() { - self.not_greater.push(value); - if self.not_greater.len() > self.greater.len() + 1 { - self.greater.push(Reverse(self.not_greater.pop().unwrap())); - } - } else { - self.greater.push(Reverse(value)); - if self.greater.len() > self.not_greater.len() { - self.not_greater.push(self.greater.pop().unwrap().0); - } - } - } -} - -// UDAFs are built using the trait `Accumulator`, that offers DataFusion the necessary functions -// to use them. -impl Accumulator for Median -where - T: WrapperType + NumCast, -{ - // This function serializes our state to `ScalarValue`, which DataFusion uses to pass this - // state between execution stages. Note that this can be arbitrary data. - // - // The `ScalarValue`s returned here will be passed in as argument `states: &[VectorRef]` to - // `merge_batch` function. - fn state(&self) -> Result> { - let nums = self - .greater - .iter() - .map(|x| &x.0) - .chain(self.not_greater.iter()) - .map(|&n| n.into()) - .collect::>(); - Ok(vec![Value::List(ListValue::new( - Some(Box::new(nums)), - T::LogicalType::build_data_type(), - ))]) - } - - // DataFusion calls this function to update the accumulator's state for a batch of inputs rows. - // It is expected this function to update the accumulator's state. - fn update_batch(&mut self, values: &[VectorRef]) -> Result<()> { - if values.is_empty() { - return Ok(()); - } - - ensure!(values.len() == 1, InvalidInputStateSnafu); - - // This is a unary accumulator, so only one column is provided. - let column = &values[0]; - let mut len = 1; - let column: &::VectorType = if column.is_const() { - len = column.len(); - let column: &ConstantVector = unsafe { Helper::static_cast(column) }; - unsafe { Helper::static_cast(column.inner()) } - } else { - unsafe { Helper::static_cast(column) } - }; - (0..len).for_each(|_| { - for v in column.iter_data().flatten() { - self.push(v); - } - }); - Ok(()) - } - - // DataFusion executes accumulators in partitions. In some execution stage, DataFusion will - // merge states from other accumulators (returned by `state()` method). - fn merge_batch(&mut self, states: &[VectorRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - } - - // The states here are returned by the `state` method. Since we only returned a vector - // with one value in that method, `states[0]` is fine. - let states = &states[0]; - let states = states - .as_any() - .downcast_ref::() - .with_context(|| DowncastVectorSnafu { - err_msg: format!( - "expect ListVector, got vector type {}", - states.vector_type_name() - ), - })?; - for state in states.values_iter() { - if let Some(state) = state.context(FromScalarValueSnafu)? { - // merging state is simply accumulate stored numbers from others', so just call update - self.update_batch(&[state])? - } - } - Ok(()) - } - - // DataFusion expects this function to return the final value of this aggregator. - fn evaluate(&self) -> Result { - if self.not_greater.is_empty() { - assert!( - self.greater.is_empty(), - "not expected in two-heap median algorithm, there must be a bug when implementing it" - ); - return Ok(Value::Null); - } - - // unwrap is safe because we checked not_greater heap's len above - let not_greater = *self.not_greater.peek().unwrap(); - let median = if self.not_greater.len() > self.greater.len() { - not_greater.into() - } else { - // unwrap is safe because greater heap len >= not_greater heap len, which is > 0 - let greater = self.greater.peek().unwrap(); - - // the following three NumCast's `unwrap`s are safe because T is primitive - let not_greater_v: f64 = NumCast::from(not_greater.as_primitive()).unwrap(); - let greater_v: f64 = NumCast::from(greater.0.as_primitive()).unwrap(); - let median: T = NumCast::from((not_greater_v + greater_v) / 2.0).unwrap(); - median.into() - }; - Ok(median) - } -} - -#[as_aggr_func_creator] -#[derive(Debug, Default, AggrFuncTypeStore)] -pub struct MedianAccumulatorCreator {} - -impl AggregateFunctionCreator for MedianAccumulatorCreator { - fn creator(&self) -> AccumulatorCreatorFunction { - let creator: AccumulatorCreatorFunction = Arc::new(move |types: &[ConcreteDataType]| { - let input_type = &types[0]; - with_match_primitive_type_id!( - input_type.logical_type_id(), - |$S| { - Ok(Box::new(Median::<<$S as LogicalPrimitiveType>::Wrapper>::default())) - }, - { - let err_msg = format!( - "\"MEDIAN\" aggregate function not support data type {:?}", - input_type.logical_type_id(), - ); - CreateAccumulatorSnafu { err_msg }.fail()? - } - ) - }); - creator - } - - fn output_type(&self) -> Result { - let input_types = self.input_types()?; - ensure!(input_types.len() == 1, InvalidInputStateSnafu); - // unwrap is safe because we have checked input_types len must equals 1 - Ok(input_types.into_iter().next().unwrap()) - } - - fn state_types(&self) -> Result> { - Ok(vec![ConcreteDataType::list_datatype(self.output_type()?)]) - } -} - -#[cfg(test)] -mod test { - use datatypes::vectors::Int32Vector; - - use super::*; - #[test] - fn test_update_batch() { - // test update empty batch, expect not updating anything - let mut median = Median::::default(); - assert!(median.update_batch(&[]).is_ok()); - assert!(median.not_greater.is_empty()); - assert!(median.greater.is_empty()); - assert_eq!(Value::Null, median.evaluate().unwrap()); - - // test update one not-null value - let mut median = Median::::default(); - let v: Vec = vec![Arc::new(Int32Vector::from(vec![Some(42)]))]; - assert!(median.update_batch(&v).is_ok()); - assert_eq!(Value::Int32(42), median.evaluate().unwrap()); - - // test update one null value - let mut median = Median::::default(); - let v: Vec = vec![Arc::new(Int32Vector::from(vec![Option::::None]))]; - assert!(median.update_batch(&v).is_ok()); - assert_eq!(Value::Null, median.evaluate().unwrap()); - - // test update no null-value batch - let mut median = Median::::default(); - let v: Vec = vec![Arc::new(Int32Vector::from(vec![ - Some(-1i32), - Some(1), - Some(2), - ]))]; - assert!(median.update_batch(&v).is_ok()); - assert_eq!(Value::Int32(1), median.evaluate().unwrap()); - - // test update null-value batch - let mut median = Median::::default(); - let v: Vec = vec![Arc::new(Int32Vector::from(vec![ - Some(-2i32), - None, - Some(3), - Some(4), - ]))]; - assert!(median.update_batch(&v).is_ok()); - assert_eq!(Value::Int32(3), median.evaluate().unwrap()); - - // test update with constant vector - let mut median = Median::::default(); - let v: Vec = vec![Arc::new(ConstantVector::new( - Arc::new(Int32Vector::from_vec(vec![4])), - 10, - ))]; - assert!(median.update_batch(&v).is_ok()); - assert_eq!(Value::Int32(4), median.evaluate().unwrap()); - } -} diff --git a/src/common/recordbatch/src/lib.rs b/src/common/recordbatch/src/lib.rs index 23aa04a9bf..77987eac25 100644 --- a/src/common/recordbatch/src/lib.rs +++ b/src/common/recordbatch/src/lib.rs @@ -66,7 +66,7 @@ impl Stream for EmptyRecordBatchStream { } } -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub struct RecordBatches { schema: SchemaRef, batches: Vec, diff --git a/src/common/recordbatch/src/util.rs b/src/common/recordbatch/src/util.rs index 1cca3ee988..4b2f1a67c8 100644 --- a/src/common/recordbatch/src/util.rs +++ b/src/common/recordbatch/src/util.rs @@ -15,13 +15,20 @@ use futures::TryStreamExt; use crate::error::Result; -use crate::{RecordBatch, SendableRecordBatchStream}; +use crate::{RecordBatch, RecordBatches, SendableRecordBatchStream}; /// Collect all the items from the stream into a vector of [`RecordBatch`]. pub async fn collect(stream: SendableRecordBatchStream) -> Result> { stream.try_collect::>().await } +/// Collect all the items from the stream into [RecordBatches]. +pub async fn collect_batches(stream: SendableRecordBatchStream) -> Result { + let schema = stream.schema(); + let batches = stream.try_collect::>().await?; + RecordBatches::try_new(schema, batches) +} + #[cfg(test)] mod tests { use std::mem; @@ -90,7 +97,14 @@ mod tests { }; let batches = collect(Box::pin(stream)).await.unwrap(); assert_eq!(1, batches.len()); - assert_eq!(batch, batches[0]); + + let stream = MockRecordBatchStream { + schema: schema.clone(), + batch: Some(batch.clone()), + }; + let batches = collect_batches(Box::pin(stream)).await.unwrap(); + let expect_batches = RecordBatches::try_new(schema.clone(), vec![batch]).unwrap(); + assert_eq!(expect_batches, batches); } } diff --git a/src/datatypes/src/types/timestamp_type.rs b/src/datatypes/src/types/timestamp_type.rs index ba352bac1e..ddd740b203 100644 --- a/src/datatypes/src/types/timestamp_type.rs +++ b/src/datatypes/src/types/timestamp_type.rs @@ -70,7 +70,7 @@ macro_rules! impl_data_type_for_timestamp { impl DataType for [] { fn name(&self) -> &str { - stringify!([]) + stringify!([]) } fn logical_type_id(&self) -> LogicalTypeId { diff --git a/src/datatypes/src/value.rs b/src/datatypes/src/value.rs index d7cf5325bd..457c774606 100644 --- a/src/datatypes/src/value.rs +++ b/src/datatypes/src/value.rs @@ -1346,7 +1346,7 @@ mod tests { ConcreteDataType::timestamp_second_datatype(), )) .to_string(), - "TimestampSecondType[]" + "TimestampSecond[]" ); assert_eq!( Value::List(ListValue::new( @@ -1354,7 +1354,7 @@ mod tests { ConcreteDataType::timestamp_millisecond_datatype(), )) .to_string(), - "TimestampMillisecondType[]" + "TimestampMillisecond[]" ); assert_eq!( Value::List(ListValue::new( @@ -1362,7 +1362,7 @@ mod tests { ConcreteDataType::timestamp_microsecond_datatype(), )) .to_string(), - "TimestampMicrosecondType[]" + "TimestampMicrosecond[]" ); assert_eq!( Value::List(ListValue::new( @@ -1370,7 +1370,7 @@ mod tests { ConcreteDataType::timestamp_nanosecond_datatype(), )) .to_string(), - "TimestampNanosecondType[]" + "TimestampNanosecond[]" ); } diff --git a/src/frontend/Cargo.toml b/src/frontend/Cargo.toml index cd6f298bd0..304dcaa033 100644 --- a/src/frontend/Cargo.toml +++ b/src/frontend/Cargo.toml @@ -40,7 +40,6 @@ servers = { path = "../servers" } session = { path = "../session" } snafu = { version = "0.7", features = ["backtraces"] } sql = { path = "../sql" } -sqlparser = "0.15" store-api = { path = "../store-api" } substrait = { path = "../common/substrait" } table = { path = "../table" } diff --git a/src/frontend/src/expr_factory.rs b/src/frontend/src/expr_factory.rs index 9f406ace0b..204eb42d92 100644 --- a/src/frontend/src/expr_factory.rs +++ b/src/frontend/src/expr_factory.rs @@ -19,9 +19,9 @@ use api::helper::ColumnDataTypeWrapper; use api::v1::{Column, ColumnDataType, CreateExpr}; use datatypes::schema::ColumnSchema; use snafu::{ensure, ResultExt}; +use sql::ast::{ColumnDef, TableConstraint}; use sql::statements::create::{CreateTable, TIME_INDEX}; use sql::statements::{column_def_to_schema, table_idents_to_full_name}; -use sqlparser::ast::{ColumnDef, TableConstraint}; use crate::error::{ BuildCreateExprOnInsertionSnafu, ColumnDataTypeSnafu, ConvertColumnDefaultConstraintSnafu, diff --git a/src/frontend/src/instance/distributed.rs b/src/frontend/src/instance/distributed.rs index d32e12ee24..0adeb96f31 100644 --- a/src/frontend/src/instance/distributed.rs +++ b/src/frontend/src/instance/distributed.rs @@ -35,10 +35,10 @@ use query::sql::{describe_table, explain, show_databases, show_tables}; use query::{QueryEngineFactory, QueryEngineRef}; use session::context::QueryContextRef; use snafu::{ensure, OptionExt, ResultExt}; +use sql::ast::Value as SqlValue; use sql::statements::create::Partitions; use sql::statements::sql_value_to_value; use sql::statements::statement::Statement; -use sqlparser::ast::Value as SqlValue; use table::metadata::{RawTableInfo, RawTableMeta, TableIdent, TableType}; use crate::catalog::FrontendCatalogManager; @@ -454,9 +454,9 @@ fn find_partition_columns( #[cfg(test)] mod test { + use sql::dialect::GenericDialect; use sql::parser::ParserContext; use sql::statements::statement::Statement; - use sqlparser::dialect::GenericDialect; use super::*; use crate::expr_factory::{CreateExprFactory, DefaultCreateExprFactory}; diff --git a/src/frontend/src/table.rs b/src/frontend/src/table.rs index 36d229a245..d7e02931d8 100644 --- a/src/frontend/src/table.rs +++ b/src/frontend/src/table.rs @@ -530,9 +530,9 @@ mod test { use meta_srv::mocks::MockInfo; use meta_srv::service::store::kv::KvStoreRef; use meta_srv::service::store::memory::MemStore; + use sql::dialect::GenericDialect; use sql::parser::ParserContext; use sql::statements::statement::Statement; - use sqlparser::dialect::GenericDialect; use table::metadata::{TableInfoBuilder, TableMetaBuilder}; use table::TableRef; use tempdir::TempDir; diff --git a/src/query/Cargo.toml b/src/query/Cargo.toml index 6d92c2e9d8..1bb9da358a 100644 --- a/src/query/Cargo.toml +++ b/src/query/Cargo.toml @@ -18,7 +18,9 @@ common-time = { path = "../common/time" } datafusion = "14.0.0" datafusion-common = "14.0.0" datafusion-expr = "14.0.0" +datafusion-optimizer = "14.0.0" datafusion-physical-expr = "14.0.0" +datafusion-sql = "14.0.0" datatypes = { path = "../datatypes" } futures = "0.3" futures-util = "0.3" diff --git a/src/query/src/datafusion.rs b/src/query/src/datafusion.rs index 8dda26a5db..0968d99357 100644 --- a/src/query/src/datafusion.rs +++ b/src/query/src/datafusion.rs @@ -141,7 +141,6 @@ impl LogicalOptimizer for DatafusionQueryEngine { LogicalPlan::DfPlan(df_plan) => { let optimized_plan = self.state - .df_context() .optimize(df_plan) .context(error::DatafusionSnafu { msg: "Fail to optimize logical plan", @@ -163,14 +162,11 @@ impl PhysicalPlanner for DatafusionQueryEngine { let _timer = timer!(metric::METRIC_CREATE_PHYSICAL_ELAPSED); match logical_plan { LogicalPlan::DfPlan(df_plan) => { - let physical_plan = self - .state - .df_context() - .create_physical_plan(df_plan) - .await - .context(error::DatafusionSnafu { + let physical_plan = self.state.create_physical_plan(df_plan).await.context( + error::DatafusionSnafu { msg: "Fail to create physical plan", - })?; + }, + )?; Ok(Arc::new(PhysicalPlanAdapter::new( Arc::new( @@ -193,22 +189,19 @@ impl PhysicalOptimizer for DatafusionQueryEngine { plan: Arc, ) -> Result> { let _timer = timer!(metric::METRIC_OPTIMIZE_PHYSICAL_ELAPSED); - let config = &self.state.df_context().state.lock().config; - let optimizers = &config.physical_optimizers; - let mut new_plan = plan + let new_plan = plan .as_any() .downcast_ref::() .context(error::PhysicalPlanDowncastSnafu)? .df_plan(); - for optimizer in optimizers { - new_plan = optimizer - .optimize(new_plan, config) + let new_plan = + self.state + .optimize_physical_plan(new_plan) .context(error::DatafusionSnafu { msg: "Fail to optimize physical plan", })?; - } Ok(Arc::new(PhysicalPlanAdapter::new(plan.schema(), new_plan))) } } @@ -224,7 +217,7 @@ impl QueryExecutor for DatafusionQueryEngine { match plan.output_partitioning().partition_count() { 0 => Ok(Box::pin(EmptyRecordBatchStream::new(plan.schema()))), 1 => Ok(plan - .execute(0, ctx.state().runtime()) + .execute(0, ctx.state().task_ctx()) .context(error::ExecutePhysicalPlanSnafu)?), _ => { // merge into a single partition @@ -232,11 +225,11 @@ impl QueryExecutor for DatafusionQueryEngine { CoalescePartitionsExec::new(Arc::new(DfPhysicalPlanAdapter(plan.clone()))); // CoalescePartitionsExec must produce a single partition assert_eq!(1, plan.output_partitioning().partition_count()); - let df_stream = plan.execute(0, ctx.state().runtime()).await.context( - error::DatafusionSnafu { - msg: "Failed to execute DataFusion merge exec", - }, - )?; + let df_stream = + plan.execute(0, ctx.state().task_ctx()) + .context(error::DatafusionSnafu { + msg: "Failed to execute DataFusion merge exec", + })?; let stream = RecordBatchStreamAdapter::try_new(df_stream) .context(error::ConvertDfRecordBatchStreamSnafu)?; Ok(Box::pin(stream)) @@ -254,8 +247,7 @@ mod tests { use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_query::Output; use common_recordbatch::util; - use datafusion::field_util::{FieldExt, SchemaExt}; - use datatypes::arrow::array::UInt64Array; + use datatypes::vectors::{UInt64Vector, VectorRef}; use session::context::QueryContext; use table::table::numbers::NumbersTable; @@ -290,10 +282,10 @@ mod tests { assert_eq!( format!("{:?}", plan), - r#"DfPlan(Limit: 20 - Projection: #SUM(numbers.number) - Aggregate: groupBy=[[]], aggr=[[SUM(#numbers.number)]] - TableScan: numbers projection=None)"# + r#"DfPlan(Limit: skip=0, fetch=20 + Projection: SUM(numbers.number) + Aggregate: groupBy=[[]], aggr=[[SUM(numbers.number)]] + TableScan: numbers)"# ); } @@ -311,20 +303,20 @@ mod tests { Output::Stream(recordbatch) => { let numbers = util::collect(recordbatch).await.unwrap(); assert_eq!(1, numbers.len()); - assert_eq!(numbers[0].df_recordbatch.num_columns(), 1); - assert_eq!(1, numbers[0].schema.arrow_schema().fields().len()); + assert_eq!(numbers[0].num_columns(), 1); + assert_eq!(1, numbers[0].schema.num_columns()); assert_eq!( "SUM(numbers.number)", - numbers[0].schema.arrow_schema().field(0).name() + numbers[0].schema.column_schemas()[0].name ); - let columns = numbers[0].df_recordbatch.columns(); - assert_eq!(1, columns.len()); - assert_eq!(columns[0].len(), 1); + let batch = &numbers[0]; + assert_eq!(1, batch.num_columns()); + assert_eq!(batch.column(0).len(), 1); assert_eq!( - *columns[0].as_any().downcast_ref::().unwrap(), - UInt64Array::from_slice(&[4950]) + *batch.column(0), + Arc::new(UInt64Vector::from_slice(&[4950])) as VectorRef ); } _ => unreachable!(), diff --git a/src/query/src/datafusion/planner.rs b/src/query/src/datafusion/planner.rs index 6d70109e74..4c87654e3c 100644 --- a/src/query/src/datafusion/planner.rs +++ b/src/query/src/datafusion/planner.rs @@ -12,14 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; use std::sync::Arc; use common_query::logical_plan::create_aggregate_function; use datafusion::catalog::TableReference; -use datafusion::datasource::TableProvider; +use datafusion::error::Result as DfResult; use datafusion::physical_plan::udaf::AggregateUDF; use datafusion::physical_plan::udf::ScalarUDF; use datafusion::sql::planner::{ContextProvider, SqlToRel}; +use datafusion_expr::TableSource; use datatypes::arrow::datatypes::DataType; use session::context::QueryContextRef; use snafu::ResultExt; @@ -50,7 +52,7 @@ impl<'a, S: ContextProvider + Send + Sync> DfPlanner<'a, S> { let sql = query.inner.to_string(); let result = self .sql_to_rel - .query_to_plan(query.inner) + .query_to_plan(query.inner, &mut HashMap::new()) .context(error::PlanSqlSnafu { sql })?; Ok(LogicalPlan::DfPlan(result)) @@ -103,26 +105,14 @@ impl DfContextProviderAdapter { } } -/// TODO(dennis): Delegate all requests to ExecutionContext right now, -/// manage UDFs, UDAFs, variables by ourself in future. impl ContextProvider for DfContextProviderAdapter { - fn get_table_provider(&self, name: TableReference) -> Option> { + fn get_table_provider(&self, name: TableReference) -> DfResult> { let schema = self.query_ctx.current_schema(); - let execution_ctx = self.state.df_context().state.lock(); - match name { - TableReference::Bare { table } if schema.is_some() => { - execution_ctx.get_table_provider(TableReference::Partial { - // unwrap safety: checked in this match's arm - schema: &schema.unwrap(), - table, - }) - } - _ => execution_ctx.get_table_provider(name), - } + self.state.get_table_provider(schema.as_deref(), name) } fn get_function_meta(&self, name: &str) -> Option> { - self.state.df_context().state.lock().get_function_meta(name) + self.state.get_function_meta(name) } fn get_aggregate_meta(&self, name: &str) -> Option> { @@ -134,10 +124,6 @@ impl ContextProvider for DfContextProviderAdapter { } fn get_variable_type(&self, variable_names: &[String]) -> Option { - self.state - .df_context() - .state - .lock() - .get_variable_type(variable_names) + self.state.get_variable_type(variable_names) } } diff --git a/src/query/src/optimizer.rs b/src/query/src/optimizer.rs index 6eadf5b7e7..2e66588769 100644 --- a/src/query/src/optimizer.rs +++ b/src/query/src/optimizer.rs @@ -44,12 +44,10 @@ impl OptimizerRule for TypeConversionRule { }; match plan { - LogicalPlan::Filter(Filter { predicate, input }) => { - Ok(LogicalPlan::Filter(Filter::try_new( - predicate.clone().rewrite(&mut converter)?, - Arc::new(self.optimize(input, optimizer_config)?), - )?)) - } + LogicalPlan::Filter(filter) => Ok(LogicalPlan::Filter(Filter::try_new( + filter.predicate().clone().rewrite(&mut converter)?, + Arc::new(self.optimize(filter.input(), optimizer_config)?), + )?)), LogicalPlan::TableScan(TableScan { table_name, source, @@ -150,8 +148,8 @@ impl<'a> TypeConverter<'a> { compute::cast(&value_arr, target_type).map_err(DataFusionError::ArrowError)?; ScalarValue::try_from_array( - &Arc::from(arr), // index: Converts a value in `array` at `index` into a ScalarValue - 0, + &arr, + 0, // index: Converts a value in `array` at `index` into a ScalarValue ) } } diff --git a/src/query/src/query_engine/state.rs b/src/query/src/query_engine/state.rs index b638ead920..a72b0203e3 100644 --- a/src/query/src/query_engine/state.rs +++ b/src/query/src/query_engine/state.rs @@ -21,14 +21,17 @@ use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_function::scalars::aggregate::AggregateFunctionMetaRef; use common_query::physical_plan::{SessionContext, TaskContext}; use common_query::prelude::ScalarUdf; -use datafusion::optimizer::common_subexpr_eliminate::CommonSubexprEliminate; -use datafusion::optimizer::eliminate_limit::EliminateLimit; -use datafusion::optimizer::filter_push_down::FilterPushDown; -use datafusion::optimizer::limit_push_down::LimitPushDown; -use datafusion::optimizer::projection_push_down::ProjectionPushDown; -use datafusion::optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy; -use datafusion::optimizer::to_approx_perc::ToApproxPerc; -use datafusion::execution::context::SessionConfig; +use datafusion::catalog::TableReference; +use datafusion::error::Result as DfResult; +use datafusion::execution::context::{SessionConfig, SessionState}; +use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::physical_plan::udf::ScalarUDF; +use datafusion::physical_plan::ExecutionPlan; +use datafusion_expr::{LogicalPlan as DfLogicalPlan, TableSource}; +use datafusion_optimizer::optimizer::{Optimizer, OptimizerConfig}; +use datafusion_sql::planner::ContextProvider; +use datatypes::arrow::datatypes::DataType; + use crate::datafusion::DfCatalogListAdapter; use crate::optimizer::TypeConversionRule; @@ -38,7 +41,7 @@ use crate::optimizer::TypeConversionRule; // type in QueryEngine trait. #[derive(Clone)] pub struct QueryEngineState { - df_context: ExecutionContext, + df_context: SessionContext, catalog_list: CatalogListRef, aggregate_functions: Arc>>, } @@ -52,25 +55,18 @@ impl fmt::Debug for QueryEngineState { impl QueryEngineState { pub(crate) fn new(catalog_list: CatalogListRef) -> Self { - let config = ExecutionConfig::new() - .with_default_catalog_and_schema(DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME) - .with_optimizer_rules(vec![ - // TODO(hl): SimplifyExpressions is not exported. - Arc::new(TypeConversionRule {}), - // These are the default optimizer in datafusion - Arc::new(CommonSubexprEliminate::new()), - Arc::new(EliminateLimit::new()), - Arc::new(ProjectionPushDown::new()), - Arc::new(FilterPushDown::new()), - Arc::new(LimitPushDown::new()), - Arc::new(SingleDistinctToGroupBy::new()), - Arc::new(ToApproxPerc::new()), - ]); + let runtime_env = Arc::new(RuntimeEnv::default()); + let session_config = SessionConfig::new() + .with_default_catalog_and_schema(DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME); + let mut optimizer = Optimizer::new(&OptimizerConfig::new()); + // Apply the type conversion rule first. + optimizer.rules.insert(0, Arc::new(TypeConversionRule {})); - let df_context = ExecutionContext::with_config(config); + let mut session_state = SessionState::with_config_rt(session_config, runtime_env); + session_state.optimizer = optimizer; + session_state.catalog_list = Arc::new(DfCatalogListAdapter::new(catalog_list.clone())); - df_context.state.lock().catalog_list = - Arc::new(DfCatalogListAdapter::new(catalog_list.clone())); + let df_context = SessionContext::with_state(session_state); Self { df_context, @@ -80,11 +76,15 @@ impl QueryEngineState { } /// Register a udf function - /// TODO(dennis): manage UDFs by ourself. + // TODO(dennis): manage UDFs by ourself. pub fn register_udf(&self, udf: ScalarUdf) { + // `SessionContext` has a `register_udf()` method, which requires `&mut self`, this is + // a workaround. + // TODO(yingwen): Use `SessionContext::register_udf()` once it taks `&self`. + // It's implemented in https://github.com/apache/arrow-datafusion/pull/4612 self.df_context .state - .lock() + .write() .scalar_functions .insert(udf.name.clone(), Arc::new(udf.into_df_udf())); } @@ -112,12 +112,59 @@ impl QueryEngineState { } #[inline] - pub(crate) fn df_context(&self) -> &ExecutionContext { - &self.df_context + pub(crate) fn task_ctx(&self) -> Arc { + self.df_context.task_ctx() } - #[inline] - pub(crate) fn runtime(&self) -> Arc { - self.df_context.runtime_env() + pub(crate) fn get_table_provider( + &self, + schema: Option<&str>, + name: TableReference, + ) -> DfResult> { + let state = self.df_context.state.read(); + match name { + TableReference::Bare { table } if schema.is_some() => { + state.get_table_provider(TableReference::Partial { + // unwrap safety: checked in this match's arm + schema: schema.unwrap(), + table, + }) + } + _ => state.get_table_provider(name), + } + } + + pub(crate) fn get_function_meta(&self, name: &str) -> Option> { + let state = self.df_context.state.read(); + state.get_function_meta(name) + } + + pub(crate) fn get_variable_type(&self, variable_names: &[String]) -> Option { + let state = self.df_context.state.read(); + state.get_variable_type(variable_names) + } + + pub(crate) fn optimize(&self, plan: &DfLogicalPlan) -> DfResult { + self.df_context.optimize(plan) + } + + pub(crate) async fn create_physical_plan( + &self, + logical_plan: &DfLogicalPlan, + ) -> DfResult> { + self.df_context.create_physical_plan(logical_plan).await + } + + pub(crate) fn optimize_physical_plan( + &self, + mut plan: Arc, + ) -> DfResult> { + let state = self.df_context.state.read(); + let config = &state.config; + for optimizer in &state.physical_optimizers { + plan = optimizer.optimize(plan, config)?; + } + + Ok(plan) } } diff --git a/src/query/src/sql.rs b/src/query/src/sql.rs index 2854fed7fc..327394416e 100644 --- a/src/query/src/sql.rs +++ b/src/query/src/sql.rs @@ -261,10 +261,9 @@ mod test { use common_query::Output; use common_recordbatch::{RecordBatch, RecordBatches}; use common_time::timestamp::TimeUnit; - use datatypes::arrow::array::PrimitiveArray; use datatypes::prelude::ConcreteDataType; use datatypes::schema::{ColumnDefaultConstraint, ColumnSchema, Schema, SchemaRef}; - use datatypes::vectors::{StringVector, TimestampVector, UInt32Vector, VectorRef}; + use datatypes::vectors::{StringVector, TimestampMillisecondVector, UInt32Vector, VectorRef}; use snafu::ResultExt; use sql::statements::describe::DescribeTable; use table::test_util::MemTable; @@ -379,12 +378,12 @@ mod test { .with_time_index(true), ]; let data = vec![ - Arc::new(UInt32Vector::from_vec(vec![0])) as _, - Arc::new(TimestampVector::new(PrimitiveArray::from_vec(vec![0]))) as _, + Arc::new(UInt32Vector::from_slice(&[0])) as _, + Arc::new(TimestampMillisecondVector::from_slice(&[0])) as _, ]; let expected_columns = vec![ Arc::new(StringVector::from(vec!["t1", "t2"])) as _, - Arc::new(StringVector::from(vec!["UInt32", "Timestamp"])) as _, + Arc::new(StringVector::from(vec!["UInt32", "TimestampMillisecond"])) as _, Arc::new(StringVector::from(vec![NULLABLE_YES, NULLABLE_NO])) as _, Arc::new(StringVector::from(vec!["", "current_timestamp()"])) as _, Arc::new(StringVector::from(vec![ diff --git a/src/query/tests/argmax_test.rs b/src/query/tests/argmax_test.rs index 11f0167a09..cbf1ae931d 100644 --- a/src/query/tests/argmax_test.rs +++ b/src/query/tests/argmax_test.rs @@ -12,16 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; mod function; + +use std::sync::Arc; + use common_query::Output; use common_recordbatch::error::Result as RecordResult; use common_recordbatch::{util, RecordBatch}; -use datafusion::field_util::{FieldExt, SchemaExt}; use datatypes::for_all_primitive_types; use datatypes::prelude::*; -use datatypes::types::PrimitiveElement; -use function::{create_query_engine, get_numbers_from_table}; +use datatypes::types::WrapperType; use query::error::Result; use query::QueryEngine; use session::context::QueryContext; @@ -29,7 +29,7 @@ use session::context::QueryContext; #[tokio::test] async fn test_argmax_aggregator() -> Result<()> { common_telemetry::init_default_ut_logging(); - let engine = create_query_engine(); + let engine = function::create_query_engine(); macro_rules! test_argmax { ([], $( { $T:ty } ),*) => { @@ -49,33 +49,23 @@ async fn test_argmax_success( engine: Arc, ) -> Result<()> where - T: PrimitiveElement + PartialOrd, - for<'a> T: Scalar = T>, + T: WrapperType + PartialOrd, { let result = execute_argmax(column_name, table_name, engine.clone()) .await .unwrap(); - assert_eq!(1, result.len()); - assert_eq!(result[0].df_recordbatch.num_columns(), 1); - assert_eq!(1, result[0].schema.arrow_schema().fields().len()); - assert_eq!("argmax", result[0].schema.arrow_schema().field(0).name()); + let value = function::get_value_from_batches("argmax", result); - let columns = result[0].df_recordbatch.columns(); - assert_eq!(1, columns.len()); - assert_eq!(columns[0].len(), 1); - let v = VectorHelper::try_into_vector(&columns[0]).unwrap(); - assert_eq!(1, v.len()); - let value = v.get(0); - - let numbers = get_numbers_from_table::(column_name, table_name, engine.clone()).await; + let numbers = + function::get_numbers_from_table::(column_name, table_name, engine.clone()).await; let expected_value = match numbers.len() { 0 => 0_u64, _ => { let mut index = 0; - let mut max = numbers[0].into(); + let mut max = numbers[0]; for (i, &number) in numbers.iter().enumerate() { - if max < number.into() { - max = number.into(); + if max < number { + max = number; index = i; } } diff --git a/src/query/tests/argmin_test.rs b/src/query/tests/argmin_test.rs index 2a509f05fd..546fa9ae23 100644 --- a/src/query/tests/argmin_test.rs +++ b/src/query/tests/argmin_test.rs @@ -12,17 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; mod function; +use std::sync::Arc; + use common_query::Output; use common_recordbatch::error::Result as RecordResult; use common_recordbatch::{util, RecordBatch}; -use datafusion::field_util::{FieldExt, SchemaExt}; use datatypes::for_all_primitive_types; use datatypes::prelude::*; -use datatypes::types::PrimitiveElement; -use function::{create_query_engine, get_numbers_from_table}; +use datatypes::types::WrapperType; use query::error::Result; use query::QueryEngine; use session::context::QueryContext; @@ -30,7 +29,7 @@ use session::context::QueryContext; #[tokio::test] async fn test_argmin_aggregator() -> Result<()> { common_telemetry::init_default_ut_logging(); - let engine = create_query_engine(); + let engine = function::create_query_engine(); macro_rules! test_argmin { ([], $( { $T:ty } ),*) => { @@ -50,33 +49,23 @@ async fn test_argmin_success( engine: Arc, ) -> Result<()> where - T: PrimitiveElement + PartialOrd, - for<'a> T: Scalar = T>, + T: WrapperType + PartialOrd, { let result = execute_argmin(column_name, table_name, engine.clone()) .await .unwrap(); - assert_eq!(1, result.len()); - assert_eq!(result[0].df_recordbatch.num_columns(), 1); - assert_eq!(1, result[0].schema.arrow_schema().fields().len()); - assert_eq!("argmin", result[0].schema.arrow_schema().field(0).name()); + let value = function::get_value_from_batches("argmin", result); - let columns = result[0].df_recordbatch.columns(); - assert_eq!(1, columns.len()); - assert_eq!(columns[0].len(), 1); - let v = VectorHelper::try_into_vector(&columns[0]).unwrap(); - assert_eq!(1, v.len()); - let value = v.get(0); - - let numbers = get_numbers_from_table::(column_name, table_name, engine.clone()).await; + let numbers = + function::get_numbers_from_table::(column_name, table_name, engine.clone()).await; let expected_value = match numbers.len() { 0 => 0_u32, _ => { let mut index = 0; - let mut min = numbers[0].into(); + let mut min = numbers[0]; for (i, &number) in numbers.iter().enumerate() { - if min > number.into() { - min = number.into(); + if min > number { + min = number; index = i; } } diff --git a/src/query/tests/function.rs b/src/query/tests/function.rs index 040dfa7a6b..7de93a6265 100644 --- a/src/query/tests/function.rs +++ b/src/query/tests/function.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +// FIXME(yingwen): Consider move all tests under query/tests to query/src so we could reuse +// more codes. use std::sync::Arc; use catalog::local::{MemoryCatalogManager, MemoryCatalogProvider, MemorySchemaProvider}; @@ -22,8 +24,8 @@ use common_recordbatch::{util, RecordBatch}; use datatypes::for_all_primitive_types; use datatypes::prelude::*; use datatypes::schema::{ColumnSchema, Schema}; -use datatypes::types::PrimitiveElement; -use datatypes::vectors::PrimitiveVector; +use datatypes::types::WrapperType; +use datatypes::vectors::Helper; use query::query_engine::QueryEngineFactory; use query::QueryEngine; use rand::Rng; @@ -47,7 +49,7 @@ pub fn create_query_engine() -> Arc { column_schemas.push(column_schema); let numbers = (1..=10).map(|_| rng.gen::<$T>()).collect::>(); - let column: VectorRef = Arc::new(PrimitiveVector::<$T>::from_vec(numbers.to_vec())); + let column: VectorRef = Arc::new(<$T as Scalar>::VectorType::from_vec(numbers.to_vec())); columns.push(column); )* } @@ -77,8 +79,7 @@ pub async fn get_numbers_from_table<'s, T>( engine: Arc, ) -> Vec where - T: PrimitiveElement, - for<'a> T: Scalar = T>, + T: WrapperType, { let sql = format!("SELECT {} FROM {}", column_name, table_name); let plan = engine @@ -92,8 +93,21 @@ where }; let numbers = util::collect(recordbatch_stream).await.unwrap(); - let columns = numbers[0].df_recordbatch.columns(); - let column = VectorHelper::try_into_vector(&columns[0]).unwrap(); - let column: &::VectorType = unsafe { VectorHelper::static_cast(&column) }; + let column = numbers[0].column(0); + let column: &::VectorType = unsafe { Helper::static_cast(column) }; column.iter_data().flatten().collect::>() } + +pub fn get_value_from_batches(column_name: &str, batches: Vec) -> Value { + assert_eq!(1, batches.len()); + assert_eq!(batches[0].num_columns(), 1); + assert_eq!(1, batches[0].schema.num_columns()); + assert_eq!(column_name, batches[0].schema.column_schemas()[0].name); + + let batch = &batches[0]; + assert_eq!(1, batch.num_columns()); + assert_eq!(batch.column(0).len(), 1); + let v = batch.column(0); + assert_eq!(1, v.len()); + v.get(0) +} diff --git a/src/query/tests/mean_test.rs b/src/query/tests/mean_test.rs index 705dea797d..000323fb21 100644 --- a/src/query/tests/mean_test.rs +++ b/src/query/tests/mean_test.rs @@ -12,19 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; mod function; +use std::sync::Arc; + use common_query::Output; use common_recordbatch::error::Result as RecordResult; use common_recordbatch::{util, RecordBatch}; -use datafusion::field_util::{FieldExt, SchemaExt}; use datatypes::for_all_primitive_types; use datatypes::prelude::*; -use datatypes::types::PrimitiveElement; +use datatypes::types::WrapperType; use datatypes::value::OrderedFloat; use format_num::NumberFormat; -use function::{create_query_engine, get_numbers_from_table}; use num_traits::AsPrimitive; use query::error::Result; use query::QueryEngine; @@ -33,7 +32,7 @@ use session::context::QueryContext; #[tokio::test] async fn test_mean_aggregator() -> Result<()> { common_telemetry::init_default_ut_logging(); - let engine = create_query_engine(); + let engine = function::create_query_engine(); macro_rules! test_mean { ([], $( { $T:ty } ),*) => { @@ -53,25 +52,15 @@ async fn test_mean_success( engine: Arc, ) -> Result<()> where - T: PrimitiveElement + AsPrimitive, - for<'a> T: Scalar = T>, + T: WrapperType + AsPrimitive, { let result = execute_mean(column_name, table_name, engine.clone()) .await .unwrap(); - assert_eq!(1, result.len()); - assert_eq!(result[0].df_recordbatch.num_columns(), 1); - assert_eq!(1, result[0].schema.arrow_schema().fields().len()); - assert_eq!("mean", result[0].schema.arrow_schema().field(0).name()); + let value = function::get_value_from_batches("mean", result); - let columns = result[0].df_recordbatch.columns(); - assert_eq!(1, columns.len()); - assert_eq!(columns[0].len(), 1); - let v = VectorHelper::try_into_vector(&columns[0]).unwrap(); - assert_eq!(1, v.len()); - let value = v.get(0); - - let numbers = get_numbers_from_table::(column_name, table_name, engine.clone()).await; + let numbers = + function::get_numbers_from_table::(column_name, table_name, engine.clone()).await; let expected_value = numbers.iter().map(|&n| n.as_()).collect::>(); let expected_value = inc_stats::mean(expected_value.iter().cloned()).unwrap(); diff --git a/src/query/tests/my_sum_udaf_example.rs b/src/query/tests/my_sum_udaf_example.rs index 4e05183861..54d3a62a5b 100644 --- a/src/query/tests/my_sum_udaf_example.rs +++ b/src/query/tests/my_sum_udaf_example.rs @@ -26,12 +26,10 @@ use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; use common_query::prelude::*; use common_query::Output; use common_recordbatch::{util, RecordBatch}; -use datafusion::arrow_print; -use datafusion_common::record_batch::RecordBatch as DfRecordBatch; use datatypes::prelude::*; use datatypes::schema::{ColumnSchema, Schema}; -use datatypes::types::{PrimitiveElement, PrimitiveType}; -use datatypes::vectors::PrimitiveVector; +use datatypes::types::{LogicalPrimitiveType, WrapperType}; +use datatypes::vectors::Helper; use datatypes::with_match_primitive_type_id; use num_traits::AsPrimitive; use query::error::Result; @@ -40,28 +38,30 @@ use session::context::QueryContext; use table::test_util::MemTable; #[derive(Debug, Default)] -struct MySumAccumulator -where - T: Primitive + AsPrimitive, - SumT: Primitive + std::ops::AddAssign, -{ +struct MySumAccumulator { sum: SumT, _phantom: PhantomData, } impl MySumAccumulator where - T: Primitive + AsPrimitive, - SumT: Primitive + std::ops::AddAssign, + T: WrapperType, + SumT: WrapperType, + T::Native: AsPrimitive, + SumT::Native: std::ops::AddAssign, { #[inline(always)] fn add(&mut self, v: T) { - self.sum += v.as_(); + let mut sum_native = self.sum.into_native(); + sum_native += v.into_native().as_(); + self.sum = SumT::from_native(sum_native); } #[inline(always)] fn merge(&mut self, s: SumT) { - self.sum += s; + let mut sum_native = self.sum.into_native(); + sum_native += s.into_native(); + self.sum = SumT::from_native(sum_native); } } @@ -76,7 +76,7 @@ impl AggregateFunctionCreator for MySumAccumulatorCreator { with_match_primitive_type_id!( input_type.logical_type_id(), |$S| { - Ok(Box::new(MySumAccumulator::<$S, <$S as Primitive>::LargestType>::default())) + Ok(Box::new(MySumAccumulator::<<$S as LogicalPrimitiveType>::Wrapper, <<$S as LogicalPrimitiveType>::LargestType as LogicalPrimitiveType>::Wrapper>::default())) }, { let err_msg = format!( @@ -95,7 +95,7 @@ impl AggregateFunctionCreator for MySumAccumulatorCreator { with_match_primitive_type_id!( input_type.logical_type_id(), |$S| { - Ok(PrimitiveType::<<$S as Primitive>::LargestType>::default().logical_type_id().data_type()) + Ok(<<$S as LogicalPrimitiveType>::LargestType>::build_data_type()) }, { unreachable!() @@ -110,10 +110,10 @@ impl AggregateFunctionCreator for MySumAccumulatorCreator { impl Accumulator for MySumAccumulator where - T: Primitive + AsPrimitive, - for<'a> T: Scalar = T>, - SumT: Primitive + std::ops::AddAssign, - for<'a> SumT: Scalar = SumT>, + T: WrapperType, + SumT: WrapperType, + T::Native: AsPrimitive, + SumT::Native: std::ops::AddAssign, { fn state(&self) -> QueryResult> { Ok(vec![self.sum.into()]) @@ -124,7 +124,7 @@ where return Ok(()); }; let column = &values[0]; - let column: &::VectorType = unsafe { VectorHelper::static_cast(column) }; + let column: &::VectorType = unsafe { Helper::static_cast(column) }; for v in column.iter_data().flatten() { self.add(v) } @@ -136,7 +136,7 @@ where return Ok(()); }; let states = &states[0]; - let states: &::VectorType = unsafe { VectorHelper::static_cast(states) }; + let states: &::VectorType = unsafe { Helper::static_cast(states) }; for s in states.iter_data().flatten() { self.merge(s) } @@ -154,65 +154,57 @@ async fn test_my_sum() -> Result<()> { test_my_sum_with( (1..=10).collect::>(), - vec![ - "+--------+", - "| my_sum |", - "+--------+", - "| 55 |", - "+--------+", - ], + r#"+--------+ +| my_sum | ++--------+ +| 55 | ++--------+"#, ) .await?; test_my_sum_with( (-10..=11).collect::>(), - vec![ - "+--------+", - "| my_sum |", - "+--------+", - "| 11 |", - "+--------+", - ], + r#"+--------+ +| my_sum | ++--------+ +| 11 | ++--------+"#, ) .await?; test_my_sum_with( vec![-1.0f32, 1.0, 2.0, 3.0, 4.0], - vec![ - "+--------+", - "| my_sum |", - "+--------+", - "| 9 |", - "+--------+", - ], + r#"+--------+ +| my_sum | ++--------+ +| 9 | ++--------+"#, ) .await?; test_my_sum_with( vec![u32::MAX, u32::MAX], - vec![ - "+------------+", - "| my_sum |", - "+------------+", - "| 8589934590 |", - "+------------+", - ], + r#"+------------+ +| my_sum | ++------------+ +| 8589934590 | ++------------+"#, ) .await?; Ok(()) } -async fn test_my_sum_with(numbers: Vec, expected: Vec<&str>) -> Result<()> +async fn test_my_sum_with(numbers: Vec, expected: &str) -> Result<()> where - T: PrimitiveElement, + T: WrapperType, { let table_name = format!("{}_numbers", std::any::type_name::()); let column_name = format!("{}_number", std::any::type_name::()); let column_schemas = vec![ColumnSchema::new( column_name.clone(), - T::build_data_type(), + T::LogicalType::build_data_type(), true, )]; let schema = Arc::new(Schema::new(column_schemas.clone())); - let column: VectorRef = Arc::new(PrimitiveVector::::from_vec(numbers)); + let column: VectorRef = Arc::new(T::VectorType::from_vec(numbers)); let recordbatch = RecordBatch::new(schema, vec![column]).unwrap(); let testing_table = MemTable::new(&table_name, recordbatch); @@ -236,14 +228,9 @@ where Output::Stream(batch) => batch, _ => unreachable!(), }; - let recordbatch = util::collect(recordbatch_stream).await.unwrap(); - let df_recordbatch = recordbatch - .into_iter() - .map(|r| r.df_recordbatch) - .collect::>(); + let batches = util::collect_batches(recordbatch_stream).await.unwrap(); - let pretty_print = arrow_print::write(&df_recordbatch); - let pretty_print = pretty_print.lines().collect::>(); + let pretty_print = batches.pretty_print().unwrap(); assert_eq!(expected, pretty_print); Ok(()) } diff --git a/src/query/tests/percentile_test.rs b/src/query/tests/percentile_test.rs index 6e210a0494..e639d4b3e6 100644 --- a/src/query/tests/percentile_test.rs +++ b/src/query/tests/percentile_test.rs @@ -20,12 +20,10 @@ use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_query::Output; use common_recordbatch::error::Result as RecordResult; use common_recordbatch::{util, RecordBatch}; -use datafusion::field_util::{FieldExt, SchemaExt}; use datatypes::for_all_primitive_types; use datatypes::prelude::*; use datatypes::schema::{ColumnSchema, Schema}; -use datatypes::types::PrimitiveElement; -use datatypes::vectors::PrimitiveVector; +use datatypes::vectors::Int32Vector; use function::{create_query_engine, get_numbers_from_table}; use num_traits::AsPrimitive; use query::error::Result; @@ -64,9 +62,8 @@ async fn test_percentile_correctness() -> Result<()> { _ => unreachable!(), }; let record_batch = util::collect(recordbatch_stream).await.unwrap(); - let columns = record_batch[0].df_recordbatch.columns(); - let v = VectorHelper::try_into_vector(&columns[0]).unwrap(); - let value = v.get(0); + let column = record_batch[0].column(0); + let value = column.get(0); assert_eq!(value, Value::from(9.280_000_000_000_001_f64)); Ok(()) } @@ -77,26 +74,12 @@ async fn test_percentile_success( engine: Arc, ) -> Result<()> where - T: PrimitiveElement + AsPrimitive, - for<'a> T: Scalar = T>, + T: WrapperType + AsPrimitive, { let result = execute_percentile(column_name, table_name, engine.clone()) .await .unwrap(); - assert_eq!(1, result.len()); - assert_eq!(result[0].df_recordbatch.num_columns(), 1); - assert_eq!(1, result[0].schema.arrow_schema().fields().len()); - assert_eq!( - "percentile", - result[0].schema.arrow_schema().field(0).name() - ); - - let columns = result[0].df_recordbatch.columns(); - assert_eq!(1, columns.len()); - assert_eq!(columns[0].len(), 1); - let v = VectorHelper::try_into_vector(&columns[0]).unwrap(); - assert_eq!(1, v.len()); - let value = v.get(0); + let value = function::get_value_from_batches("percentile", result); let numbers = get_numbers_from_table::(column_name, table_name, engine.clone()).await; let expected_value = numbers.iter().map(|&n| n.as_()).collect::>(); @@ -140,9 +123,9 @@ fn create_correctness_engine() -> Arc { let column_schema = ColumnSchema::new("corr_number", ConcreteDataType::int32_datatype(), true); column_schemas.push(column_schema); - let numbers = vec![3_i32, 6_i32, 8_i32, 10_i32]; + let numbers = [3_i32, 6_i32, 8_i32, 10_i32]; - let column: VectorRef = Arc::new(PrimitiveVector::::from_vec(numbers.to_vec())); + let column: VectorRef = Arc::new(Int32Vector::from_slice(&numbers)); columns.push(column); let schema = Arc::new(Schema::new(column_schemas)); diff --git a/src/query/tests/polyval_test.rs b/src/query/tests/polyval_test.rs index f2e60c0217..248c0d42d7 100644 --- a/src/query/tests/polyval_test.rs +++ b/src/query/tests/polyval_test.rs @@ -18,11 +18,9 @@ mod function; use common_query::Output; use common_recordbatch::error::Result as RecordResult; use common_recordbatch::{util, RecordBatch}; -use datafusion::field_util::{FieldExt, SchemaExt}; use datatypes::for_all_primitive_types; use datatypes::prelude::*; -use datatypes::types::PrimitiveElement; -use function::{create_query_engine, get_numbers_from_table}; +use datatypes::types::WrapperType; use num_traits::AsPrimitive; use query::error::Result; use query::QueryEngine; @@ -31,13 +29,13 @@ use session::context::QueryContext; #[tokio::test] async fn test_polyval_aggregator() -> Result<()> { common_telemetry::init_default_ut_logging(); - let engine = create_query_engine(); + let engine = function::create_query_engine(); macro_rules! test_polyval { ([], $( { $T:ty } ),*) => { $( let column_name = format!("{}_number", std::any::type_name::<$T>()); - test_polyval_success::<$T,<$T as Primitive>::LargestType>(&column_name, "numbers", engine.clone()).await?; + test_polyval_success::<$T, <<<$T as WrapperType>::LogicalType as LogicalPrimitiveType>::LargestType as LogicalPrimitiveType>::Wrapper>(&column_name, "numbers", engine.clone()).await?; )* } } @@ -51,36 +49,27 @@ async fn test_polyval_success( engine: Arc, ) -> Result<()> where - T: Primitive + AsPrimitive + PrimitiveElement, - PolyT: Primitive + std::ops::Mul + std::iter::Sum, - for<'a> T: Scalar = T>, - for<'a> PolyT: Scalar = PolyT>, - i64: AsPrimitive, + T: WrapperType, + PolyT: WrapperType, + T::Native: AsPrimitive, + PolyT::Native: std::ops::Mul + std::iter::Sum, + i64: AsPrimitive, { let result = execute_polyval(column_name, table_name, engine.clone()) .await .unwrap(); - assert_eq!(1, result.len()); - assert_eq!(result[0].df_recordbatch.num_columns(), 1); - assert_eq!(1, result[0].schema.arrow_schema().fields().len()); - assert_eq!("polyval", result[0].schema.arrow_schema().field(0).name()); + let value = function::get_value_from_batches("polyval", result); - let columns = result[0].df_recordbatch.columns(); - assert_eq!(1, columns.len()); - assert_eq!(columns[0].len(), 1); - let v = VectorHelper::try_into_vector(&columns[0]).unwrap(); - assert_eq!(1, v.len()); - let value = v.get(0); - - let numbers = get_numbers_from_table::(column_name, table_name, engine.clone()).await; + let numbers = + function::get_numbers_from_table::(column_name, table_name, engine.clone()).await; let expected_value = numbers.iter().copied(); let x = 0i64; let len = expected_value.len(); - let expected_value: PolyT = expected_value + let expected_native: PolyT::Native = expected_value .enumerate() - .map(|(i, value)| value.as_() * (x.pow((len - 1 - i) as u32)).as_()) + .map(|(i, v)| v.into_native().as_() * (x.pow((len - 1 - i) as u32)).as_()) .sum(); - assert_eq!(value, expected_value.into()); + assert_eq!(value, PolyT::from_native(expected_native).into()); Ok(()) } diff --git a/src/query/tests/pow.rs b/src/query/tests/pow.rs index 4d9006ca29..d48c28b220 100644 --- a/src/query/tests/pow.rs +++ b/src/query/tests/pow.rs @@ -32,7 +32,7 @@ pub fn pow(args: &[VectorRef]) -> Result { assert_eq!(exponent.len(), base.len()); - let v = base + let iter = base .iter_data() .zip(exponent.iter_data()) .map(|(base, exponent)| { @@ -42,8 +42,8 @@ pub fn pow(args: &[VectorRef]) -> Result { (Some(base), Some(exponent)) => Some(base.pow(exponent)), _ => None, } - }) - .collect::(); + }); + let v = UInt32Vector::from_owned_iterator(iter); Ok(Arc::new(v) as _) } diff --git a/src/query/tests/query_engine_test.rs b/src/query/tests/query_engine_test.rs index cf640afba4..05bb32a2c4 100644 --- a/src/query/tests/query_engine_test.rs +++ b/src/query/tests/query_engine_test.rs @@ -13,30 +13,28 @@ // limitations under the License. mod pow; +// This is used to suppress the warning: function `create_query_engine` is never used. +// FIXME(yingwen): We finally need to refactor these tests and move them to `query/src` +// so tests can share codes with other mods. +#[allow(unused)] +mod function; use std::sync::Arc; -use catalog::local::{MemoryCatalogManager, MemoryCatalogProvider, MemorySchemaProvider}; +use catalog::local::{MemoryCatalogProvider, MemorySchemaProvider}; use catalog::{CatalogList, CatalogProvider, SchemaProvider}; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_query::prelude::{create_udf, make_scalar_function, Volatility}; use common_query::Output; -use common_recordbatch::error::Result as RecordResult; use common_recordbatch::{util, RecordBatch}; -use datafusion::field_util::{FieldExt, SchemaExt}; -use datafusion::logical_plan::LogicalPlanBuilder; -use datatypes::arrow::array::UInt32Array; -use datatypes::for_all_primitive_types; +use datafusion::datasource::DefaultTableSource; +use datafusion_expr::logical_plan::builder::LogicalPlanBuilder; use datatypes::prelude::*; use datatypes::schema::{ColumnSchema, Schema}; -use datatypes::types::{OrdPrimitive, PrimitiveElement}; -use datatypes::vectors::{PrimitiveVector, UInt32Vector}; -use num::NumCast; +use datatypes::vectors::UInt32Vector; use query::error::Result; use query::plan::LogicalPlan; use query::query_engine::QueryEngineFactory; -use query::QueryEngine; -use rand::Rng; use session::context::QueryContext; use table::table::adapter::DfTableProviderAdapter; use table::table::numbers::NumbersTable; @@ -66,12 +64,16 @@ async fn test_datafusion_query_engine() -> Result<()> { let limit = 10; let table_provider = Arc::new(DfTableProviderAdapter::new(table.clone())); let plan = LogicalPlan::DfPlan( - LogicalPlanBuilder::scan("numbers", table_provider, None) - .unwrap() - .limit(limit) - .unwrap() - .build() - .unwrap(), + LogicalPlanBuilder::scan( + "numbers", + Arc::new(DefaultTableSource { table_provider }), + None, + ) + .unwrap() + .limit(0, Some(limit)) + .unwrap() + .build() + .unwrap(), ); let output = engine.execute(&plan).await?; @@ -84,17 +86,17 @@ async fn test_datafusion_query_engine() -> Result<()> { let numbers = util::collect(recordbatch).await.unwrap(); assert_eq!(1, numbers.len()); - assert_eq!(numbers[0].df_recordbatch.num_columns(), 1); - assert_eq!(1, numbers[0].schema.arrow_schema().fields().len()); - assert_eq!("number", numbers[0].schema.arrow_schema().field(0).name()); + assert_eq!(numbers[0].num_columns(), 1); + assert_eq!(1, numbers[0].schema.num_columns()); + assert_eq!("number", numbers[0].schema.column_schemas()[0].name); - let columns = numbers[0].df_recordbatch.columns(); - assert_eq!(1, columns.len()); - assert_eq!(columns[0].len(), limit); + let batch = &numbers[0]; + assert_eq!(1, batch.num_columns()); + assert_eq!(batch.column(0).len(), limit); let expected: Vec = (0u32..limit as u32).collect(); assert_eq!( - *columns[0].as_any().downcast_ref::().unwrap(), - UInt32Array::from_slice(&expected) + *batch.column(0), + Arc::new(UInt32Vector::from_slice(&expected)) as VectorRef ); Ok(()) @@ -123,7 +125,8 @@ async fn test_udf() -> Result<()> { let pow = make_scalar_function(pow); let udf = create_udf( - "pow", + // datafusion already supports pow, so we use a different name. + "my_pow", vec![ ConcreteDataType::uint32_datatype(), ConcreteDataType::uint32_datatype(), @@ -136,7 +139,7 @@ async fn test_udf() -> Result<()> { engine.register_udf(udf); let plan = engine.sql_to_plan( - "select pow(number, number) as p from numbers limit 10", + "select my_pow(number, number) as p from numbers limit 10", Arc::new(QueryContext::new()), )?; @@ -148,202 +151,18 @@ async fn test_udf() -> Result<()> { let numbers = util::collect(recordbatch).await.unwrap(); assert_eq!(1, numbers.len()); - assert_eq!(numbers[0].df_recordbatch.num_columns(), 1); - assert_eq!(1, numbers[0].schema.arrow_schema().fields().len()); - assert_eq!("p", numbers[0].schema.arrow_schema().field(0).name()); + assert_eq!(numbers[0].num_columns(), 1); + assert_eq!(1, numbers[0].schema.num_columns()); + assert_eq!("p", numbers[0].schema.column_schemas()[0].name); - let columns = numbers[0].df_recordbatch.columns(); - assert_eq!(1, columns.len()); - assert_eq!(columns[0].len(), 10); + let batch = &numbers[0]; + assert_eq!(1, batch.num_columns()); + assert_eq!(batch.column(0).len(), 10); let expected: Vec = vec![1, 1, 4, 27, 256, 3125, 46656, 823543, 16777216, 387420489]; assert_eq!( - *columns[0].as_any().downcast_ref::().unwrap(), - UInt32Array::from_slice(&expected) + *batch.column(0), + Arc::new(UInt32Vector::from_slice(&expected)) as VectorRef ); Ok(()) } - -fn create_query_engine() -> Arc { - let schema_provider = Arc::new(MemorySchemaProvider::new()); - let catalog_provider = Arc::new(MemoryCatalogProvider::new()); - let catalog_list = Arc::new(MemoryCatalogManager::default()); - - // create table with primitives, and all columns' length are even - let mut column_schemas = vec![]; - let mut columns = vec![]; - macro_rules! create_even_number_table { - ([], $( { $T:ty } ),*) => { - $( - let mut rng = rand::thread_rng(); - - let column_name = format!("{}_number_even", std::any::type_name::<$T>()); - let column_schema = ColumnSchema::new(column_name, Value::from(<$T>::default()).data_type(), true); - column_schemas.push(column_schema); - - let numbers = (1..=100).map(|_| rng.gen::<$T>()).collect::>(); - let column: VectorRef = Arc::new(PrimitiveVector::<$T>::from_vec(numbers.to_vec())); - columns.push(column); - )* - } - } - for_all_primitive_types! { create_even_number_table } - - let schema = Arc::new(Schema::new(column_schemas.clone())); - let recordbatch = RecordBatch::new(schema, columns).unwrap(); - let even_number_table = Arc::new(MemTable::new("even_numbers", recordbatch)); - schema_provider - .register_table( - even_number_table.table_name().to_string(), - even_number_table, - ) - .unwrap(); - - // create table with primitives, and all columns' length are odd - let mut column_schemas = vec![]; - let mut columns = vec![]; - macro_rules! create_odd_number_table { - ([], $( { $T:ty } ),*) => { - $( - let mut rng = rand::thread_rng(); - - let column_name = format!("{}_number_odd", std::any::type_name::<$T>()); - let column_schema = ColumnSchema::new(column_name, Value::from(<$T>::default()).data_type(), true); - column_schemas.push(column_schema); - - let numbers = (1..=99).map(|_| rng.gen::<$T>()).collect::>(); - let column: VectorRef = Arc::new(PrimitiveVector::<$T>::from_vec(numbers.to_vec())); - columns.push(column); - )* - } - } - for_all_primitive_types! { create_odd_number_table } - - let schema = Arc::new(Schema::new(column_schemas.clone())); - let recordbatch = RecordBatch::new(schema, columns).unwrap(); - let odd_number_table = Arc::new(MemTable::new("odd_numbers", recordbatch)); - schema_provider - .register_table(odd_number_table.table_name().to_string(), odd_number_table) - .unwrap(); - - catalog_provider - .register_schema(DEFAULT_SCHEMA_NAME.to_string(), schema_provider) - .unwrap(); - catalog_list - .register_catalog(DEFAULT_CATALOG_NAME.to_string(), catalog_provider) - .unwrap(); - - QueryEngineFactory::new(catalog_list).query_engine() -} - -async fn get_numbers_from_table<'s, T>( - column_name: &'s str, - table_name: &'s str, - engine: Arc, -) -> Vec> -where - T: PrimitiveElement, - for<'a> T: Scalar = T>, -{ - let sql = format!("SELECT {} FROM {}", column_name, table_name); - let plan = engine - .sql_to_plan(&sql, Arc::new(QueryContext::new())) - .unwrap(); - - let output = engine.execute(&plan).await.unwrap(); - let recordbatch_stream = match output { - Output::Stream(batch) => batch, - _ => unreachable!(), - }; - let numbers = util::collect(recordbatch_stream).await.unwrap(); - - let columns = numbers[0].df_recordbatch.columns(); - let column = VectorHelper::try_into_vector(&columns[0]).unwrap(); - let column: &::VectorType = unsafe { VectorHelper::static_cast(&column) }; - column - .iter_data() - .flatten() - .map(|x| OrdPrimitive::(x)) - .collect::>>() -} - -#[tokio::test] -async fn test_median_aggregator() -> Result<()> { - common_telemetry::init_default_ut_logging(); - - let engine = create_query_engine(); - - macro_rules! test_median { - ([], $( { $T:ty } ),*) => { - $( - let column_name = format!("{}_number_even", std::any::type_name::<$T>()); - test_median_success::<$T>(&column_name, "even_numbers", engine.clone()).await?; - - let column_name = format!("{}_number_odd", std::any::type_name::<$T>()); - test_median_success::<$T>(&column_name, "odd_numbers", engine.clone()).await?; - )* - } - } - for_all_primitive_types! { test_median } - Ok(()) -} - -async fn test_median_success( - column_name: &str, - table_name: &str, - engine: Arc, -) -> Result<()> -where - T: PrimitiveElement, - for<'a> T: Scalar = T>, -{ - let result = execute_median(column_name, table_name, engine.clone()) - .await - .unwrap(); - assert_eq!(1, result.len()); - assert_eq!(result[0].df_recordbatch.num_columns(), 1); - assert_eq!(1, result[0].schema.arrow_schema().fields().len()); - assert_eq!("median", result[0].schema.arrow_schema().field(0).name()); - - let columns = result[0].df_recordbatch.columns(); - assert_eq!(1, columns.len()); - assert_eq!(columns[0].len(), 1); - let v = VectorHelper::try_into_vector(&columns[0]).unwrap(); - assert_eq!(1, v.len()); - let median = v.get(0); - - let mut numbers = get_numbers_from_table::(column_name, table_name, engine.clone()).await; - numbers.sort(); - let len = numbers.len(); - let expected_median: Value = if len % 2 == 1 { - numbers[len / 2] - } else { - let a: f64 = NumCast::from(numbers[len / 2 - 1].as_primitive()).unwrap(); - let b: f64 = NumCast::from(numbers[len / 2].as_primitive()).unwrap(); - OrdPrimitive::(NumCast::from(a / 2.0 + b / 2.0).unwrap()) - } - .into(); - assert_eq!(expected_median, median); - Ok(()) -} - -async fn execute_median<'a>( - column_name: &'a str, - table_name: &'a str, - engine: Arc, -) -> RecordResult> { - let sql = format!( - "select MEDIAN({}) as median from {}", - column_name, table_name - ); - let plan = engine - .sql_to_plan(&sql, Arc::new(QueryContext::new())) - .unwrap(); - - let output = engine.execute(&plan).await.unwrap(); - let recordbatch_stream = match output { - Output::Stream(batch) => batch, - _ => unreachable!(), - }; - util::collect(recordbatch_stream).await -} diff --git a/src/query/tests/scipy_stats_norm_cdf_test.rs b/src/query/tests/scipy_stats_norm_cdf_test.rs index 815501a314..dee8f5c87e 100644 --- a/src/query/tests/scipy_stats_norm_cdf_test.rs +++ b/src/query/tests/scipy_stats_norm_cdf_test.rs @@ -18,11 +18,8 @@ mod function; use common_query::Output; use common_recordbatch::error::Result as RecordResult; use common_recordbatch::{util, RecordBatch}; -use datafusion::field_util::{FieldExt, SchemaExt}; use datatypes::for_all_primitive_types; -use datatypes::prelude::*; -use datatypes::types::PrimitiveElement; -use function::{create_query_engine, get_numbers_from_table}; +use datatypes::types::WrapperType; use num_traits::AsPrimitive; use query::error::Result; use query::QueryEngine; @@ -33,7 +30,7 @@ use statrs::statistics::Statistics; #[tokio::test] async fn test_scipy_stats_norm_cdf_aggregator() -> Result<()> { common_telemetry::init_default_ut_logging(); - let engine = create_query_engine(); + let engine = function::create_query_engine(); macro_rules! test_scipy_stats_norm_cdf { ([], $( { $T:ty } ),*) => { @@ -53,28 +50,15 @@ async fn test_scipy_stats_norm_cdf_success( engine: Arc, ) -> Result<()> where - T: PrimitiveElement + AsPrimitive, - for<'a> T: Scalar = T>, + T: WrapperType + AsPrimitive, { let result = execute_scipy_stats_norm_cdf(column_name, table_name, engine.clone()) .await .unwrap(); - assert_eq!(1, result.len()); - assert_eq!(result[0].df_recordbatch.num_columns(), 1); - assert_eq!(1, result[0].schema.arrow_schema().fields().len()); - assert_eq!( - "scipy_stats_norm_cdf", - result[0].schema.arrow_schema().field(0).name() - ); + let value = function::get_value_from_batches("scipy_stats_norm_cdf", result); - let columns = result[0].df_recordbatch.columns(); - assert_eq!(1, columns.len()); - assert_eq!(columns[0].len(), 1); - let v = VectorHelper::try_into_vector(&columns[0]).unwrap(); - assert_eq!(1, v.len()); - let value = v.get(0); - - let numbers = get_numbers_from_table::(column_name, table_name, engine.clone()).await; + let numbers = + function::get_numbers_from_table::(column_name, table_name, engine.clone()).await; let expected_value = numbers.iter().map(|&n| n.as_()).collect::>(); let mean = expected_value.clone().mean(); let stddev = expected_value.std_dev(); diff --git a/src/query/tests/scipy_stats_norm_pdf.rs b/src/query/tests/scipy_stats_norm_pdf.rs index dd5e0fc7fc..03e4cf1292 100644 --- a/src/query/tests/scipy_stats_norm_pdf.rs +++ b/src/query/tests/scipy_stats_norm_pdf.rs @@ -18,11 +18,8 @@ mod function; use common_query::Output; use common_recordbatch::error::Result as RecordResult; use common_recordbatch::{util, RecordBatch}; -use datafusion::field_util::{FieldExt, SchemaExt}; use datatypes::for_all_primitive_types; -use datatypes::prelude::*; -use datatypes::types::PrimitiveElement; -use function::{create_query_engine, get_numbers_from_table}; +use datatypes::types::WrapperType; use num_traits::AsPrimitive; use query::error::Result; use query::QueryEngine; @@ -33,7 +30,7 @@ use statrs::statistics::Statistics; #[tokio::test] async fn test_scipy_stats_norm_pdf_aggregator() -> Result<()> { common_telemetry::init_default_ut_logging(); - let engine = create_query_engine(); + let engine = function::create_query_engine(); macro_rules! test_scipy_stats_norm_pdf { ([], $( { $T:ty } ),*) => { @@ -53,28 +50,15 @@ async fn test_scipy_stats_norm_pdf_success( engine: Arc, ) -> Result<()> where - T: PrimitiveElement + AsPrimitive, - for<'a> T: Scalar = T>, + T: WrapperType + AsPrimitive, { let result = execute_scipy_stats_norm_pdf(column_name, table_name, engine.clone()) .await .unwrap(); - assert_eq!(1, result.len()); - assert_eq!(result[0].df_recordbatch.num_columns(), 1); - assert_eq!(1, result[0].schema.arrow_schema().fields().len()); - assert_eq!( - "scipy_stats_norm_pdf", - result[0].schema.arrow_schema().field(0).name() - ); + let value = function::get_value_from_batches("scipy_stats_norm_pdf", result); - let columns = result[0].df_recordbatch.columns(); - assert_eq!(1, columns.len()); - assert_eq!(columns[0].len(), 1); - let v = VectorHelper::try_into_vector(&columns[0]).unwrap(); - assert_eq!(1, v.len()); - let value = v.get(0); - - let numbers = get_numbers_from_table::(column_name, table_name, engine.clone()).await; + let numbers = + function::get_numbers_from_table::(column_name, table_name, engine.clone()).await; let expected_value = numbers.iter().map(|&n| n.as_()).collect::>(); let mean = expected_value.clone().mean(); let stddev = expected_value.std_dev(); diff --git a/src/sql/Cargo.toml b/src/sql/Cargo.toml index 6f7f40b017..ebdd0f172b 100644 --- a/src/sql/Cargo.toml +++ b/src/sql/Cargo.toml @@ -15,4 +15,4 @@ itertools = "0.10" mito = { path = "../mito" } once_cell = "1.10" snafu = { version = "0.7", features = ["backtraces"] } -sqlparser = "0.15.0" +sqlparser = "0.26" diff --git a/src/sql/src/ast.rs b/src/sql/src/ast.rs index 11636df8c0..7388b9453c 100644 --- a/src/sql/src/ast.rs +++ b/src/sql/src/ast.rs @@ -14,5 +14,5 @@ pub use sqlparser::ast::{ ColumnDef, ColumnOption, ColumnOptionDef, DataType, Expr, Function, FunctionArg, - FunctionArgExpr, Ident, ObjectName, SqlOption, TableConstraint, Value, + FunctionArgExpr, Ident, ObjectName, SqlOption, TableConstraint, TimezoneInfo, Value, }; diff --git a/src/sql/src/parser.rs b/src/sql/src/parser.rs index 254982e88e..3a14fb0666 100644 --- a/src/sql/src/parser.rs +++ b/src/sql/src/parser.rs @@ -505,11 +505,7 @@ mod tests { assert_matches!( &stmts[0], Statement::ShowTables(ShowTables { - kind: ShowKind::Where(sqlparser::ast::Expr::BinaryOp { - left: _, - right: _, - op: sqlparser::ast::BinaryOperator::Like, - }), + kind: ShowKind::Where(sqlparser::ast::Expr::Like { .. }), database: None, }) ); @@ -522,11 +518,7 @@ mod tests { assert_matches!( &stmts[0], Statement::ShowTables(ShowTables { - kind: ShowKind::Where(sqlparser::ast::Expr::BinaryOp { - left: _, - right: _, - op: sqlparser::ast::BinaryOperator::Like, - }), + kind: ShowKind::Where(sqlparser::ast::Expr::Like { .. }), database: Some(_), }) ); @@ -543,11 +535,12 @@ mod tests { distinct: false, top: None, projection: vec![sqlparser::ast::SelectItem::Wildcard], + into: None, from: vec![sqlparser::ast::TableWithJoins { relation: sqlparser::ast::TableFactor::Table { name: sqlparser::ast::ObjectName(vec![sqlparser::ast::Ident::new("foo")]), alias: None, - args: vec![], + args: None, with_hints: vec![], }, joins: vec![], @@ -559,11 +552,12 @@ mod tests { distribute_by: vec![], sort_by: vec![], having: None, + qualify: None, }; let sp_statement = SpStatement::Query(Box::new(SpQuery { with: None, - body: sqlparser::ast::SetExpr::Select(Box::new(select)), + body: Box::new(sqlparser::ast::SetExpr::Select(Box::new(select))), order_by: vec![], limit: None, offset: None, @@ -576,6 +570,7 @@ mod tests { analyze: false, verbose: false, statement: Box::new(sp_statement), + format: None, }) .unwrap(); diff --git a/src/sql/src/statements.rs b/src/sql/src/statements.rs index b88f9380fa..c82a2caaa8 100644 --- a/src/sql/src/statements.rs +++ b/src/sql/src/statements.rs @@ -318,7 +318,7 @@ pub fn sql_data_type_to_concrete_data_type(data_type: &SqlDataType) -> Result Ok(ConcreteDataType::timestamp_millisecond_datatype()), + SqlDataType::Timestamp(_) => Ok(ConcreteDataType::timestamp_millisecond_datatype()), _ => error::SqlTypeNotSupportedSnafu { t: data_type.clone(), } @@ -336,7 +336,7 @@ mod tests { use datatypes::value::OrderedFloat; use super::*; - use crate::ast::{DataType, Ident}; + use crate::ast::{Ident, TimezoneInfo}; use crate::statements::ColumnOption; fn check_type(sql_type: SqlDataType, data_type: ConcreteDataType) { @@ -376,7 +376,7 @@ mod tests { ConcreteDataType::datetime_datatype(), ); check_type( - SqlDataType::Timestamp, + SqlDataType::Timestamp(TimezoneInfo::None), ConcreteDataType::timestamp_millisecond_datatype(), ); } @@ -573,7 +573,7 @@ mod tests { // test basic let column_def = ColumnDef { name: "col".into(), - data_type: DataType::Double, + data_type: SqlDataType::Double, collation: None, options: vec![], }; @@ -588,7 +588,7 @@ mod tests { // test not null let column_def = ColumnDef { name: "col".into(), - data_type: DataType::Double, + data_type: SqlDataType::Double, collation: None, options: vec![ColumnOptionDef { name: None, diff --git a/src/sql/src/statements/insert.rs b/src/sql/src/statements/insert.rs index 410c0d09cb..f105648ea8 100644 --- a/src/sql/src/statements/insert.rs +++ b/src/sql/src/statements/insert.rs @@ -49,7 +49,7 @@ impl Insert { pub fn values(&self) -> Result>> { let values = match &self.inner { - Statement::Insert { source, .. } => match &source.body { + Statement::Insert { source, .. } => match &*source.body { SetExpr::Values(Values(exprs)) => sql_exprs_to_values(exprs)?, _ => unreachable!(), },