feat(flow): transform substrait SELECT&WHERE&GROUP BY to Flow Plan (#3690)

* feat: transofrm substrait SELECT&WHERE&GROUP BY to Flow Plan

* chore: reexport from common/substrait

* feat: use datafusion Aggr Func to map to Flow aggr func

* chore: remove unwrap&split literal

* refactor: split transform.rs into smaller files

* feat: apply optimize for variadic fn

* refactor: split unit test

* chore: per review
This commit is contained in:
discord9
2024-04-12 15:38:42 +08:00
committed by GitHub
parent 544c4a70f8
commit db329f6c80
15 changed files with 1559 additions and 26 deletions

View File

@@ -17,12 +17,12 @@
mod df_substrait;
pub mod error;
pub mod extension_serializer;
use std::sync::Arc;
use async_trait::async_trait;
use bytes::{Buf, Bytes};
use datafusion::catalog::CatalogList;
pub use substrait_proto;
pub use crate::df_substrait::DFLogicalSubstraitConvertor;

View File

@@ -29,6 +29,7 @@ servers.workspace = true
smallvec.workspace = true
snafu.workspace = true
strum.workspace = true
substrait.workspace = true
tokio.workspace = true
tonic.workspace = true
@@ -39,5 +40,4 @@ prost.workspace = true
query.workspace = true
serde_json = "1.0"
session.workspace = true
substrait.workspace = true
table.workspace = true

View File

@@ -73,6 +73,13 @@ pub enum Error {
extra: String,
location: Location,
},
#[snafu(display("Datafusion error: {raw:?} in context: {context}"))]
Datafusion {
raw: datafusion_common::DataFusionError,
context: String,
location: Location,
},
}
/// Result type for flow module
@@ -81,7 +88,9 @@ pub type Result<T> = std::result::Result<T, Error>;
impl ErrorExt for Error {
fn status_code(&self) -> StatusCode {
match self {
Self::Eval { .. } | &Self::JoinTask { .. } => StatusCode::Internal,
Self::Eval { .. } | &Self::JoinTask { .. } | &Self::Datafusion { .. } => {
StatusCode::Internal
}
&Self::TableAlreadyExist { .. } => StatusCode::TableAlreadyExists,
Self::TableNotFound { .. } => StatusCode::TableNotFound,
&Self::InvalidQuery { .. } | &Self::Plan { .. } | &Self::Datatypes { .. } => {

View File

@@ -344,7 +344,7 @@ mod test {
(Row::new(vec![2i64.into()]), 2, 1),
(Row::new(vec![3i64.into()]), 3, 1),
];
let collection = ctx.render_constant(rows.clone());
let collection = ctx.render_constant(rows);
ctx.insert_global(GlobalId::User(1), collection);
let input_plan = Plan::Get {
id: expr::Id::Global(GlobalId::User(1)),
@@ -440,7 +440,7 @@ mod test {
(Row::new(vec![2.into()]), 2, 1),
(Row::new(vec![3.into()]), 3, 1),
];
let collection = ctx.render_constant(rows.clone());
let collection = ctx.render_constant(rows);
ctx.insert_global(GlobalId::User(1), collection);
let input_plan = Plan::Get {
id: expr::Id::Global(GlobalId::User(1)),
@@ -490,7 +490,7 @@ mod test {
(Row::empty(), 2, 1),
(Row::empty(), 3, 1),
];
let collection = ctx.render_constant(rows.clone());
let collection = ctx.render_constant(rows);
let collection = collection.collection.clone(ctx.df);
let cnt = Rc::new(RefCell::new(0));
let cnt_inner = cnt.clone();

View File

@@ -27,4 +27,4 @@ pub(crate) use func::{BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc}
pub(crate) use id::{GlobalId, Id, LocalId};
pub(crate) use linear::{MapFilterProject, MfpPlan, SafeMfpPlan};
pub(crate) use relation::{AggregateExpr, AggregateFunc};
pub(crate) use scalar::ScalarExpr;
pub(crate) use scalar::{ScalarExpr, TypedExpr};

View File

@@ -501,8 +501,8 @@ impl BinaryFunc {
let spec_fn = Self::specialization(generic_fn, query_input_type)?;
let signature = Signature {
input: smallvec![arg_type.clone(), arg_type.clone()],
output: spec_fn.signature().output.clone(),
input: smallvec![arg_type.clone(), arg_type],
output: spec_fn.signature().output,
generic_fn,
};
@@ -767,7 +767,7 @@ fn test_num_ops() {
assert_eq!(res, Value::from(30));
let res = div::<i32>(left.clone(), right.clone()).unwrap();
assert_eq!(res, Value::from(3));
let res = rem::<i32>(left.clone(), right.clone()).unwrap();
let res = rem::<i32>(left, right).unwrap();
assert_eq!(res, Value::from(1));
let values = vec![Value::from(true), Value::from(false)];

View File

@@ -13,6 +13,7 @@
// limitations under the License.
use std::collections::HashMap;
use std::str::FromStr;
use std::sync::OnceLock;
use common_time::{Date, DateTime};
@@ -20,10 +21,10 @@ use datatypes::prelude::ConcreteDataType;
use datatypes::value::{OrderedF32, OrderedF64, Value};
use serde::{Deserialize, Serialize};
use smallvec::smallvec;
use snafu::OptionExt;
use snafu::{OptionExt, ResultExt};
use strum::{EnumIter, IntoEnumIterator};
use crate::adapter::error::{Error, InvalidQuerySnafu};
use crate::adapter::error::{DatafusionSnafu, Error, InvalidQuerySnafu};
use crate::expr::error::{EvalError, TryFromValueSnafu, TypeMismatchSnafu};
use crate::expr::relation::accum::{Accum, Accumulator};
use crate::expr::signature::{GenericFn, Signature};
@@ -172,17 +173,32 @@ impl AggregateFunc {
}
spec
});
use datafusion_expr::aggregate_function::AggregateFunction as DfAggrFunc;
let df_aggr_func = DfAggrFunc::from_str(name).or_else(|err| {
if let datafusion_common::DataFusionError::NotImplemented(msg) = err {
InvalidQuerySnafu {
reason: format!("Unsupported aggregate function: {}", msg),
}
.fail()
} else {
DatafusionSnafu {
raw: err,
context: "Error when parsing aggregate function",
}
.fail()
}
})?;
let generic_fn = match name {
"max" => GenericFn::Max,
"min" => GenericFn::Min,
"sum" => GenericFn::Sum,
"count" => GenericFn::Count,
"any" => GenericFn::Any,
"all" => GenericFn::All,
let generic_fn = match df_aggr_func {
DfAggrFunc::Max => GenericFn::Max,
DfAggrFunc::Min => GenericFn::Min,
DfAggrFunc::Sum => GenericFn::Sum,
DfAggrFunc::Count => GenericFn::Count,
DfAggrFunc::BoolOr => GenericFn::Any,
DfAggrFunc::BoolAnd => GenericFn::All,
_ => {
return InvalidQuerySnafu {
reason: format!("Unknown binary function: {}", name),
reason: format!("Unknown aggregate function: {}", name),
}
.fail();
}

View File

@@ -24,6 +24,22 @@ use snafu::ensure;
use crate::adapter::error::{Error, InvalidQuerySnafu, UnsupportedTemporalFilterSnafu};
use crate::expr::error::{EvalError, InvalidArgumentSnafu, OptimizeSnafu};
use crate::expr::func::{BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc};
use crate::repr::ColumnType;
/// A scalar expression with a known type.
#[derive(Debug, Clone)]
pub struct TypedExpr {
/// The expression.
pub expr: ScalarExpr,
/// The type of the expression.
pub typ: ColumnType,
}
impl TypedExpr {
pub fn new(expr: ScalarExpr, typ: ColumnType) -> Self {
Self { expr, typ }
}
}
/// A scalar expression, which can be evaluated to a value.
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
@@ -64,6 +80,38 @@ pub enum ScalarExpr {
},
}
impl ScalarExpr {
/// apply optimization to the expression, like flatten variadic function
pub fn optimize(&mut self) {
self.flatten_varidic_fn();
}
/// Because Substrait's `And`/`Or` function is binary, but FlowPlan's
/// `And`/`Or` function is variadic, we need to flatten the `And` function if multiple `And`/`Or` functions are nested.
fn flatten_varidic_fn(&mut self) {
if let ScalarExpr::CallVariadic { func, exprs } = self {
let mut new_exprs = vec![];
for expr in std::mem::take(exprs) {
if let ScalarExpr::CallVariadic {
func: inner_func,
exprs: mut inner_exprs,
} = expr
{
if *func == inner_func {
for inner_expr in inner_exprs.iter_mut() {
inner_expr.flatten_varidic_fn();
}
new_exprs.extend(inner_exprs);
}
} else {
new_exprs.push(expr);
}
}
*exprs = new_exprs;
}
}
}
impl ScalarExpr {
/// Call a unary function on this expression.
pub fn call_unary(self, func: UnaryFunc) -> Self {

View File

@@ -27,4 +27,5 @@ mod compute;
mod expr;
mod plan;
mod repr;
mod transform;
mod utils;

View File

@@ -24,7 +24,7 @@ use serde::{Deserialize, Serialize};
pub(crate) use self::reduce::{AccumulablePlan, KeyValPlan, ReducePlan};
use crate::adapter::error::Error;
use crate::expr::{
AggregateExpr, EvalError, Id, LocalId, MapFilterProject, SafeMfpPlan, ScalarExpr,
AggregateExpr, EvalError, Id, LocalId, MapFilterProject, SafeMfpPlan, ScalarExpr, TypedExpr,
};
use crate::plan::join::JoinPlan;
use crate::repr::{ColumnType, DiffRow, RelationType};
@@ -61,10 +61,13 @@ impl TypedPlan {
}
/// project the plan to the given expressions
pub fn projection(self, exprs: Vec<(ScalarExpr, ColumnType)>) -> Result<Self, Error> {
pub fn projection(self, exprs: Vec<TypedExpr>) -> Result<Self, Error> {
let input_arity = self.typ.column_types.len();
let output_arity = exprs.len();
let (exprs, expr_typs): (Vec<_>, Vec<_>) = exprs.into_iter().unzip();
let (exprs, expr_typs): (Vec<_>, Vec<_>) = exprs
.into_iter()
.map(|TypedExpr { expr, typ }| (expr, typ))
.unzip();
let mfp = MapFilterProject::new(input_arity)
.map(exprs)?
.project(input_arity..input_arity + output_arity)?;
@@ -87,18 +90,19 @@ impl TypedPlan {
}
/// Add a new filter to the plan, will filter out the records that do not satisfy the filter
pub fn filter(self, filter: (ScalarExpr, ColumnType)) -> Result<Self, Error> {
pub fn filter(self, filter: TypedExpr) -> Result<Self, Error> {
let plan = match self.plan {
Plan::Mfp {
input,
mfp: old_mfp,
} => Plan::Mfp {
input,
mfp: old_mfp.filter(vec![filter.0])?,
mfp: old_mfp.filter(vec![filter.expr])?,
},
_ => Plan::Mfp {
input: Box::new(self.plan),
mfp: MapFilterProject::new(self.typ.column_types.len()).filter(vec![filter.0])?,
mfp: MapFilterProject::new(self.typ.column_types.len())
.filter(vec![filter.expr])?,
},
};
Ok(TypedPlan {

179
src/flow/src/transform.rs Normal file
View File

@@ -0,0 +1,179 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Transform Substrait into execution plan
use std::collections::HashMap;
use datatypes::data_type::ConcreteDataType as CDT;
use crate::adapter::error::{Error, NotImplementedSnafu, TableNotFoundSnafu};
use crate::expr::GlobalId;
use crate::repr::RelationType;
/// a simple macro to generate a not implemented error
macro_rules! not_impl_err {
($($arg:tt)*) => {
NotImplementedSnafu {
reason: format!($($arg)*),
}.fail()
};
}
/// generate a plan error
macro_rules! plan_err {
($($arg:tt)*) => {
PlanSnafu {
reason: format!($($arg)*),
}.fail()
};
}
mod aggr;
mod expr;
mod literal;
mod plan;
use literal::{from_substrait_literal, from_substrait_type};
use snafu::OptionExt;
use substrait::substrait_proto::proto::extensions::simple_extension_declaration::MappingType;
use substrait::substrait_proto::proto::extensions::SimpleExtensionDeclaration;
/// In Substrait, a function can be define by an u32 anchor, and the anchor can be mapped to a name
///
/// So in substrait plan, a ref to a function can be a single u32 anchor instead of a full name in string
pub struct FunctionExtensions {
anchor_to_name: HashMap<u32, String>,
}
impl FunctionExtensions {
/// Create a new FunctionExtensions from a list of SimpleExtensionDeclaration
pub fn try_from_proto(extensions: &[SimpleExtensionDeclaration]) -> Result<Self, Error> {
let mut anchor_to_name = HashMap::new();
for e in extensions {
match &e.mapping_type {
Some(ext) => match ext {
MappingType::ExtensionFunction(ext_f) => {
anchor_to_name.insert(ext_f.function_anchor, ext_f.name.clone());
}
_ => not_impl_err!("Extension type not supported: {ext:?}")?,
},
None => not_impl_err!("Cannot parse empty extension")?,
}
}
Ok(Self { anchor_to_name })
}
/// Get the name of a function by it's anchor
pub fn get(&self, anchor: &u32) -> Option<&String> {
self.anchor_to_name.get(anchor)
}
}
/// A context that holds the information of the dataflow
pub struct DataflowContext {
/// `id` refer to any source table in the dataflow, and `name` is the name of the table
/// which is a `Vec<String>` in substrait
id_to_name: HashMap<GlobalId, Vec<String>>,
/// see `id_to_name`
name_to_id: HashMap<Vec<String>, GlobalId>,
/// the schema of the table
schema: HashMap<GlobalId, RelationType>,
}
impl DataflowContext {
/// Retrieves a GlobalId and table schema representing a table previously registered by calling the [register_table] function.
///
/// Returns an error if no table has been registered with the provided names
pub fn table(&self, name: &Vec<String>) -> Result<(GlobalId, RelationType), Error> {
let id = self
.name_to_id
.get(name)
.copied()
.with_context(|| TableNotFoundSnafu {
name: name.join("."),
})?;
let schema = self
.schema
.get(&id)
.cloned()
.with_context(|| TableNotFoundSnafu {
name: name.join("."),
})?;
Ok((id, schema))
}
}
#[cfg(test)]
mod test {
use std::sync::Arc;
use catalog::RegisterTableRequest;
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, NUMBERS_TABLE_ID};
use prost::Message;
use query::parser::QueryLanguageParser;
use query::plan::LogicalPlan;
use query::QueryEngine;
use session::context::QueryContext;
use substrait::substrait_proto::proto;
use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
use table::table::numbers::{NumbersTable, NUMBERS_TABLE_NAME};
use super::*;
use crate::repr::ColumnType;
pub fn create_test_ctx() -> DataflowContext {
let gid = GlobalId::User(0);
let name = vec!["numbers".to_string()];
let schema = RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), false)]);
DataflowContext {
id_to_name: HashMap::from([(gid, name.clone())]),
name_to_id: HashMap::from([(name.clone(), gid)]),
schema: HashMap::from([(gid, schema)]),
}
}
pub fn create_test_query_engine() -> Arc<dyn QueryEngine> {
let catalog_list = catalog::memory::new_memory_catalog_manager().unwrap();
let req = RegisterTableRequest {
catalog: DEFAULT_CATALOG_NAME.to_string(),
schema: DEFAULT_SCHEMA_NAME.to_string(),
table_name: NUMBERS_TABLE_NAME.to_string(),
table_id: NUMBERS_TABLE_ID,
table: NumbersTable::table(NUMBERS_TABLE_ID),
};
catalog_list.register_table_sync(req).unwrap();
let factory = query::QueryEngineFactory::new(catalog_list, None, None, None, false);
let engine = factory.query_engine();
assert_eq!("datafusion", engine.name());
engine
}
pub async fn sql_to_substrait(engine: Arc<dyn QueryEngine>, sql: &str) -> proto::Plan {
// let engine = create_test_query_engine();
let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
let plan = engine
.planner()
.plan(stmt, QueryContext::arc())
.await
.unwrap();
let LogicalPlan::DfPlan(plan) = plan;
// encode then decode so to rely on the impl of conversion from logical plan to substrait plan
let bytes = DFLogicalSubstraitConvertor {}.encode(&plan).unwrap();
proto::Plan::decode(bytes).unwrap()
}
}

View File

@@ -0,0 +1,446 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use common_decimal::Decimal128;
use common_time::{Date, Timestamp};
use datafusion_substrait::variation_const::{
DATE_32_TYPE_REF, DATE_64_TYPE_REF, DEFAULT_TYPE_REF, TIMESTAMP_MICRO_TYPE_REF,
TIMESTAMP_MILLI_TYPE_REF, TIMESTAMP_NANO_TYPE_REF, TIMESTAMP_SECOND_TYPE_REF,
UNSIGNED_INTEGER_TYPE_REF,
};
use datatypes::arrow::compute::kernels::window;
use datatypes::arrow::ipc::Binary;
use datatypes::data_type::ConcreteDataType as CDT;
use datatypes::value::Value;
use hydroflow::futures::future::Map;
use itertools::Itertools;
use snafu::{OptionExt, ResultExt};
use substrait::substrait_proto::proto::aggregate_function::AggregationInvocation;
use substrait::substrait_proto::proto::aggregate_rel::{Grouping, Measure};
use substrait::substrait_proto::proto::expression::field_reference::ReferenceType::DirectReference;
use substrait::substrait_proto::proto::expression::literal::LiteralType;
use substrait::substrait_proto::proto::expression::reference_segment::ReferenceType::StructField;
use substrait::substrait_proto::proto::expression::{
IfThen, Literal, MaskExpression, RexType, ScalarFunction,
};
use substrait::substrait_proto::proto::extensions::simple_extension_declaration::MappingType;
use substrait::substrait_proto::proto::extensions::SimpleExtensionDeclaration;
use substrait::substrait_proto::proto::function_argument::ArgType;
use substrait::substrait_proto::proto::r#type::Kind;
use substrait::substrait_proto::proto::read_rel::ReadType;
use substrait::substrait_proto::proto::rel::RelType;
use substrait::substrait_proto::proto::{self, plan_rel, Expression, Plan as SubPlan, Rel};
use crate::adapter::error::{
DatatypesSnafu, Error, EvalSnafu, InvalidQuerySnafu, NotImplementedSnafu, PlanSnafu,
TableNotFoundSnafu,
};
use crate::expr::{
AggregateExpr, AggregateFunc, BinaryFunc, GlobalId, MapFilterProject, SafeMfpPlan, ScalarExpr,
TypedExpr, UnaryFunc, UnmaterializableFunc, VariadicFunc,
};
use crate::plan::{AccumulablePlan, KeyValPlan, Plan, ReducePlan, TypedPlan};
use crate::repr::{self, ColumnType, RelationType};
use crate::transform::{DataflowContext, FunctionExtensions};
impl TypedExpr {
fn from_substrait_agg_grouping(
ctx: &mut DataflowContext,
groupings: &[Grouping],
typ: &RelationType,
extensions: &FunctionExtensions,
) -> Result<Vec<TypedExpr>, Error> {
let _ = ctx;
let mut group_expr = vec![];
match groupings.len() {
1 => {
for e in &groupings[0].grouping_expressions {
let x = TypedExpr::from_substrait_rex(e, typ, extensions)?;
group_expr.push(x);
}
}
_ => {
return not_impl_err!(
"Grouping sets not support yet, use union all with group by instead."
);
}
};
Ok(group_expr)
}
}
impl AggregateExpr {
fn from_substrait_agg_measures(
ctx: &mut DataflowContext,
measures: &[Measure],
typ: &RelationType,
extensions: &FunctionExtensions,
) -> Result<Vec<AggregateExpr>, Error> {
let _ = ctx;
let mut aggr_exprs = vec![];
for m in measures {
let filter = &m
.filter
.as_ref()
.map(|fil| TypedExpr::from_substrait_rex(fil, typ, extensions))
.transpose()?;
let agg_func = match &m.measure {
Some(f) => {
let distinct = match f.invocation {
_ if f.invocation == AggregationInvocation::Distinct as i32 => true,
_ if f.invocation == AggregationInvocation::All as i32 => false,
_ => false,
};
AggregateExpr::from_substrait_agg_func(
f, typ, extensions, filter, // TODO(discord9): impl order_by
&None, distinct,
)
}
None => not_impl_err!("Aggregate without aggregate function is not supported"),
}?;
aggr_exprs.push(agg_func);
}
Ok(aggr_exprs)
}
/// Convert AggregateFunction into Flow's AggregateExpr
pub fn from_substrait_agg_func(
f: &proto::AggregateFunction,
input_schema: &RelationType,
extensions: &FunctionExtensions,
filter: &Option<TypedExpr>,
order_by: &Option<Vec<TypedExpr>>,
distinct: bool,
) -> Result<AggregateExpr, Error> {
// TODO(discord9): impl filter
let _ = filter;
let _ = order_by;
let mut args = vec![];
for arg in &f.arguments {
let arg_expr = match &arg.arg_type {
Some(ArgType::Value(e)) => {
TypedExpr::from_substrait_rex(e, input_schema, extensions)
}
_ => not_impl_err!("Aggregated function argument non-Value type not supported"),
}?;
args.push(arg_expr);
}
let arg = if let Some(first) = args.first() {
first
} else {
return not_impl_err!("Aggregated function without arguments is not supported");
};
let func = match extensions.get(&f.function_reference) {
Some(function_name) => {
AggregateFunc::from_str_and_type(function_name, Some(arg.typ.scalar_type.clone()))
}
None => not_impl_err!(
"Aggregated function not found: function anchor = {:?}",
f.function_reference
),
}?;
Ok(AggregateExpr {
func,
expr: arg.expr.clone(),
distinct,
})
}
}
impl KeyValPlan {
/// Generate KeyValPlan from AggregateExpr and group_exprs
///
/// will also change aggregate expr to use column ref if necessary
fn from_substrait_gen_key_val_plan(
aggr_exprs: &mut [AggregateExpr],
group_exprs: &[TypedExpr],
input_arity: usize,
) -> Result<KeyValPlan, Error> {
let group_expr_val = group_exprs
.iter()
.cloned()
.map(|expr| expr.expr.clone())
.collect_vec();
let output_arity = group_expr_val.len();
let key_plan = MapFilterProject::new(input_arity)
.map(group_expr_val)?
.project(input_arity..input_arity + output_arity)?;
// val_plan is extracted from aggr_exprs to give aggr function it's necessary input
// and since aggr func need inputs that is column ref, we just add a prefix mfp to transform any expr that is not into a column ref
let val_plan = {
let need_mfp = aggr_exprs.iter().any(|agg| agg.expr.as_column().is_none());
if need_mfp {
// create mfp from aggr_expr, and modify aggr_expr to use the output column of mfp
let input_exprs = aggr_exprs
.iter_mut()
.enumerate()
.map(|(idx, aggr)| {
let ret = aggr.expr.clone();
aggr.expr = ScalarExpr::Column(idx);
ret
})
.collect_vec();
let aggr_arity = aggr_exprs.len();
MapFilterProject::new(input_arity)
.map(input_exprs)?
.project(input_arity..input_arity + aggr_arity)?
} else {
// simply take all inputs as value
MapFilterProject::new(input_arity)
}
};
Ok(KeyValPlan {
key_plan: key_plan.into_safe(),
val_plan: val_plan.into_safe(),
})
}
}
impl TypedPlan {
/// Convert AggregateRel into Flow's TypedPlan
pub fn from_substrait_agg_rel(
ctx: &mut DataflowContext,
agg: &proto::AggregateRel,
extensions: &FunctionExtensions,
) -> Result<TypedPlan, Error> {
let input = if let Some(input) = agg.input.as_ref() {
TypedPlan::from_substrait_rel(ctx, input, extensions)?
} else {
return not_impl_err!("Aggregate without an input is not supported");
};
let group_expr =
TypedExpr::from_substrait_agg_grouping(ctx, &agg.groupings, &input.typ, extensions)?;
let mut aggr_exprs =
AggregateExpr::from_substrait_agg_measures(ctx, &agg.measures, &input.typ, extensions)?;
let key_val_plan = KeyValPlan::from_substrait_gen_key_val_plan(
&mut aggr_exprs,
&group_expr,
input.typ.column_types.len(),
)?;
let output_type = {
let mut output_types = Vec::new();
// first append group_expr as key, then aggr_expr as value
for expr in &group_expr {
output_types.push(expr.typ.clone());
}
for aggr in &aggr_exprs {
output_types.push(ColumnType::new_nullable(
aggr.func.signature().output.clone(),
));
}
RelationType::new(output_types)
};
// copy aggr_exprs to full_aggrs, and split them into simple_aggrs and distinct_aggrs
// also set them input/output column
let full_aggrs = aggr_exprs;
let mut simple_aggrs = Vec::new();
let mut distinct_aggrs = Vec::new();
for (output_column, aggr_expr) in full_aggrs.iter().enumerate() {
let input_column = aggr_expr.expr.as_column().with_context(|| PlanSnafu {
reason: "Expect aggregate argument to be transformed into a column at this point",
})?;
if aggr_expr.distinct {
distinct_aggrs.push((output_column, input_column, aggr_expr.clone()));
} else {
simple_aggrs.push((output_column, input_column, aggr_expr.clone()));
}
}
let accum_plan = AccumulablePlan {
full_aggrs,
simple_aggrs,
distinct_aggrs,
};
let plan = Plan::Reduce {
input: Box::new(input.plan),
key_val_plan,
reduce_plan: ReducePlan::Accumulable(accum_plan),
};
Ok(TypedPlan {
typ: output_type,
plan,
})
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::plan::{Plan, TypedPlan};
use crate::repr::{self, ColumnType, RelationType};
use crate::transform::test::{create_test_ctx, create_test_query_engine, sql_to_substrait};
#[tokio::test]
async fn test_sum() {
let engine = create_test_query_engine();
let sql = "SELECT sum(number) FROM numbers";
let plan = sql_to_substrait(engine.clone(), sql).await;
let mut ctx = create_test_ctx();
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
let aggr_expr = AggregateExpr {
func: AggregateFunc::SumUInt32,
expr: ScalarExpr::Column(0),
distinct: false,
};
let expected = TypedPlan {
typ: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)]),
plan: Plan::Mfp {
input: Box::new(Plan::Reduce {
input: Box::new(Plan::Get {
id: crate::expr::Id::Global(GlobalId::User(0)),
}),
key_val_plan: KeyValPlan {
key_plan: MapFilterProject::new(1)
.project(vec![])
.unwrap()
.into_safe(),
val_plan: MapFilterProject::new(1)
.project(vec![0])
.unwrap()
.into_safe(),
},
reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
full_aggrs: vec![aggr_expr.clone()],
simple_aggrs: vec![(0, 0, aggr_expr.clone())],
distinct_aggrs: vec![],
}),
}),
mfp: MapFilterProject::new(1)
.map(vec![ScalarExpr::Column(0)])
.unwrap()
.project(vec![1])
.unwrap(),
},
};
assert_eq!(flow_plan.unwrap(), expected);
}
#[tokio::test]
async fn test_sum_group_by() {
let engine = create_test_query_engine();
let sql = "SELECT sum(number), number FROM numbers GROUP BY number";
let plan = sql_to_substrait(engine.clone(), sql).await;
let mut ctx = create_test_ctx();
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).unwrap();
let aggr_expr = AggregateExpr {
func: AggregateFunc::SumUInt32,
expr: ScalarExpr::Column(0),
distinct: false,
};
let expected = TypedPlan {
typ: RelationType::new(vec![
ColumnType::new(CDT::uint32_datatype(), true),
ColumnType::new(CDT::uint32_datatype(), false),
]),
plan: Plan::Mfp {
input: Box::new(Plan::Reduce {
input: Box::new(Plan::Get {
id: crate::expr::Id::Global(GlobalId::User(0)),
}),
key_val_plan: KeyValPlan {
key_plan: MapFilterProject::new(1)
.map(vec![ScalarExpr::Column(0)])
.unwrap()
.project(vec![1])
.unwrap()
.into_safe(),
val_plan: MapFilterProject::new(1)
.project(vec![0])
.unwrap()
.into_safe(),
},
reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
full_aggrs: vec![aggr_expr.clone()],
simple_aggrs: vec![(0, 0, aggr_expr.clone())],
distinct_aggrs: vec![],
}),
}),
mfp: MapFilterProject::new(2)
.map(vec![ScalarExpr::Column(1), ScalarExpr::Column(0)])
.unwrap()
.project(vec![2, 3])
.unwrap(),
},
};
assert_eq!(flow_plan, expected);
}
#[tokio::test]
async fn test_sum_add() {
let engine = create_test_query_engine();
let sql = "SELECT sum(number+number) FROM numbers";
let plan = sql_to_substrait(engine.clone(), sql).await;
let mut ctx = create_test_ctx();
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
let aggr_expr = AggregateExpr {
func: AggregateFunc::SumUInt32,
expr: ScalarExpr::Column(0),
distinct: false,
};
let expected = TypedPlan {
typ: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)]),
plan: Plan::Mfp {
input: Box::new(Plan::Reduce {
input: Box::new(Plan::Get {
id: crate::expr::Id::Global(GlobalId::User(0)),
}),
key_val_plan: KeyValPlan {
key_plan: MapFilterProject::new(1)
.project(vec![])
.unwrap()
.into_safe(),
val_plan: MapFilterProject::new(1)
.map(vec![ScalarExpr::Column(0)
.call_binary(ScalarExpr::Column(0), BinaryFunc::AddUInt32)])
.unwrap()
.project(vec![1])
.unwrap()
.into_safe(),
},
reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
full_aggrs: vec![aggr_expr.clone()],
simple_aggrs: vec![(0, 0, aggr_expr.clone())],
distinct_aggrs: vec![],
}),
}),
mfp: MapFilterProject::new(1)
.map(vec![ScalarExpr::Column(0)])
.unwrap()
.project(vec![1])
.unwrap(),
},
};
assert_eq!(flow_plan.unwrap(), expected);
}
}

View File

@@ -0,0 +1,449 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#![warn(unused_imports)]
use datatypes::data_type::ConcreteDataType as CDT;
use itertools::Itertools;
use snafu::{OptionExt, ResultExt};
use substrait::substrait_proto::proto::expression::field_reference::ReferenceType::DirectReference;
use substrait::substrait_proto::proto::expression::reference_segment::ReferenceType::StructField;
use substrait::substrait_proto::proto::expression::{IfThen, RexType, ScalarFunction};
use substrait::substrait_proto::proto::function_argument::ArgType;
use substrait::substrait_proto::proto::Expression;
use crate::adapter::error::{
DatatypesSnafu, Error, EvalSnafu, InvalidQuerySnafu, NotImplementedSnafu, PlanSnafu,
};
use crate::expr::{
BinaryFunc, ScalarExpr, TypedExpr, UnaryFunc, UnmaterializableFunc, VariadicFunc,
};
use crate::repr::{ColumnType, RelationType};
use crate::transform::literal::{from_substrait_literal, from_substrait_type};
use crate::transform::FunctionExtensions;
impl TypedExpr {
/// Convert ScalarFunction into Flow's ScalarExpr
pub fn from_substrait_scalar_func(
f: &ScalarFunction,
input_schema: &RelationType,
extensions: &FunctionExtensions,
) -> Result<TypedExpr, Error> {
let fn_name =
extensions
.get(&f.function_reference)
.with_context(|| NotImplementedSnafu {
reason: format!(
"Aggregated function not found: function reference = {:?}",
f.function_reference
),
})?;
let arg_len = f.arguments.len();
let arg_exprs: Vec<TypedExpr> = f
.arguments
.iter()
.map(|arg| match &arg.arg_type {
Some(ArgType::Value(e)) => {
TypedExpr::from_substrait_rex(e, input_schema, extensions)
}
_ => not_impl_err!("Aggregated function argument non-Value type not supported"),
})
.try_collect()?;
// literal's type is determined by the function and type of other args
let (arg_exprs, arg_types): (Vec<_>, Vec<_>) = arg_exprs
.into_iter()
.map(
|TypedExpr {
expr: arg_val,
typ: arg_type,
}| {
if arg_val.is_literal() {
(arg_val, None)
} else {
(arg_val, Some(arg_type.scalar_type))
}
},
)
.unzip();
match arg_len {
// because variadic function can also have 1 arguments, we need to check if it's a variadic function first
1 if VariadicFunc::from_str_and_types(fn_name, &arg_types).is_err() => {
let func = UnaryFunc::from_str_and_type(fn_name, None)?;
let arg = arg_exprs[0].clone();
let ret_type = ColumnType::new_nullable(func.signature().output.clone());
Ok(TypedExpr::new(arg.call_unary(func), ret_type))
}
// because variadic function can also have 2 arguments, we need to check if it's a variadic function first
2 if VariadicFunc::from_str_and_types(fn_name, &arg_types).is_err() => {
let (func, signature) =
BinaryFunc::from_str_expr_and_type(fn_name, &arg_exprs, &arg_types[0..2])?;
// constant folding here
let is_all_literal = arg_exprs.iter().all(|arg| arg.is_literal());
if is_all_literal {
let res = func
.eval(&[], &arg_exprs[0], &arg_exprs[1])
.context(EvalSnafu)?;
// if output type is null, it should be inferred from the input types
let con_typ = signature.output.clone();
let typ = ColumnType::new_nullable(con_typ.clone());
return Ok(TypedExpr::new(ScalarExpr::Literal(res, con_typ), typ));
}
let mut arg_exprs = arg_exprs;
for (idx, arg_expr) in arg_exprs.iter_mut().enumerate() {
if let ScalarExpr::Literal(val, typ) = arg_expr {
let dest_type = signature.input[idx].clone();
// cast val to target_type
let dest_val = if !dest_type.is_null() {
datatypes::types::cast(val.clone(), &dest_type)
.with_context(|_|
DatatypesSnafu{
extra: format!("Failed to implicitly cast literal {val:?} to type {dest_type:?}")
})?
} else {
val.clone()
};
*val = dest_val;
*typ = dest_type;
}
}
let ret_type = ColumnType::new_nullable(func.signature().output.clone());
let ret_expr = arg_exprs[0].clone().call_binary(arg_exprs[1].clone(), func);
Ok(TypedExpr::new(ret_expr, ret_type))
}
_var => {
if let Ok(func) = VariadicFunc::from_str_and_types(fn_name, &arg_types) {
let ret_type = ColumnType::new_nullable(func.signature().output.clone());
let mut expr = ScalarExpr::CallVariadic {
func,
exprs: arg_exprs,
};
expr.optimize();
Ok(TypedExpr::new(expr, ret_type))
} else if let Ok(func) = UnmaterializableFunc::from_str(fn_name) {
let ret_type = ColumnType::new_nullable(func.signature().output.clone());
Ok(TypedExpr::new(
ScalarExpr::CallUnmaterializable(func),
ret_type,
))
} else {
not_impl_err!("Unsupported function {fn_name} with {arg_len} arguments")
}
}
}
}
/// Convert IfThen into Flow's ScalarExpr
pub fn from_substrait_ifthen_rex(
if_then: &IfThen,
input_schema: &RelationType,
extensions: &FunctionExtensions,
) -> Result<TypedExpr, Error> {
let ifs: Vec<_> = if_then
.ifs
.iter()
.map(|if_clause| {
let proto_if = if_clause.r#if.as_ref().with_context(|| InvalidQuerySnafu {
reason: "IfThen clause without if",
})?;
let proto_then = if_clause.then.as_ref().with_context(|| InvalidQuerySnafu {
reason: "IfThen clause without then",
})?;
let cond = TypedExpr::from_substrait_rex(proto_if, input_schema, extensions)?;
let then = TypedExpr::from_substrait_rex(proto_then, input_schema, extensions)?;
Ok((cond, then))
})
.try_collect()?;
// if no else is presented
let els = if_then
.r#else
.as_ref()
.map(|e| TypedExpr::from_substrait_rex(e, input_schema, extensions))
.transpose()?
.unwrap_or_else(|| {
TypedExpr::new(
ScalarExpr::literal_null(),
ColumnType::new_nullable(CDT::null_datatype()),
)
});
fn build_if_then_recur(
mut next_if_then: impl Iterator<Item = (TypedExpr, TypedExpr)>,
els: TypedExpr,
) -> TypedExpr {
if let Some((cond, then)) = next_if_then.next() {
// always assume the type of `if`` expr is the same with the `then`` expr
TypedExpr::new(
ScalarExpr::If {
cond: Box::new(cond.expr),
then: Box::new(then.expr),
els: Box::new(build_if_then_recur(next_if_then, els).expr),
},
then.typ,
)
} else {
els
}
}
let expr_if = build_if_then_recur(ifs.into_iter(), els);
Ok(expr_if)
}
/// Convert Substrait Rex into Flow's ScalarExpr
pub fn from_substrait_rex(
e: &Expression,
input_schema: &RelationType,
extensions: &FunctionExtensions,
) -> Result<TypedExpr, Error> {
match &e.rex_type {
Some(RexType::Literal(lit)) => {
let lit = from_substrait_literal(lit)?;
Ok(TypedExpr::new(
ScalarExpr::Literal(lit.0, lit.1.clone()),
ColumnType::new_nullable(lit.1),
))
}
Some(RexType::SingularOrList(s)) => {
let substrait_expr = s.value.as_ref().with_context(|| InvalidQuerySnafu {
reason: "SingularOrList expression without value",
})?;
// Note that we didn't impl support to in list expr
if !s.options.is_empty() {
return not_impl_err!("In list expression is not supported");
}
TypedExpr::from_substrait_rex(substrait_expr, input_schema, extensions)
}
Some(RexType::Selection(field_ref)) => match &field_ref.reference_type {
Some(DirectReference(direct)) => match &direct.reference_type.as_ref() {
Some(StructField(x)) => match &x.child.as_ref() {
Some(_) => {
not_impl_err!(
"Direct reference StructField with child is not supported"
)
}
None => {
let column = x.field as usize;
let column_type = input_schema.column_types[column].clone();
Ok(TypedExpr::new(ScalarExpr::Column(column), column_type))
}
},
_ => not_impl_err!(
"Direct reference with types other than StructField is not supported"
),
},
_ => not_impl_err!("unsupported field ref type"),
},
Some(RexType::ScalarFunction(f)) => {
TypedExpr::from_substrait_scalar_func(f, input_schema, extensions)
}
Some(RexType::IfThen(if_then)) => {
TypedExpr::from_substrait_ifthen_rex(if_then, input_schema, extensions)
}
Some(RexType::Cast(cast)) => {
let input = cast.input.as_ref().with_context(|| InvalidQuerySnafu {
reason: "Cast expression without input",
})?;
let input = TypedExpr::from_substrait_rex(input, input_schema, extensions)?;
let cast_type = from_substrait_type(cast.r#type.as_ref().with_context(|| {
InvalidQuerySnafu {
reason: "Cast expression without type",
}
})?)?;
let func = UnaryFunc::from_str_and_type("cast", Some(cast_type.clone()))?;
Ok(TypedExpr::new(
input.expr.call_unary(func),
ColumnType::new_nullable(cast_type),
))
}
Some(RexType::WindowFunction(_)) => PlanSnafu {
reason:
"Window function is not supported yet. Please use aggregation function instead."
.to_string(),
}
.fail(),
_ => not_impl_err!("unsupported rex_type"),
}
}
}
#[cfg(test)]
mod test {
use datatypes::value::Value;
use super::*;
use crate::expr::{GlobalId, MapFilterProject};
use crate::plan::{Plan, TypedPlan};
use crate::repr::{self, ColumnType, RelationType};
use crate::transform::test::{create_test_ctx, create_test_query_engine, sql_to_substrait};
/// test if `WHERE` condition can be converted to Flow's ScalarExpr in mfp's filter
#[tokio::test]
async fn test_where_and() {
let engine = create_test_query_engine();
let sql = "SELECT number FROM numbers WHERE number >= 1 AND number <= 3 AND number!=2";
let plan = sql_to_substrait(engine.clone(), sql).await;
let mut ctx = create_test_ctx();
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
// optimize binary and to variadic and
let filter = ScalarExpr::CallVariadic {
func: VariadicFunc::And,
exprs: vec![
ScalarExpr::Column(0).call_binary(
ScalarExpr::Literal(Value::from(1u32), CDT::uint32_datatype()),
BinaryFunc::Gte,
),
ScalarExpr::Column(0).call_binary(
ScalarExpr::Literal(Value::from(3u32), CDT::uint32_datatype()),
BinaryFunc::Lte,
),
ScalarExpr::Column(0).call_binary(
ScalarExpr::Literal(Value::from(2u32), CDT::uint32_datatype()),
BinaryFunc::NotEq,
),
],
};
let expected = TypedPlan {
typ: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), false)]),
plan: Plan::Mfp {
input: Box::new(Plan::Get {
id: crate::expr::Id::Global(GlobalId::User(0)),
}),
mfp: MapFilterProject::new(1)
.map(vec![ScalarExpr::Column(0)])
.unwrap()
.filter(vec![filter])
.unwrap()
.project(vec![1])
.unwrap(),
},
};
assert_eq!(flow_plan.unwrap(), expected);
}
/// case: binary functions&constant folding can happen in converting substrait plan
#[tokio::test]
async fn test_binary_func_and_constant_folding() {
let engine = create_test_query_engine();
let sql = "SELECT 1+1*2-1/1+1%2==3 FROM numbers";
let plan = sql_to_substrait(engine.clone(), sql).await;
let mut ctx = create_test_ctx();
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
let expected = TypedPlan {
typ: RelationType::new(vec![ColumnType::new(CDT::boolean_datatype(), true)]),
plan: Plan::Constant {
rows: vec![(
repr::Row::new(vec![Value::from(true)]),
repr::Timestamp::MIN,
1,
)],
},
};
assert_eq!(flow_plan.unwrap(), expected);
}
/// test if the type of the literal is correctly inferred, i.e. in here literal is decoded to be int64, but need to be uint32,
#[tokio::test]
async fn test_implicitly_cast() {
let engine = create_test_query_engine();
let sql = "SELECT number+1 FROM numbers";
let plan = sql_to_substrait(engine.clone(), sql).await;
let mut ctx = create_test_ctx();
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
let expected = TypedPlan {
typ: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)]),
plan: Plan::Mfp {
input: Box::new(Plan::Get {
id: crate::expr::Id::Global(GlobalId::User(0)),
}),
mfp: MapFilterProject::new(1)
.map(vec![ScalarExpr::Column(0).call_binary(
ScalarExpr::Literal(Value::from(1u32), CDT::uint32_datatype()),
BinaryFunc::AddUInt32,
)])
.unwrap()
.project(vec![1])
.unwrap(),
},
};
assert_eq!(flow_plan.unwrap(), expected);
}
#[tokio::test]
async fn test_cast() {
let engine = create_test_query_engine();
let sql = "SELECT CAST(1 AS INT16) FROM numbers";
let plan = sql_to_substrait(engine.clone(), sql).await;
let mut ctx = create_test_ctx();
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
let expected = TypedPlan {
typ: RelationType::new(vec![ColumnType::new(CDT::int16_datatype(), true)]),
plan: Plan::Mfp {
input: Box::new(Plan::Get {
id: crate::expr::Id::Global(GlobalId::User(0)),
}),
mfp: MapFilterProject::new(1)
.map(vec![ScalarExpr::Literal(
Value::Int64(1),
CDT::int64_datatype(),
)
.call_unary(UnaryFunc::Cast(CDT::int16_datatype()))])
.unwrap()
.project(vec![1])
.unwrap(),
},
};
assert_eq!(flow_plan.unwrap(), expected);
}
#[tokio::test]
async fn test_select_add() {
let engine = create_test_query_engine();
let sql = "SELECT number+number FROM numbers";
let plan = sql_to_substrait(engine.clone(), sql).await;
let mut ctx = create_test_ctx();
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
let expected = TypedPlan {
typ: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)]),
plan: Plan::Mfp {
input: Box::new(Plan::Get {
id: crate::expr::Id::Global(GlobalId::User(0)),
}),
mfp: MapFilterProject::new(1)
.map(vec![ScalarExpr::Column(0)
.call_binary(ScalarExpr::Column(0), BinaryFunc::AddUInt32)])
.unwrap()
.project(vec![1])
.unwrap(),
},
};
assert_eq!(flow_plan.unwrap(), expected);
}
}

View File

@@ -0,0 +1,191 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use common_decimal::Decimal128;
use common_time::{Date, Timestamp};
use datafusion_substrait::variation_const::{
DATE_32_TYPE_REF, DATE_64_TYPE_REF, DEFAULT_TYPE_REF, TIMESTAMP_MICRO_TYPE_REF,
TIMESTAMP_MILLI_TYPE_REF, TIMESTAMP_NANO_TYPE_REF, TIMESTAMP_SECOND_TYPE_REF,
UNSIGNED_INTEGER_TYPE_REF,
};
use datatypes::data_type::ConcreteDataType as CDT;
use datatypes::value::Value;
use substrait::substrait_proto::proto::expression::literal::LiteralType;
use substrait::substrait_proto::proto::expression::Literal;
use substrait::substrait_proto::proto::r#type::Kind;
use crate::adapter::error::{Error, NotImplementedSnafu, PlanSnafu};
/// Convert a Substrait literal into a Value and its ConcreteDataType (So that we can know type even if the value is null)
pub(crate) fn from_substrait_literal(lit: &Literal) -> Result<(Value, CDT), Error> {
let scalar_value = match &lit.literal_type {
Some(LiteralType::Boolean(b)) => (Value::from(*b), CDT::boolean_datatype()),
Some(LiteralType::I8(n)) => match lit.type_variation_reference {
DEFAULT_TYPE_REF => (Value::from(*n as i8), CDT::int8_datatype()),
UNSIGNED_INTEGER_TYPE_REF => (Value::from(*n as u8), CDT::uint8_datatype()),
others => not_impl_err!("Unknown type variation reference {others}",)?,
},
Some(LiteralType::I16(n)) => match lit.type_variation_reference {
DEFAULT_TYPE_REF => (Value::from(*n as i16), CDT::int16_datatype()),
UNSIGNED_INTEGER_TYPE_REF => (Value::from(*n as u16), CDT::uint16_datatype()),
others => not_impl_err!("Unknown type variation reference {others}",)?,
},
Some(LiteralType::I32(n)) => match lit.type_variation_reference {
DEFAULT_TYPE_REF => (Value::from(*n), CDT::int32_datatype()),
UNSIGNED_INTEGER_TYPE_REF => (Value::from(*n as u32), CDT::uint32_datatype()),
others => not_impl_err!("Unknown type variation reference {others}",)?,
},
Some(LiteralType::I64(n)) => match lit.type_variation_reference {
DEFAULT_TYPE_REF => (Value::from(*n), CDT::int64_datatype()),
UNSIGNED_INTEGER_TYPE_REF => (Value::from(*n as u64), CDT::uint64_datatype()),
others => not_impl_err!("Unknown type variation reference {others}",)?,
},
Some(LiteralType::Fp32(f)) => (Value::from(*f), CDT::float32_datatype()),
Some(LiteralType::Fp64(f)) => (Value::from(*f), CDT::float64_datatype()),
Some(LiteralType::Timestamp(t)) => match lit.type_variation_reference {
TIMESTAMP_SECOND_TYPE_REF => (
Value::from(Timestamp::new_second(*t)),
CDT::timestamp_second_datatype(),
),
TIMESTAMP_MILLI_TYPE_REF => (
Value::from(Timestamp::new_millisecond(*t)),
CDT::timestamp_millisecond_datatype(),
),
TIMESTAMP_MICRO_TYPE_REF => (
Value::from(Timestamp::new_microsecond(*t)),
CDT::timestamp_microsecond_datatype(),
),
TIMESTAMP_NANO_TYPE_REF => (
Value::from(Timestamp::new_nanosecond(*t)),
CDT::timestamp_nanosecond_datatype(),
),
others => not_impl_err!("Unknown type variation reference {others}",)?,
},
Some(LiteralType::Date(d)) => (Value::from(Date::new(*d)), CDT::date_datatype()),
Some(LiteralType::String(s)) => (Value::from(s.clone()), CDT::string_datatype()),
Some(LiteralType::Binary(b)) | Some(LiteralType::FixedBinary(b)) => {
(Value::from(b.clone()), CDT::binary_datatype())
}
Some(LiteralType::Decimal(d)) => {
let value: [u8; 16] = d.value.clone().try_into().map_err(|e| {
PlanSnafu {
reason: format!("Failed to parse decimal value from {e:?}"),
}
.build()
})?;
let p: u8 = d.precision.try_into().map_err(|e| {
PlanSnafu {
reason: format!("Failed to parse decimal precision: {e}"),
}
.build()
})?;
let s: i8 = d.scale.try_into().map_err(|e| {
PlanSnafu {
reason: format!("Failed to parse decimal scale: {e}"),
}
.build()
})?;
let value = i128::from_le_bytes(value);
(
Value::from(Decimal128::new(value, p, s)),
CDT::decimal128_datatype(p, s),
)
}
Some(LiteralType::Null(ntype)) => (Value::Null, from_substrait_type(ntype)?),
_ => not_impl_err!("unsupported literal_type")?,
};
Ok(scalar_value)
}
/// convert a Substrait type into a ConcreteDataType
pub fn from_substrait_type(
null_type: &substrait::substrait_proto::proto::Type,
) -> Result<CDT, Error> {
if let Some(kind) = &null_type.kind {
match kind {
Kind::Bool(_) => Ok(CDT::boolean_datatype()),
Kind::I8(integer) => match integer.type_variation_reference {
DEFAULT_TYPE_REF => Ok(CDT::int8_datatype()),
UNSIGNED_INTEGER_TYPE_REF => Ok(CDT::uint8_datatype()),
v => not_impl_err!("Unsupported Substrait type variation {v} of type {kind:?}"),
},
Kind::I16(integer) => match integer.type_variation_reference {
DEFAULT_TYPE_REF => Ok(CDT::int16_datatype()),
UNSIGNED_INTEGER_TYPE_REF => Ok(CDT::uint16_datatype()),
v => not_impl_err!("Unsupported Substrait type variation {v} of type {kind:?}"),
},
Kind::I32(integer) => match integer.type_variation_reference {
DEFAULT_TYPE_REF => Ok(CDT::int32_datatype()),
UNSIGNED_INTEGER_TYPE_REF => Ok(CDT::uint32_datatype()),
v => not_impl_err!("Unsupported Substrait type variation {v} of type {kind:?}"),
},
Kind::I64(integer) => match integer.type_variation_reference {
DEFAULT_TYPE_REF => Ok(CDT::int64_datatype()),
UNSIGNED_INTEGER_TYPE_REF => Ok(CDT::uint64_datatype()),
v => not_impl_err!("Unsupported Substrait type variation {v} of type {kind:?}"),
},
Kind::Fp32(_) => Ok(CDT::float32_datatype()),
Kind::Fp64(_) => Ok(CDT::float64_datatype()),
Kind::Timestamp(ts) => match ts.type_variation_reference {
TIMESTAMP_SECOND_TYPE_REF => Ok(CDT::timestamp_second_datatype()),
TIMESTAMP_MILLI_TYPE_REF => Ok(CDT::timestamp_millisecond_datatype()),
TIMESTAMP_MICRO_TYPE_REF => Ok(CDT::timestamp_microsecond_datatype()),
TIMESTAMP_NANO_TYPE_REF => Ok(CDT::timestamp_nanosecond_datatype()),
v => not_impl_err!("Unsupported Substrait type variation {v} of type {kind:?}"),
},
Kind::Date(date) => match date.type_variation_reference {
DATE_32_TYPE_REF => Ok(CDT::date_datatype()),
DATE_64_TYPE_REF => Ok(CDT::date_datatype()),
v => not_impl_err!("Unsupported Substrait type variation {v} of type {kind:?}"),
},
Kind::Binary(_) => Ok(CDT::binary_datatype()),
Kind::String(_) => Ok(CDT::string_datatype()),
Kind::Decimal(d) => Ok(CDT::decimal128_datatype(d.precision as u8, d.scale as i8)),
_ => not_impl_err!("Unsupported Substrait type: {kind:?}"),
}
} else {
not_impl_err!("Null type without kind is not supported")
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::plan::{Plan, TypedPlan};
use crate::repr::{self, ColumnType, RelationType};
use crate::transform::test::{create_test_ctx, create_test_query_engine, sql_to_substrait};
/// test if literal in substrait plan can be correctly converted to flow plan
#[tokio::test]
async fn test_literal() {
let engine = create_test_query_engine();
let sql = "SELECT 1 FROM numbers";
let plan = sql_to_substrait(engine.clone(), sql).await;
let mut ctx = create_test_ctx();
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
let expected = TypedPlan {
typ: RelationType::new(vec![ColumnType::new(CDT::int64_datatype(), true)]),
plan: Plan::Constant {
rows: vec![(
repr::Row::new(vec![Value::Int64(1)]),
repr::Timestamp::MIN,
1,
)],
},
};
assert_eq!(flow_plan.unwrap(), expected);
}
}

View File

@@ -0,0 +1,190 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use itertools::Itertools;
use snafu::OptionExt;
use substrait::substrait_proto::proto::expression::MaskExpression;
use substrait::substrait_proto::proto::read_rel::ReadType;
use substrait::substrait_proto::proto::rel::RelType;
use substrait::substrait_proto::proto::{plan_rel, Plan as SubPlan, Rel};
use crate::adapter::error::{Error, InvalidQuerySnafu, NotImplementedSnafu, PlanSnafu};
use crate::expr::{MapFilterProject, TypedExpr};
use crate::plan::{Plan, TypedPlan};
use crate::repr::{self, RelationType};
use crate::transform::{DataflowContext, FunctionExtensions};
impl TypedPlan {
/// Convert Substrait Plan into Flow's TypedPlan
pub fn from_substrait_plan(
ctx: &mut DataflowContext,
plan: &SubPlan,
) -> Result<TypedPlan, Error> {
// Register function extension
let function_extension = FunctionExtensions::try_from_proto(&plan.extensions)?;
// Parse relations
match plan.relations.len() {
1 => {
match plan.relations[0].rel_type.as_ref() {
Some(rt) => match rt {
plan_rel::RelType::Rel(rel) => {
Ok(TypedPlan::from_substrait_rel(ctx, rel, &function_extension)?)
},
plan_rel::RelType::Root(root) => {
let input = root.input.as_ref().with_context(|| InvalidQuerySnafu {
reason: "Root relation without input",
})?;
Ok(TypedPlan::from_substrait_rel(ctx, input, &function_extension)?)
}
},
None => plan_err!("Cannot parse plan relation: None")
}
},
_ => not_impl_err!(
"Substrait plan with more than 1 relation trees not supported. Number of relation trees: {:?}",
plan.relations.len()
)
}
}
/// Convert Substrait Rel into Flow's TypedPlan
/// TODO: SELECT DISTINCT(does it get compile with something else?)
pub fn from_substrait_rel(
ctx: &mut DataflowContext,
rel: &Rel,
extensions: &FunctionExtensions,
) -> Result<TypedPlan, Error> {
match &rel.rel_type {
Some(RelType::Project(p)) => {
let input = if let Some(input) = p.input.as_ref() {
TypedPlan::from_substrait_rel(ctx, input, extensions)?
} else {
return not_impl_err!("Projection without an input is not supported");
};
let mut exprs: Vec<TypedExpr> = vec![];
for e in &p.expressions {
let expr = TypedExpr::from_substrait_rex(e, &input.typ, extensions)?;
exprs.push(expr);
}
let is_literal = exprs.iter().all(|expr| expr.expr.is_literal());
if is_literal {
let (literals, lit_types): (Vec<_>, Vec<_>) = exprs
.into_iter()
.map(|TypedExpr { expr, typ }| (expr, typ))
.unzip();
let typ = RelationType::new(lit_types);
let row = literals
.into_iter()
.map(|lit| lit.as_literal().expect("A literal"))
.collect_vec();
let row = repr::Row::new(row);
let plan = Plan::Constant {
rows: vec![(row, repr::Timestamp::MIN, 1)],
};
Ok(TypedPlan { typ, plan })
} else {
input.projection(exprs)
}
}
Some(RelType::Filter(filter)) => {
let input = if let Some(input) = filter.input.as_ref() {
TypedPlan::from_substrait_rel(ctx, input, extensions)?
} else {
return not_impl_err!("Filter without an input is not supported");
};
let expr = if let Some(condition) = filter.condition.as_ref() {
TypedExpr::from_substrait_rex(condition, &input.typ, extensions)?
} else {
return not_impl_err!("Filter without an condition is not valid");
};
input.filter(expr)
}
Some(RelType::Read(read)) => {
if let Some(ReadType::NamedTable(nt)) = &read.as_ref().read_type {
let table_reference = nt.names.clone();
let table = ctx.table(&table_reference)?;
let get_table = Plan::Get {
id: crate::expr::Id::Global(table.0),
};
let get_table = TypedPlan {
typ: table.1,
plan: get_table,
};
if let Some(MaskExpression {
select: Some(projection),
..
}) = &read.projection
{
let column_indices: Vec<usize> = projection
.struct_items
.iter()
.map(|item| item.field as usize)
.collect();
let input_arity = get_table.typ.column_types.len();
let mfp =
MapFilterProject::new(input_arity).project(column_indices.clone())?;
get_table.mfp(mfp)
} else {
Ok(get_table)
}
} else {
not_impl_err!("Only NamedTable reads are supported")
}
}
Some(RelType::Aggregate(agg)) => {
TypedPlan::from_substrait_agg_rel(ctx, agg, extensions)
}
_ => not_impl_err!("Unsupported relation type: {:?}", rel.rel_type),
}
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::expr::{GlobalId, ScalarExpr};
use crate::plan::{Plan, TypedPlan};
use crate::repr::{self, ColumnType, RelationType};
use crate::transform::test::{create_test_ctx, create_test_query_engine, sql_to_substrait};
use crate::transform::CDT;
#[tokio::test]
async fn test_select() {
let engine = create_test_query_engine();
let sql = "SELECT number FROM numbers";
let plan = sql_to_substrait(engine.clone(), sql).await;
let mut ctx = create_test_ctx();
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
let expected = TypedPlan {
typ: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), false)]),
plan: Plan::Mfp {
input: Box::new(Plan::Get {
id: crate::expr::Id::Global(GlobalId::User(0)),
}),
mfp: MapFilterProject::new(1)
.map(vec![ScalarExpr::Column(0)])
.unwrap()
.project(vec![1])
.unwrap(),
},
};
assert_eq!(flow_plan.unwrap(), expected);
}
}