mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-01-03 20:02:54 +00:00
feat(flow): transform substrait SELECT&WHERE&GROUP BY to Flow Plan (#3690)
* feat: transofrm substrait SELECT&WHERE&GROUP BY to Flow Plan * chore: reexport from common/substrait * feat: use datafusion Aggr Func to map to Flow aggr func * chore: remove unwrap&split literal * refactor: split transform.rs into smaller files * feat: apply optimize for variadic fn * refactor: split unit test * chore: per review
This commit is contained in:
@@ -17,12 +17,12 @@
|
||||
mod df_substrait;
|
||||
pub mod error;
|
||||
pub mod extension_serializer;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use bytes::{Buf, Bytes};
|
||||
use datafusion::catalog::CatalogList;
|
||||
pub use substrait_proto;
|
||||
|
||||
pub use crate::df_substrait::DFLogicalSubstraitConvertor;
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ servers.workspace = true
|
||||
smallvec.workspace = true
|
||||
snafu.workspace = true
|
||||
strum.workspace = true
|
||||
substrait.workspace = true
|
||||
tokio.workspace = true
|
||||
tonic.workspace = true
|
||||
|
||||
@@ -39,5 +40,4 @@ prost.workspace = true
|
||||
query.workspace = true
|
||||
serde_json = "1.0"
|
||||
session.workspace = true
|
||||
substrait.workspace = true
|
||||
table.workspace = true
|
||||
|
||||
@@ -73,6 +73,13 @@ pub enum Error {
|
||||
extra: String,
|
||||
location: Location,
|
||||
},
|
||||
|
||||
#[snafu(display("Datafusion error: {raw:?} in context: {context}"))]
|
||||
Datafusion {
|
||||
raw: datafusion_common::DataFusionError,
|
||||
context: String,
|
||||
location: Location,
|
||||
},
|
||||
}
|
||||
|
||||
/// Result type for flow module
|
||||
@@ -81,7 +88,9 @@ pub type Result<T> = std::result::Result<T, Error>;
|
||||
impl ErrorExt for Error {
|
||||
fn status_code(&self) -> StatusCode {
|
||||
match self {
|
||||
Self::Eval { .. } | &Self::JoinTask { .. } => StatusCode::Internal,
|
||||
Self::Eval { .. } | &Self::JoinTask { .. } | &Self::Datafusion { .. } => {
|
||||
StatusCode::Internal
|
||||
}
|
||||
&Self::TableAlreadyExist { .. } => StatusCode::TableAlreadyExists,
|
||||
Self::TableNotFound { .. } => StatusCode::TableNotFound,
|
||||
&Self::InvalidQuery { .. } | &Self::Plan { .. } | &Self::Datatypes { .. } => {
|
||||
|
||||
@@ -344,7 +344,7 @@ mod test {
|
||||
(Row::new(vec![2i64.into()]), 2, 1),
|
||||
(Row::new(vec![3i64.into()]), 3, 1),
|
||||
];
|
||||
let collection = ctx.render_constant(rows.clone());
|
||||
let collection = ctx.render_constant(rows);
|
||||
ctx.insert_global(GlobalId::User(1), collection);
|
||||
let input_plan = Plan::Get {
|
||||
id: expr::Id::Global(GlobalId::User(1)),
|
||||
@@ -440,7 +440,7 @@ mod test {
|
||||
(Row::new(vec![2.into()]), 2, 1),
|
||||
(Row::new(vec![3.into()]), 3, 1),
|
||||
];
|
||||
let collection = ctx.render_constant(rows.clone());
|
||||
let collection = ctx.render_constant(rows);
|
||||
ctx.insert_global(GlobalId::User(1), collection);
|
||||
let input_plan = Plan::Get {
|
||||
id: expr::Id::Global(GlobalId::User(1)),
|
||||
@@ -490,7 +490,7 @@ mod test {
|
||||
(Row::empty(), 2, 1),
|
||||
(Row::empty(), 3, 1),
|
||||
];
|
||||
let collection = ctx.render_constant(rows.clone());
|
||||
let collection = ctx.render_constant(rows);
|
||||
let collection = collection.collection.clone(ctx.df);
|
||||
let cnt = Rc::new(RefCell::new(0));
|
||||
let cnt_inner = cnt.clone();
|
||||
|
||||
@@ -27,4 +27,4 @@ pub(crate) use func::{BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc}
|
||||
pub(crate) use id::{GlobalId, Id, LocalId};
|
||||
pub(crate) use linear::{MapFilterProject, MfpPlan, SafeMfpPlan};
|
||||
pub(crate) use relation::{AggregateExpr, AggregateFunc};
|
||||
pub(crate) use scalar::ScalarExpr;
|
||||
pub(crate) use scalar::{ScalarExpr, TypedExpr};
|
||||
|
||||
@@ -501,8 +501,8 @@ impl BinaryFunc {
|
||||
let spec_fn = Self::specialization(generic_fn, query_input_type)?;
|
||||
|
||||
let signature = Signature {
|
||||
input: smallvec![arg_type.clone(), arg_type.clone()],
|
||||
output: spec_fn.signature().output.clone(),
|
||||
input: smallvec![arg_type.clone(), arg_type],
|
||||
output: spec_fn.signature().output,
|
||||
generic_fn,
|
||||
};
|
||||
|
||||
@@ -767,7 +767,7 @@ fn test_num_ops() {
|
||||
assert_eq!(res, Value::from(30));
|
||||
let res = div::<i32>(left.clone(), right.clone()).unwrap();
|
||||
assert_eq!(res, Value::from(3));
|
||||
let res = rem::<i32>(left.clone(), right.clone()).unwrap();
|
||||
let res = rem::<i32>(left, right).unwrap();
|
||||
assert_eq!(res, Value::from(1));
|
||||
|
||||
let values = vec![Value::from(true), Value::from(false)];
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::str::FromStr;
|
||||
use std::sync::OnceLock;
|
||||
|
||||
use common_time::{Date, DateTime};
|
||||
@@ -20,10 +21,10 @@ use datatypes::prelude::ConcreteDataType;
|
||||
use datatypes::value::{OrderedF32, OrderedF64, Value};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smallvec::smallvec;
|
||||
use snafu::OptionExt;
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
use strum::{EnumIter, IntoEnumIterator};
|
||||
|
||||
use crate::adapter::error::{Error, InvalidQuerySnafu};
|
||||
use crate::adapter::error::{DatafusionSnafu, Error, InvalidQuerySnafu};
|
||||
use crate::expr::error::{EvalError, TryFromValueSnafu, TypeMismatchSnafu};
|
||||
use crate::expr::relation::accum::{Accum, Accumulator};
|
||||
use crate::expr::signature::{GenericFn, Signature};
|
||||
@@ -172,17 +173,32 @@ impl AggregateFunc {
|
||||
}
|
||||
spec
|
||||
});
|
||||
use datafusion_expr::aggregate_function::AggregateFunction as DfAggrFunc;
|
||||
let df_aggr_func = DfAggrFunc::from_str(name).or_else(|err| {
|
||||
if let datafusion_common::DataFusionError::NotImplemented(msg) = err {
|
||||
InvalidQuerySnafu {
|
||||
reason: format!("Unsupported aggregate function: {}", msg),
|
||||
}
|
||||
.fail()
|
||||
} else {
|
||||
DatafusionSnafu {
|
||||
raw: err,
|
||||
context: "Error when parsing aggregate function",
|
||||
}
|
||||
.fail()
|
||||
}
|
||||
})?;
|
||||
|
||||
let generic_fn = match name {
|
||||
"max" => GenericFn::Max,
|
||||
"min" => GenericFn::Min,
|
||||
"sum" => GenericFn::Sum,
|
||||
"count" => GenericFn::Count,
|
||||
"any" => GenericFn::Any,
|
||||
"all" => GenericFn::All,
|
||||
let generic_fn = match df_aggr_func {
|
||||
DfAggrFunc::Max => GenericFn::Max,
|
||||
DfAggrFunc::Min => GenericFn::Min,
|
||||
DfAggrFunc::Sum => GenericFn::Sum,
|
||||
DfAggrFunc::Count => GenericFn::Count,
|
||||
DfAggrFunc::BoolOr => GenericFn::Any,
|
||||
DfAggrFunc::BoolAnd => GenericFn::All,
|
||||
_ => {
|
||||
return InvalidQuerySnafu {
|
||||
reason: format!("Unknown binary function: {}", name),
|
||||
reason: format!("Unknown aggregate function: {}", name),
|
||||
}
|
||||
.fail();
|
||||
}
|
||||
|
||||
@@ -24,6 +24,22 @@ use snafu::ensure;
|
||||
use crate::adapter::error::{Error, InvalidQuerySnafu, UnsupportedTemporalFilterSnafu};
|
||||
use crate::expr::error::{EvalError, InvalidArgumentSnafu, OptimizeSnafu};
|
||||
use crate::expr::func::{BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc};
|
||||
use crate::repr::ColumnType;
|
||||
|
||||
/// A scalar expression with a known type.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TypedExpr {
|
||||
/// The expression.
|
||||
pub expr: ScalarExpr,
|
||||
/// The type of the expression.
|
||||
pub typ: ColumnType,
|
||||
}
|
||||
|
||||
impl TypedExpr {
|
||||
pub fn new(expr: ScalarExpr, typ: ColumnType) -> Self {
|
||||
Self { expr, typ }
|
||||
}
|
||||
}
|
||||
|
||||
/// A scalar expression, which can be evaluated to a value.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
@@ -64,6 +80,38 @@ pub enum ScalarExpr {
|
||||
},
|
||||
}
|
||||
|
||||
impl ScalarExpr {
|
||||
/// apply optimization to the expression, like flatten variadic function
|
||||
pub fn optimize(&mut self) {
|
||||
self.flatten_varidic_fn();
|
||||
}
|
||||
|
||||
/// Because Substrait's `And`/`Or` function is binary, but FlowPlan's
|
||||
/// `And`/`Or` function is variadic, we need to flatten the `And` function if multiple `And`/`Or` functions are nested.
|
||||
fn flatten_varidic_fn(&mut self) {
|
||||
if let ScalarExpr::CallVariadic { func, exprs } = self {
|
||||
let mut new_exprs = vec![];
|
||||
for expr in std::mem::take(exprs) {
|
||||
if let ScalarExpr::CallVariadic {
|
||||
func: inner_func,
|
||||
exprs: mut inner_exprs,
|
||||
} = expr
|
||||
{
|
||||
if *func == inner_func {
|
||||
for inner_expr in inner_exprs.iter_mut() {
|
||||
inner_expr.flatten_varidic_fn();
|
||||
}
|
||||
new_exprs.extend(inner_exprs);
|
||||
}
|
||||
} else {
|
||||
new_exprs.push(expr);
|
||||
}
|
||||
}
|
||||
*exprs = new_exprs;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarExpr {
|
||||
/// Call a unary function on this expression.
|
||||
pub fn call_unary(self, func: UnaryFunc) -> Self {
|
||||
|
||||
@@ -27,4 +27,5 @@ mod compute;
|
||||
mod expr;
|
||||
mod plan;
|
||||
mod repr;
|
||||
mod transform;
|
||||
mod utils;
|
||||
|
||||
@@ -24,7 +24,7 @@ use serde::{Deserialize, Serialize};
|
||||
pub(crate) use self::reduce::{AccumulablePlan, KeyValPlan, ReducePlan};
|
||||
use crate::adapter::error::Error;
|
||||
use crate::expr::{
|
||||
AggregateExpr, EvalError, Id, LocalId, MapFilterProject, SafeMfpPlan, ScalarExpr,
|
||||
AggregateExpr, EvalError, Id, LocalId, MapFilterProject, SafeMfpPlan, ScalarExpr, TypedExpr,
|
||||
};
|
||||
use crate::plan::join::JoinPlan;
|
||||
use crate::repr::{ColumnType, DiffRow, RelationType};
|
||||
@@ -61,10 +61,13 @@ impl TypedPlan {
|
||||
}
|
||||
|
||||
/// project the plan to the given expressions
|
||||
pub fn projection(self, exprs: Vec<(ScalarExpr, ColumnType)>) -> Result<Self, Error> {
|
||||
pub fn projection(self, exprs: Vec<TypedExpr>) -> Result<Self, Error> {
|
||||
let input_arity = self.typ.column_types.len();
|
||||
let output_arity = exprs.len();
|
||||
let (exprs, expr_typs): (Vec<_>, Vec<_>) = exprs.into_iter().unzip();
|
||||
let (exprs, expr_typs): (Vec<_>, Vec<_>) = exprs
|
||||
.into_iter()
|
||||
.map(|TypedExpr { expr, typ }| (expr, typ))
|
||||
.unzip();
|
||||
let mfp = MapFilterProject::new(input_arity)
|
||||
.map(exprs)?
|
||||
.project(input_arity..input_arity + output_arity)?;
|
||||
@@ -87,18 +90,19 @@ impl TypedPlan {
|
||||
}
|
||||
|
||||
/// Add a new filter to the plan, will filter out the records that do not satisfy the filter
|
||||
pub fn filter(self, filter: (ScalarExpr, ColumnType)) -> Result<Self, Error> {
|
||||
pub fn filter(self, filter: TypedExpr) -> Result<Self, Error> {
|
||||
let plan = match self.plan {
|
||||
Plan::Mfp {
|
||||
input,
|
||||
mfp: old_mfp,
|
||||
} => Plan::Mfp {
|
||||
input,
|
||||
mfp: old_mfp.filter(vec![filter.0])?,
|
||||
mfp: old_mfp.filter(vec![filter.expr])?,
|
||||
},
|
||||
_ => Plan::Mfp {
|
||||
input: Box::new(self.plan),
|
||||
mfp: MapFilterProject::new(self.typ.column_types.len()).filter(vec![filter.0])?,
|
||||
mfp: MapFilterProject::new(self.typ.column_types.len())
|
||||
.filter(vec![filter.expr])?,
|
||||
},
|
||||
};
|
||||
Ok(TypedPlan {
|
||||
|
||||
179
src/flow/src/transform.rs
Normal file
179
src/flow/src/transform.rs
Normal file
@@ -0,0 +1,179 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Transform Substrait into execution plan
|
||||
use std::collections::HashMap;
|
||||
|
||||
use datatypes::data_type::ConcreteDataType as CDT;
|
||||
|
||||
use crate::adapter::error::{Error, NotImplementedSnafu, TableNotFoundSnafu};
|
||||
use crate::expr::GlobalId;
|
||||
use crate::repr::RelationType;
|
||||
/// a simple macro to generate a not implemented error
|
||||
macro_rules! not_impl_err {
|
||||
($($arg:tt)*) => {
|
||||
NotImplementedSnafu {
|
||||
reason: format!($($arg)*),
|
||||
}.fail()
|
||||
};
|
||||
}
|
||||
|
||||
/// generate a plan error
|
||||
macro_rules! plan_err {
|
||||
($($arg:tt)*) => {
|
||||
PlanSnafu {
|
||||
reason: format!($($arg)*),
|
||||
}.fail()
|
||||
};
|
||||
}
|
||||
|
||||
mod aggr;
|
||||
mod expr;
|
||||
mod literal;
|
||||
mod plan;
|
||||
|
||||
use literal::{from_substrait_literal, from_substrait_type};
|
||||
use snafu::OptionExt;
|
||||
use substrait::substrait_proto::proto::extensions::simple_extension_declaration::MappingType;
|
||||
use substrait::substrait_proto::proto::extensions::SimpleExtensionDeclaration;
|
||||
|
||||
/// In Substrait, a function can be define by an u32 anchor, and the anchor can be mapped to a name
|
||||
///
|
||||
/// So in substrait plan, a ref to a function can be a single u32 anchor instead of a full name in string
|
||||
pub struct FunctionExtensions {
|
||||
anchor_to_name: HashMap<u32, String>,
|
||||
}
|
||||
|
||||
impl FunctionExtensions {
|
||||
/// Create a new FunctionExtensions from a list of SimpleExtensionDeclaration
|
||||
pub fn try_from_proto(extensions: &[SimpleExtensionDeclaration]) -> Result<Self, Error> {
|
||||
let mut anchor_to_name = HashMap::new();
|
||||
for e in extensions {
|
||||
match &e.mapping_type {
|
||||
Some(ext) => match ext {
|
||||
MappingType::ExtensionFunction(ext_f) => {
|
||||
anchor_to_name.insert(ext_f.function_anchor, ext_f.name.clone());
|
||||
}
|
||||
_ => not_impl_err!("Extension type not supported: {ext:?}")?,
|
||||
},
|
||||
None => not_impl_err!("Cannot parse empty extension")?,
|
||||
}
|
||||
}
|
||||
Ok(Self { anchor_to_name })
|
||||
}
|
||||
|
||||
/// Get the name of a function by it's anchor
|
||||
pub fn get(&self, anchor: &u32) -> Option<&String> {
|
||||
self.anchor_to_name.get(anchor)
|
||||
}
|
||||
}
|
||||
|
||||
/// A context that holds the information of the dataflow
|
||||
pub struct DataflowContext {
|
||||
/// `id` refer to any source table in the dataflow, and `name` is the name of the table
|
||||
/// which is a `Vec<String>` in substrait
|
||||
id_to_name: HashMap<GlobalId, Vec<String>>,
|
||||
/// see `id_to_name`
|
||||
name_to_id: HashMap<Vec<String>, GlobalId>,
|
||||
/// the schema of the table
|
||||
schema: HashMap<GlobalId, RelationType>,
|
||||
}
|
||||
|
||||
impl DataflowContext {
|
||||
/// Retrieves a GlobalId and table schema representing a table previously registered by calling the [register_table] function.
|
||||
///
|
||||
/// Returns an error if no table has been registered with the provided names
|
||||
pub fn table(&self, name: &Vec<String>) -> Result<(GlobalId, RelationType), Error> {
|
||||
let id = self
|
||||
.name_to_id
|
||||
.get(name)
|
||||
.copied()
|
||||
.with_context(|| TableNotFoundSnafu {
|
||||
name: name.join("."),
|
||||
})?;
|
||||
let schema = self
|
||||
.schema
|
||||
.get(&id)
|
||||
.cloned()
|
||||
.with_context(|| TableNotFoundSnafu {
|
||||
name: name.join("."),
|
||||
})?;
|
||||
Ok((id, schema))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use std::sync::Arc;
|
||||
|
||||
use catalog::RegisterTableRequest;
|
||||
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, NUMBERS_TABLE_ID};
|
||||
use prost::Message;
|
||||
use query::parser::QueryLanguageParser;
|
||||
use query::plan::LogicalPlan;
|
||||
use query::QueryEngine;
|
||||
use session::context::QueryContext;
|
||||
use substrait::substrait_proto::proto;
|
||||
use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
|
||||
use table::table::numbers::{NumbersTable, NUMBERS_TABLE_NAME};
|
||||
|
||||
use super::*;
|
||||
use crate::repr::ColumnType;
|
||||
|
||||
pub fn create_test_ctx() -> DataflowContext {
|
||||
let gid = GlobalId::User(0);
|
||||
let name = vec!["numbers".to_string()];
|
||||
let schema = RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), false)]);
|
||||
|
||||
DataflowContext {
|
||||
id_to_name: HashMap::from([(gid, name.clone())]),
|
||||
name_to_id: HashMap::from([(name.clone(), gid)]),
|
||||
schema: HashMap::from([(gid, schema)]),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_test_query_engine() -> Arc<dyn QueryEngine> {
|
||||
let catalog_list = catalog::memory::new_memory_catalog_manager().unwrap();
|
||||
let req = RegisterTableRequest {
|
||||
catalog: DEFAULT_CATALOG_NAME.to_string(),
|
||||
schema: DEFAULT_SCHEMA_NAME.to_string(),
|
||||
table_name: NUMBERS_TABLE_NAME.to_string(),
|
||||
table_id: NUMBERS_TABLE_ID,
|
||||
table: NumbersTable::table(NUMBERS_TABLE_ID),
|
||||
};
|
||||
catalog_list.register_table_sync(req).unwrap();
|
||||
let factory = query::QueryEngineFactory::new(catalog_list, None, None, None, false);
|
||||
|
||||
let engine = factory.query_engine();
|
||||
|
||||
assert_eq!("datafusion", engine.name());
|
||||
engine
|
||||
}
|
||||
|
||||
pub async fn sql_to_substrait(engine: Arc<dyn QueryEngine>, sql: &str) -> proto::Plan {
|
||||
// let engine = create_test_query_engine();
|
||||
let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
|
||||
let plan = engine
|
||||
.planner()
|
||||
.plan(stmt, QueryContext::arc())
|
||||
.await
|
||||
.unwrap();
|
||||
let LogicalPlan::DfPlan(plan) = plan;
|
||||
|
||||
// encode then decode so to rely on the impl of conversion from logical plan to substrait plan
|
||||
let bytes = DFLogicalSubstraitConvertor {}.encode(&plan).unwrap();
|
||||
|
||||
proto::Plan::decode(bytes).unwrap()
|
||||
}
|
||||
}
|
||||
446
src/flow/src/transform/aggr.rs
Normal file
446
src/flow/src/transform/aggr.rs
Normal file
@@ -0,0 +1,446 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use common_decimal::Decimal128;
|
||||
use common_time::{Date, Timestamp};
|
||||
use datafusion_substrait::variation_const::{
|
||||
DATE_32_TYPE_REF, DATE_64_TYPE_REF, DEFAULT_TYPE_REF, TIMESTAMP_MICRO_TYPE_REF,
|
||||
TIMESTAMP_MILLI_TYPE_REF, TIMESTAMP_NANO_TYPE_REF, TIMESTAMP_SECOND_TYPE_REF,
|
||||
UNSIGNED_INTEGER_TYPE_REF,
|
||||
};
|
||||
use datatypes::arrow::compute::kernels::window;
|
||||
use datatypes::arrow::ipc::Binary;
|
||||
use datatypes::data_type::ConcreteDataType as CDT;
|
||||
use datatypes::value::Value;
|
||||
use hydroflow::futures::future::Map;
|
||||
use itertools::Itertools;
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
use substrait::substrait_proto::proto::aggregate_function::AggregationInvocation;
|
||||
use substrait::substrait_proto::proto::aggregate_rel::{Grouping, Measure};
|
||||
use substrait::substrait_proto::proto::expression::field_reference::ReferenceType::DirectReference;
|
||||
use substrait::substrait_proto::proto::expression::literal::LiteralType;
|
||||
use substrait::substrait_proto::proto::expression::reference_segment::ReferenceType::StructField;
|
||||
use substrait::substrait_proto::proto::expression::{
|
||||
IfThen, Literal, MaskExpression, RexType, ScalarFunction,
|
||||
};
|
||||
use substrait::substrait_proto::proto::extensions::simple_extension_declaration::MappingType;
|
||||
use substrait::substrait_proto::proto::extensions::SimpleExtensionDeclaration;
|
||||
use substrait::substrait_proto::proto::function_argument::ArgType;
|
||||
use substrait::substrait_proto::proto::r#type::Kind;
|
||||
use substrait::substrait_proto::proto::read_rel::ReadType;
|
||||
use substrait::substrait_proto::proto::rel::RelType;
|
||||
use substrait::substrait_proto::proto::{self, plan_rel, Expression, Plan as SubPlan, Rel};
|
||||
|
||||
use crate::adapter::error::{
|
||||
DatatypesSnafu, Error, EvalSnafu, InvalidQuerySnafu, NotImplementedSnafu, PlanSnafu,
|
||||
TableNotFoundSnafu,
|
||||
};
|
||||
use crate::expr::{
|
||||
AggregateExpr, AggregateFunc, BinaryFunc, GlobalId, MapFilterProject, SafeMfpPlan, ScalarExpr,
|
||||
TypedExpr, UnaryFunc, UnmaterializableFunc, VariadicFunc,
|
||||
};
|
||||
use crate::plan::{AccumulablePlan, KeyValPlan, Plan, ReducePlan, TypedPlan};
|
||||
use crate::repr::{self, ColumnType, RelationType};
|
||||
use crate::transform::{DataflowContext, FunctionExtensions};
|
||||
|
||||
impl TypedExpr {
|
||||
fn from_substrait_agg_grouping(
|
||||
ctx: &mut DataflowContext,
|
||||
groupings: &[Grouping],
|
||||
typ: &RelationType,
|
||||
extensions: &FunctionExtensions,
|
||||
) -> Result<Vec<TypedExpr>, Error> {
|
||||
let _ = ctx;
|
||||
let mut group_expr = vec![];
|
||||
match groupings.len() {
|
||||
1 => {
|
||||
for e in &groupings[0].grouping_expressions {
|
||||
let x = TypedExpr::from_substrait_rex(e, typ, extensions)?;
|
||||
group_expr.push(x);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return not_impl_err!(
|
||||
"Grouping sets not support yet, use union all with group by instead."
|
||||
);
|
||||
}
|
||||
};
|
||||
Ok(group_expr)
|
||||
}
|
||||
}
|
||||
|
||||
impl AggregateExpr {
|
||||
fn from_substrait_agg_measures(
|
||||
ctx: &mut DataflowContext,
|
||||
measures: &[Measure],
|
||||
typ: &RelationType,
|
||||
extensions: &FunctionExtensions,
|
||||
) -> Result<Vec<AggregateExpr>, Error> {
|
||||
let _ = ctx;
|
||||
let mut aggr_exprs = vec![];
|
||||
|
||||
for m in measures {
|
||||
let filter = &m
|
||||
.filter
|
||||
.as_ref()
|
||||
.map(|fil| TypedExpr::from_substrait_rex(fil, typ, extensions))
|
||||
.transpose()?;
|
||||
|
||||
let agg_func = match &m.measure {
|
||||
Some(f) => {
|
||||
let distinct = match f.invocation {
|
||||
_ if f.invocation == AggregationInvocation::Distinct as i32 => true,
|
||||
_ if f.invocation == AggregationInvocation::All as i32 => false,
|
||||
_ => false,
|
||||
};
|
||||
AggregateExpr::from_substrait_agg_func(
|
||||
f, typ, extensions, filter, // TODO(discord9): impl order_by
|
||||
&None, distinct,
|
||||
)
|
||||
}
|
||||
None => not_impl_err!("Aggregate without aggregate function is not supported"),
|
||||
}?;
|
||||
aggr_exprs.push(agg_func);
|
||||
}
|
||||
Ok(aggr_exprs)
|
||||
}
|
||||
|
||||
/// Convert AggregateFunction into Flow's AggregateExpr
|
||||
pub fn from_substrait_agg_func(
|
||||
f: &proto::AggregateFunction,
|
||||
input_schema: &RelationType,
|
||||
extensions: &FunctionExtensions,
|
||||
filter: &Option<TypedExpr>,
|
||||
order_by: &Option<Vec<TypedExpr>>,
|
||||
distinct: bool,
|
||||
) -> Result<AggregateExpr, Error> {
|
||||
// TODO(discord9): impl filter
|
||||
let _ = filter;
|
||||
let _ = order_by;
|
||||
let mut args = vec![];
|
||||
for arg in &f.arguments {
|
||||
let arg_expr = match &arg.arg_type {
|
||||
Some(ArgType::Value(e)) => {
|
||||
TypedExpr::from_substrait_rex(e, input_schema, extensions)
|
||||
}
|
||||
_ => not_impl_err!("Aggregated function argument non-Value type not supported"),
|
||||
}?;
|
||||
args.push(arg_expr);
|
||||
}
|
||||
|
||||
let arg = if let Some(first) = args.first() {
|
||||
first
|
||||
} else {
|
||||
return not_impl_err!("Aggregated function without arguments is not supported");
|
||||
};
|
||||
|
||||
let func = match extensions.get(&f.function_reference) {
|
||||
Some(function_name) => {
|
||||
AggregateFunc::from_str_and_type(function_name, Some(arg.typ.scalar_type.clone()))
|
||||
}
|
||||
None => not_impl_err!(
|
||||
"Aggregated function not found: function anchor = {:?}",
|
||||
f.function_reference
|
||||
),
|
||||
}?;
|
||||
Ok(AggregateExpr {
|
||||
func,
|
||||
expr: arg.expr.clone(),
|
||||
distinct,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl KeyValPlan {
|
||||
/// Generate KeyValPlan from AggregateExpr and group_exprs
|
||||
///
|
||||
/// will also change aggregate expr to use column ref if necessary
|
||||
fn from_substrait_gen_key_val_plan(
|
||||
aggr_exprs: &mut [AggregateExpr],
|
||||
group_exprs: &[TypedExpr],
|
||||
input_arity: usize,
|
||||
) -> Result<KeyValPlan, Error> {
|
||||
let group_expr_val = group_exprs
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|expr| expr.expr.clone())
|
||||
.collect_vec();
|
||||
let output_arity = group_expr_val.len();
|
||||
let key_plan = MapFilterProject::new(input_arity)
|
||||
.map(group_expr_val)?
|
||||
.project(input_arity..input_arity + output_arity)?;
|
||||
|
||||
// val_plan is extracted from aggr_exprs to give aggr function it's necessary input
|
||||
// and since aggr func need inputs that is column ref, we just add a prefix mfp to transform any expr that is not into a column ref
|
||||
let val_plan = {
|
||||
let need_mfp = aggr_exprs.iter().any(|agg| agg.expr.as_column().is_none());
|
||||
if need_mfp {
|
||||
// create mfp from aggr_expr, and modify aggr_expr to use the output column of mfp
|
||||
let input_exprs = aggr_exprs
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.map(|(idx, aggr)| {
|
||||
let ret = aggr.expr.clone();
|
||||
aggr.expr = ScalarExpr::Column(idx);
|
||||
ret
|
||||
})
|
||||
.collect_vec();
|
||||
let aggr_arity = aggr_exprs.len();
|
||||
|
||||
MapFilterProject::new(input_arity)
|
||||
.map(input_exprs)?
|
||||
.project(input_arity..input_arity + aggr_arity)?
|
||||
} else {
|
||||
// simply take all inputs as value
|
||||
MapFilterProject::new(input_arity)
|
||||
}
|
||||
};
|
||||
Ok(KeyValPlan {
|
||||
key_plan: key_plan.into_safe(),
|
||||
val_plan: val_plan.into_safe(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TypedPlan {
|
||||
/// Convert AggregateRel into Flow's TypedPlan
|
||||
pub fn from_substrait_agg_rel(
|
||||
ctx: &mut DataflowContext,
|
||||
agg: &proto::AggregateRel,
|
||||
extensions: &FunctionExtensions,
|
||||
) -> Result<TypedPlan, Error> {
|
||||
let input = if let Some(input) = agg.input.as_ref() {
|
||||
TypedPlan::from_substrait_rel(ctx, input, extensions)?
|
||||
} else {
|
||||
return not_impl_err!("Aggregate without an input is not supported");
|
||||
};
|
||||
|
||||
let group_expr =
|
||||
TypedExpr::from_substrait_agg_grouping(ctx, &agg.groupings, &input.typ, extensions)?;
|
||||
|
||||
let mut aggr_exprs =
|
||||
AggregateExpr::from_substrait_agg_measures(ctx, &agg.measures, &input.typ, extensions)?;
|
||||
|
||||
let key_val_plan = KeyValPlan::from_substrait_gen_key_val_plan(
|
||||
&mut aggr_exprs,
|
||||
&group_expr,
|
||||
input.typ.column_types.len(),
|
||||
)?;
|
||||
|
||||
let output_type = {
|
||||
let mut output_types = Vec::new();
|
||||
// first append group_expr as key, then aggr_expr as value
|
||||
for expr in &group_expr {
|
||||
output_types.push(expr.typ.clone());
|
||||
}
|
||||
|
||||
for aggr in &aggr_exprs {
|
||||
output_types.push(ColumnType::new_nullable(
|
||||
aggr.func.signature().output.clone(),
|
||||
));
|
||||
}
|
||||
RelationType::new(output_types)
|
||||
};
|
||||
|
||||
// copy aggr_exprs to full_aggrs, and split them into simple_aggrs and distinct_aggrs
|
||||
// also set them input/output column
|
||||
let full_aggrs = aggr_exprs;
|
||||
let mut simple_aggrs = Vec::new();
|
||||
let mut distinct_aggrs = Vec::new();
|
||||
for (output_column, aggr_expr) in full_aggrs.iter().enumerate() {
|
||||
let input_column = aggr_expr.expr.as_column().with_context(|| PlanSnafu {
|
||||
reason: "Expect aggregate argument to be transformed into a column at this point",
|
||||
})?;
|
||||
if aggr_expr.distinct {
|
||||
distinct_aggrs.push((output_column, input_column, aggr_expr.clone()));
|
||||
} else {
|
||||
simple_aggrs.push((output_column, input_column, aggr_expr.clone()));
|
||||
}
|
||||
}
|
||||
let accum_plan = AccumulablePlan {
|
||||
full_aggrs,
|
||||
simple_aggrs,
|
||||
distinct_aggrs,
|
||||
};
|
||||
let plan = Plan::Reduce {
|
||||
input: Box::new(input.plan),
|
||||
key_val_plan,
|
||||
reduce_plan: ReducePlan::Accumulable(accum_plan),
|
||||
};
|
||||
Ok(TypedPlan {
|
||||
typ: output_type,
|
||||
plan,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use crate::plan::{Plan, TypedPlan};
|
||||
use crate::repr::{self, ColumnType, RelationType};
|
||||
use crate::transform::test::{create_test_ctx, create_test_query_engine, sql_to_substrait};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sum() {
|
||||
let engine = create_test_query_engine();
|
||||
let sql = "SELECT sum(number) FROM numbers";
|
||||
let plan = sql_to_substrait(engine.clone(), sql).await;
|
||||
|
||||
let mut ctx = create_test_ctx();
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
|
||||
|
||||
let aggr_expr = AggregateExpr {
|
||||
func: AggregateFunc::SumUInt32,
|
||||
expr: ScalarExpr::Column(0),
|
||||
distinct: false,
|
||||
};
|
||||
let expected = TypedPlan {
|
||||
typ: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)]),
|
||||
plan: Plan::Mfp {
|
||||
input: Box::new(Plan::Reduce {
|
||||
input: Box::new(Plan::Get {
|
||||
id: crate::expr::Id::Global(GlobalId::User(0)),
|
||||
}),
|
||||
key_val_plan: KeyValPlan {
|
||||
key_plan: MapFilterProject::new(1)
|
||||
.project(vec![])
|
||||
.unwrap()
|
||||
.into_safe(),
|
||||
val_plan: MapFilterProject::new(1)
|
||||
.project(vec![0])
|
||||
.unwrap()
|
||||
.into_safe(),
|
||||
},
|
||||
reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
|
||||
full_aggrs: vec![aggr_expr.clone()],
|
||||
simple_aggrs: vec![(0, 0, aggr_expr.clone())],
|
||||
distinct_aggrs: vec![],
|
||||
}),
|
||||
}),
|
||||
mfp: MapFilterProject::new(1)
|
||||
.map(vec![ScalarExpr::Column(0)])
|
||||
.unwrap()
|
||||
.project(vec![1])
|
||||
.unwrap(),
|
||||
},
|
||||
};
|
||||
assert_eq!(flow_plan.unwrap(), expected);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sum_group_by() {
|
||||
let engine = create_test_query_engine();
|
||||
let sql = "SELECT sum(number), number FROM numbers GROUP BY number";
|
||||
let plan = sql_to_substrait(engine.clone(), sql).await;
|
||||
|
||||
let mut ctx = create_test_ctx();
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).unwrap();
|
||||
|
||||
let aggr_expr = AggregateExpr {
|
||||
func: AggregateFunc::SumUInt32,
|
||||
expr: ScalarExpr::Column(0),
|
||||
distinct: false,
|
||||
};
|
||||
let expected = TypedPlan {
|
||||
typ: RelationType::new(vec![
|
||||
ColumnType::new(CDT::uint32_datatype(), true),
|
||||
ColumnType::new(CDT::uint32_datatype(), false),
|
||||
]),
|
||||
plan: Plan::Mfp {
|
||||
input: Box::new(Plan::Reduce {
|
||||
input: Box::new(Plan::Get {
|
||||
id: crate::expr::Id::Global(GlobalId::User(0)),
|
||||
}),
|
||||
key_val_plan: KeyValPlan {
|
||||
key_plan: MapFilterProject::new(1)
|
||||
.map(vec![ScalarExpr::Column(0)])
|
||||
.unwrap()
|
||||
.project(vec![1])
|
||||
.unwrap()
|
||||
.into_safe(),
|
||||
val_plan: MapFilterProject::new(1)
|
||||
.project(vec![0])
|
||||
.unwrap()
|
||||
.into_safe(),
|
||||
},
|
||||
reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
|
||||
full_aggrs: vec![aggr_expr.clone()],
|
||||
simple_aggrs: vec![(0, 0, aggr_expr.clone())],
|
||||
distinct_aggrs: vec![],
|
||||
}),
|
||||
}),
|
||||
mfp: MapFilterProject::new(2)
|
||||
.map(vec![ScalarExpr::Column(1), ScalarExpr::Column(0)])
|
||||
.unwrap()
|
||||
.project(vec![2, 3])
|
||||
.unwrap(),
|
||||
},
|
||||
};
|
||||
|
||||
assert_eq!(flow_plan, expected);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sum_add() {
|
||||
let engine = create_test_query_engine();
|
||||
let sql = "SELECT sum(number+number) FROM numbers";
|
||||
let plan = sql_to_substrait(engine.clone(), sql).await;
|
||||
|
||||
let mut ctx = create_test_ctx();
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
|
||||
|
||||
let aggr_expr = AggregateExpr {
|
||||
func: AggregateFunc::SumUInt32,
|
||||
expr: ScalarExpr::Column(0),
|
||||
distinct: false,
|
||||
};
|
||||
let expected = TypedPlan {
|
||||
typ: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)]),
|
||||
plan: Plan::Mfp {
|
||||
input: Box::new(Plan::Reduce {
|
||||
input: Box::new(Plan::Get {
|
||||
id: crate::expr::Id::Global(GlobalId::User(0)),
|
||||
}),
|
||||
key_val_plan: KeyValPlan {
|
||||
key_plan: MapFilterProject::new(1)
|
||||
.project(vec![])
|
||||
.unwrap()
|
||||
.into_safe(),
|
||||
val_plan: MapFilterProject::new(1)
|
||||
.map(vec![ScalarExpr::Column(0)
|
||||
.call_binary(ScalarExpr::Column(0), BinaryFunc::AddUInt32)])
|
||||
.unwrap()
|
||||
.project(vec![1])
|
||||
.unwrap()
|
||||
.into_safe(),
|
||||
},
|
||||
reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
|
||||
full_aggrs: vec![aggr_expr.clone()],
|
||||
simple_aggrs: vec![(0, 0, aggr_expr.clone())],
|
||||
distinct_aggrs: vec![],
|
||||
}),
|
||||
}),
|
||||
mfp: MapFilterProject::new(1)
|
||||
.map(vec![ScalarExpr::Column(0)])
|
||||
.unwrap()
|
||||
.project(vec![1])
|
||||
.unwrap(),
|
||||
},
|
||||
};
|
||||
assert_eq!(flow_plan.unwrap(), expected);
|
||||
}
|
||||
}
|
||||
449
src/flow/src/transform/expr.rs
Normal file
449
src/flow/src/transform/expr.rs
Normal file
@@ -0,0 +1,449 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#![warn(unused_imports)]
|
||||
|
||||
use datatypes::data_type::ConcreteDataType as CDT;
|
||||
use itertools::Itertools;
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
use substrait::substrait_proto::proto::expression::field_reference::ReferenceType::DirectReference;
|
||||
use substrait::substrait_proto::proto::expression::reference_segment::ReferenceType::StructField;
|
||||
use substrait::substrait_proto::proto::expression::{IfThen, RexType, ScalarFunction};
|
||||
use substrait::substrait_proto::proto::function_argument::ArgType;
|
||||
use substrait::substrait_proto::proto::Expression;
|
||||
|
||||
use crate::adapter::error::{
|
||||
DatatypesSnafu, Error, EvalSnafu, InvalidQuerySnafu, NotImplementedSnafu, PlanSnafu,
|
||||
};
|
||||
use crate::expr::{
|
||||
BinaryFunc, ScalarExpr, TypedExpr, UnaryFunc, UnmaterializableFunc, VariadicFunc,
|
||||
};
|
||||
use crate::repr::{ColumnType, RelationType};
|
||||
use crate::transform::literal::{from_substrait_literal, from_substrait_type};
|
||||
use crate::transform::FunctionExtensions;
|
||||
|
||||
impl TypedExpr {
|
||||
/// Convert ScalarFunction into Flow's ScalarExpr
|
||||
pub fn from_substrait_scalar_func(
|
||||
f: &ScalarFunction,
|
||||
input_schema: &RelationType,
|
||||
extensions: &FunctionExtensions,
|
||||
) -> Result<TypedExpr, Error> {
|
||||
let fn_name =
|
||||
extensions
|
||||
.get(&f.function_reference)
|
||||
.with_context(|| NotImplementedSnafu {
|
||||
reason: format!(
|
||||
"Aggregated function not found: function reference = {:?}",
|
||||
f.function_reference
|
||||
),
|
||||
})?;
|
||||
let arg_len = f.arguments.len();
|
||||
let arg_exprs: Vec<TypedExpr> = f
|
||||
.arguments
|
||||
.iter()
|
||||
.map(|arg| match &arg.arg_type {
|
||||
Some(ArgType::Value(e)) => {
|
||||
TypedExpr::from_substrait_rex(e, input_schema, extensions)
|
||||
}
|
||||
_ => not_impl_err!("Aggregated function argument non-Value type not supported"),
|
||||
})
|
||||
.try_collect()?;
|
||||
|
||||
// literal's type is determined by the function and type of other args
|
||||
let (arg_exprs, arg_types): (Vec<_>, Vec<_>) = arg_exprs
|
||||
.into_iter()
|
||||
.map(
|
||||
|TypedExpr {
|
||||
expr: arg_val,
|
||||
typ: arg_type,
|
||||
}| {
|
||||
if arg_val.is_literal() {
|
||||
(arg_val, None)
|
||||
} else {
|
||||
(arg_val, Some(arg_type.scalar_type))
|
||||
}
|
||||
},
|
||||
)
|
||||
.unzip();
|
||||
|
||||
match arg_len {
|
||||
// because variadic function can also have 1 arguments, we need to check if it's a variadic function first
|
||||
1 if VariadicFunc::from_str_and_types(fn_name, &arg_types).is_err() => {
|
||||
let func = UnaryFunc::from_str_and_type(fn_name, None)?;
|
||||
let arg = arg_exprs[0].clone();
|
||||
let ret_type = ColumnType::new_nullable(func.signature().output.clone());
|
||||
|
||||
Ok(TypedExpr::new(arg.call_unary(func), ret_type))
|
||||
}
|
||||
// because variadic function can also have 2 arguments, we need to check if it's a variadic function first
|
||||
2 if VariadicFunc::from_str_and_types(fn_name, &arg_types).is_err() => {
|
||||
let (func, signature) =
|
||||
BinaryFunc::from_str_expr_and_type(fn_name, &arg_exprs, &arg_types[0..2])?;
|
||||
|
||||
// constant folding here
|
||||
let is_all_literal = arg_exprs.iter().all(|arg| arg.is_literal());
|
||||
if is_all_literal {
|
||||
let res = func
|
||||
.eval(&[], &arg_exprs[0], &arg_exprs[1])
|
||||
.context(EvalSnafu)?;
|
||||
|
||||
// if output type is null, it should be inferred from the input types
|
||||
let con_typ = signature.output.clone();
|
||||
let typ = ColumnType::new_nullable(con_typ.clone());
|
||||
return Ok(TypedExpr::new(ScalarExpr::Literal(res, con_typ), typ));
|
||||
}
|
||||
|
||||
let mut arg_exprs = arg_exprs;
|
||||
for (idx, arg_expr) in arg_exprs.iter_mut().enumerate() {
|
||||
if let ScalarExpr::Literal(val, typ) = arg_expr {
|
||||
let dest_type = signature.input[idx].clone();
|
||||
|
||||
// cast val to target_type
|
||||
let dest_val = if !dest_type.is_null() {
|
||||
datatypes::types::cast(val.clone(), &dest_type)
|
||||
.with_context(|_|
|
||||
DatatypesSnafu{
|
||||
extra: format!("Failed to implicitly cast literal {val:?} to type {dest_type:?}")
|
||||
})?
|
||||
} else {
|
||||
val.clone()
|
||||
};
|
||||
*val = dest_val;
|
||||
*typ = dest_type;
|
||||
}
|
||||
}
|
||||
|
||||
let ret_type = ColumnType::new_nullable(func.signature().output.clone());
|
||||
let ret_expr = arg_exprs[0].clone().call_binary(arg_exprs[1].clone(), func);
|
||||
Ok(TypedExpr::new(ret_expr, ret_type))
|
||||
}
|
||||
_var => {
|
||||
if let Ok(func) = VariadicFunc::from_str_and_types(fn_name, &arg_types) {
|
||||
let ret_type = ColumnType::new_nullable(func.signature().output.clone());
|
||||
let mut expr = ScalarExpr::CallVariadic {
|
||||
func,
|
||||
exprs: arg_exprs,
|
||||
};
|
||||
expr.optimize();
|
||||
Ok(TypedExpr::new(expr, ret_type))
|
||||
} else if let Ok(func) = UnmaterializableFunc::from_str(fn_name) {
|
||||
let ret_type = ColumnType::new_nullable(func.signature().output.clone());
|
||||
Ok(TypedExpr::new(
|
||||
ScalarExpr::CallUnmaterializable(func),
|
||||
ret_type,
|
||||
))
|
||||
} else {
|
||||
not_impl_err!("Unsupported function {fn_name} with {arg_len} arguments")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert IfThen into Flow's ScalarExpr
|
||||
pub fn from_substrait_ifthen_rex(
|
||||
if_then: &IfThen,
|
||||
input_schema: &RelationType,
|
||||
extensions: &FunctionExtensions,
|
||||
) -> Result<TypedExpr, Error> {
|
||||
let ifs: Vec<_> = if_then
|
||||
.ifs
|
||||
.iter()
|
||||
.map(|if_clause| {
|
||||
let proto_if = if_clause.r#if.as_ref().with_context(|| InvalidQuerySnafu {
|
||||
reason: "IfThen clause without if",
|
||||
})?;
|
||||
let proto_then = if_clause.then.as_ref().with_context(|| InvalidQuerySnafu {
|
||||
reason: "IfThen clause without then",
|
||||
})?;
|
||||
let cond = TypedExpr::from_substrait_rex(proto_if, input_schema, extensions)?;
|
||||
let then = TypedExpr::from_substrait_rex(proto_then, input_schema, extensions)?;
|
||||
Ok((cond, then))
|
||||
})
|
||||
.try_collect()?;
|
||||
// if no else is presented
|
||||
let els = if_then
|
||||
.r#else
|
||||
.as_ref()
|
||||
.map(|e| TypedExpr::from_substrait_rex(e, input_schema, extensions))
|
||||
.transpose()?
|
||||
.unwrap_or_else(|| {
|
||||
TypedExpr::new(
|
||||
ScalarExpr::literal_null(),
|
||||
ColumnType::new_nullable(CDT::null_datatype()),
|
||||
)
|
||||
});
|
||||
|
||||
fn build_if_then_recur(
|
||||
mut next_if_then: impl Iterator<Item = (TypedExpr, TypedExpr)>,
|
||||
els: TypedExpr,
|
||||
) -> TypedExpr {
|
||||
if let Some((cond, then)) = next_if_then.next() {
|
||||
// always assume the type of `if`` expr is the same with the `then`` expr
|
||||
TypedExpr::new(
|
||||
ScalarExpr::If {
|
||||
cond: Box::new(cond.expr),
|
||||
then: Box::new(then.expr),
|
||||
els: Box::new(build_if_then_recur(next_if_then, els).expr),
|
||||
},
|
||||
then.typ,
|
||||
)
|
||||
} else {
|
||||
els
|
||||
}
|
||||
}
|
||||
let expr_if = build_if_then_recur(ifs.into_iter(), els);
|
||||
Ok(expr_if)
|
||||
}
|
||||
/// Convert Substrait Rex into Flow's ScalarExpr
|
||||
pub fn from_substrait_rex(
|
||||
e: &Expression,
|
||||
input_schema: &RelationType,
|
||||
extensions: &FunctionExtensions,
|
||||
) -> Result<TypedExpr, Error> {
|
||||
match &e.rex_type {
|
||||
Some(RexType::Literal(lit)) => {
|
||||
let lit = from_substrait_literal(lit)?;
|
||||
Ok(TypedExpr::new(
|
||||
ScalarExpr::Literal(lit.0, lit.1.clone()),
|
||||
ColumnType::new_nullable(lit.1),
|
||||
))
|
||||
}
|
||||
Some(RexType::SingularOrList(s)) => {
|
||||
let substrait_expr = s.value.as_ref().with_context(|| InvalidQuerySnafu {
|
||||
reason: "SingularOrList expression without value",
|
||||
})?;
|
||||
// Note that we didn't impl support to in list expr
|
||||
if !s.options.is_empty() {
|
||||
return not_impl_err!("In list expression is not supported");
|
||||
}
|
||||
TypedExpr::from_substrait_rex(substrait_expr, input_schema, extensions)
|
||||
}
|
||||
Some(RexType::Selection(field_ref)) => match &field_ref.reference_type {
|
||||
Some(DirectReference(direct)) => match &direct.reference_type.as_ref() {
|
||||
Some(StructField(x)) => match &x.child.as_ref() {
|
||||
Some(_) => {
|
||||
not_impl_err!(
|
||||
"Direct reference StructField with child is not supported"
|
||||
)
|
||||
}
|
||||
None => {
|
||||
let column = x.field as usize;
|
||||
let column_type = input_schema.column_types[column].clone();
|
||||
Ok(TypedExpr::new(ScalarExpr::Column(column), column_type))
|
||||
}
|
||||
},
|
||||
_ => not_impl_err!(
|
||||
"Direct reference with types other than StructField is not supported"
|
||||
),
|
||||
},
|
||||
_ => not_impl_err!("unsupported field ref type"),
|
||||
},
|
||||
Some(RexType::ScalarFunction(f)) => {
|
||||
TypedExpr::from_substrait_scalar_func(f, input_schema, extensions)
|
||||
}
|
||||
Some(RexType::IfThen(if_then)) => {
|
||||
TypedExpr::from_substrait_ifthen_rex(if_then, input_schema, extensions)
|
||||
}
|
||||
Some(RexType::Cast(cast)) => {
|
||||
let input = cast.input.as_ref().with_context(|| InvalidQuerySnafu {
|
||||
reason: "Cast expression without input",
|
||||
})?;
|
||||
let input = TypedExpr::from_substrait_rex(input, input_schema, extensions)?;
|
||||
let cast_type = from_substrait_type(cast.r#type.as_ref().with_context(|| {
|
||||
InvalidQuerySnafu {
|
||||
reason: "Cast expression without type",
|
||||
}
|
||||
})?)?;
|
||||
let func = UnaryFunc::from_str_and_type("cast", Some(cast_type.clone()))?;
|
||||
Ok(TypedExpr::new(
|
||||
input.expr.call_unary(func),
|
||||
ColumnType::new_nullable(cast_type),
|
||||
))
|
||||
}
|
||||
Some(RexType::WindowFunction(_)) => PlanSnafu {
|
||||
reason:
|
||||
"Window function is not supported yet. Please use aggregation function instead."
|
||||
.to_string(),
|
||||
}
|
||||
.fail(),
|
||||
_ => not_impl_err!("unsupported rex_type"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use datatypes::value::Value;
|
||||
|
||||
use super::*;
|
||||
use crate::expr::{GlobalId, MapFilterProject};
|
||||
use crate::plan::{Plan, TypedPlan};
|
||||
use crate::repr::{self, ColumnType, RelationType};
|
||||
use crate::transform::test::{create_test_ctx, create_test_query_engine, sql_to_substrait};
|
||||
/// test if `WHERE` condition can be converted to Flow's ScalarExpr in mfp's filter
|
||||
#[tokio::test]
|
||||
async fn test_where_and() {
|
||||
let engine = create_test_query_engine();
|
||||
let sql = "SELECT number FROM numbers WHERE number >= 1 AND number <= 3 AND number!=2";
|
||||
let plan = sql_to_substrait(engine.clone(), sql).await;
|
||||
|
||||
let mut ctx = create_test_ctx();
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
|
||||
|
||||
// optimize binary and to variadic and
|
||||
let filter = ScalarExpr::CallVariadic {
|
||||
func: VariadicFunc::And,
|
||||
exprs: vec![
|
||||
ScalarExpr::Column(0).call_binary(
|
||||
ScalarExpr::Literal(Value::from(1u32), CDT::uint32_datatype()),
|
||||
BinaryFunc::Gte,
|
||||
),
|
||||
ScalarExpr::Column(0).call_binary(
|
||||
ScalarExpr::Literal(Value::from(3u32), CDT::uint32_datatype()),
|
||||
BinaryFunc::Lte,
|
||||
),
|
||||
ScalarExpr::Column(0).call_binary(
|
||||
ScalarExpr::Literal(Value::from(2u32), CDT::uint32_datatype()),
|
||||
BinaryFunc::NotEq,
|
||||
),
|
||||
],
|
||||
};
|
||||
let expected = TypedPlan {
|
||||
typ: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), false)]),
|
||||
plan: Plan::Mfp {
|
||||
input: Box::new(Plan::Get {
|
||||
id: crate::expr::Id::Global(GlobalId::User(0)),
|
||||
}),
|
||||
mfp: MapFilterProject::new(1)
|
||||
.map(vec![ScalarExpr::Column(0)])
|
||||
.unwrap()
|
||||
.filter(vec![filter])
|
||||
.unwrap()
|
||||
.project(vec![1])
|
||||
.unwrap(),
|
||||
},
|
||||
};
|
||||
assert_eq!(flow_plan.unwrap(), expected);
|
||||
}
|
||||
|
||||
/// case: binary functions&constant folding can happen in converting substrait plan
|
||||
#[tokio::test]
|
||||
async fn test_binary_func_and_constant_folding() {
|
||||
let engine = create_test_query_engine();
|
||||
let sql = "SELECT 1+1*2-1/1+1%2==3 FROM numbers";
|
||||
let plan = sql_to_substrait(engine.clone(), sql).await;
|
||||
|
||||
let mut ctx = create_test_ctx();
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
|
||||
|
||||
let expected = TypedPlan {
|
||||
typ: RelationType::new(vec![ColumnType::new(CDT::boolean_datatype(), true)]),
|
||||
plan: Plan::Constant {
|
||||
rows: vec![(
|
||||
repr::Row::new(vec![Value::from(true)]),
|
||||
repr::Timestamp::MIN,
|
||||
1,
|
||||
)],
|
||||
},
|
||||
};
|
||||
|
||||
assert_eq!(flow_plan.unwrap(), expected);
|
||||
}
|
||||
|
||||
/// test if the type of the literal is correctly inferred, i.e. in here literal is decoded to be int64, but need to be uint32,
|
||||
#[tokio::test]
|
||||
async fn test_implicitly_cast() {
|
||||
let engine = create_test_query_engine();
|
||||
let sql = "SELECT number+1 FROM numbers";
|
||||
let plan = sql_to_substrait(engine.clone(), sql).await;
|
||||
|
||||
let mut ctx = create_test_ctx();
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
|
||||
|
||||
let expected = TypedPlan {
|
||||
typ: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)]),
|
||||
plan: Plan::Mfp {
|
||||
input: Box::new(Plan::Get {
|
||||
id: crate::expr::Id::Global(GlobalId::User(0)),
|
||||
}),
|
||||
mfp: MapFilterProject::new(1)
|
||||
.map(vec![ScalarExpr::Column(0).call_binary(
|
||||
ScalarExpr::Literal(Value::from(1u32), CDT::uint32_datatype()),
|
||||
BinaryFunc::AddUInt32,
|
||||
)])
|
||||
.unwrap()
|
||||
.project(vec![1])
|
||||
.unwrap(),
|
||||
},
|
||||
};
|
||||
assert_eq!(flow_plan.unwrap(), expected);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cast() {
|
||||
let engine = create_test_query_engine();
|
||||
let sql = "SELECT CAST(1 AS INT16) FROM numbers";
|
||||
let plan = sql_to_substrait(engine.clone(), sql).await;
|
||||
|
||||
let mut ctx = create_test_ctx();
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
|
||||
|
||||
let expected = TypedPlan {
|
||||
typ: RelationType::new(vec![ColumnType::new(CDT::int16_datatype(), true)]),
|
||||
plan: Plan::Mfp {
|
||||
input: Box::new(Plan::Get {
|
||||
id: crate::expr::Id::Global(GlobalId::User(0)),
|
||||
}),
|
||||
mfp: MapFilterProject::new(1)
|
||||
.map(vec![ScalarExpr::Literal(
|
||||
Value::Int64(1),
|
||||
CDT::int64_datatype(),
|
||||
)
|
||||
.call_unary(UnaryFunc::Cast(CDT::int16_datatype()))])
|
||||
.unwrap()
|
||||
.project(vec![1])
|
||||
.unwrap(),
|
||||
},
|
||||
};
|
||||
assert_eq!(flow_plan.unwrap(), expected);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_select_add() {
|
||||
let engine = create_test_query_engine();
|
||||
let sql = "SELECT number+number FROM numbers";
|
||||
let plan = sql_to_substrait(engine.clone(), sql).await;
|
||||
|
||||
let mut ctx = create_test_ctx();
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
|
||||
|
||||
let expected = TypedPlan {
|
||||
typ: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)]),
|
||||
plan: Plan::Mfp {
|
||||
input: Box::new(Plan::Get {
|
||||
id: crate::expr::Id::Global(GlobalId::User(0)),
|
||||
}),
|
||||
mfp: MapFilterProject::new(1)
|
||||
.map(vec![ScalarExpr::Column(0)
|
||||
.call_binary(ScalarExpr::Column(0), BinaryFunc::AddUInt32)])
|
||||
.unwrap()
|
||||
.project(vec![1])
|
||||
.unwrap(),
|
||||
},
|
||||
};
|
||||
|
||||
assert_eq!(flow_plan.unwrap(), expected);
|
||||
}
|
||||
}
|
||||
191
src/flow/src/transform/literal.rs
Normal file
191
src/flow/src/transform/literal.rs
Normal file
@@ -0,0 +1,191 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use common_decimal::Decimal128;
|
||||
use common_time::{Date, Timestamp};
|
||||
use datafusion_substrait::variation_const::{
|
||||
DATE_32_TYPE_REF, DATE_64_TYPE_REF, DEFAULT_TYPE_REF, TIMESTAMP_MICRO_TYPE_REF,
|
||||
TIMESTAMP_MILLI_TYPE_REF, TIMESTAMP_NANO_TYPE_REF, TIMESTAMP_SECOND_TYPE_REF,
|
||||
UNSIGNED_INTEGER_TYPE_REF,
|
||||
};
|
||||
use datatypes::data_type::ConcreteDataType as CDT;
|
||||
use datatypes::value::Value;
|
||||
use substrait::substrait_proto::proto::expression::literal::LiteralType;
|
||||
use substrait::substrait_proto::proto::expression::Literal;
|
||||
use substrait::substrait_proto::proto::r#type::Kind;
|
||||
|
||||
use crate::adapter::error::{Error, NotImplementedSnafu, PlanSnafu};
|
||||
|
||||
/// Convert a Substrait literal into a Value and its ConcreteDataType (So that we can know type even if the value is null)
|
||||
pub(crate) fn from_substrait_literal(lit: &Literal) -> Result<(Value, CDT), Error> {
|
||||
let scalar_value = match &lit.literal_type {
|
||||
Some(LiteralType::Boolean(b)) => (Value::from(*b), CDT::boolean_datatype()),
|
||||
Some(LiteralType::I8(n)) => match lit.type_variation_reference {
|
||||
DEFAULT_TYPE_REF => (Value::from(*n as i8), CDT::int8_datatype()),
|
||||
UNSIGNED_INTEGER_TYPE_REF => (Value::from(*n as u8), CDT::uint8_datatype()),
|
||||
others => not_impl_err!("Unknown type variation reference {others}",)?,
|
||||
},
|
||||
Some(LiteralType::I16(n)) => match lit.type_variation_reference {
|
||||
DEFAULT_TYPE_REF => (Value::from(*n as i16), CDT::int16_datatype()),
|
||||
UNSIGNED_INTEGER_TYPE_REF => (Value::from(*n as u16), CDT::uint16_datatype()),
|
||||
others => not_impl_err!("Unknown type variation reference {others}",)?,
|
||||
},
|
||||
Some(LiteralType::I32(n)) => match lit.type_variation_reference {
|
||||
DEFAULT_TYPE_REF => (Value::from(*n), CDT::int32_datatype()),
|
||||
UNSIGNED_INTEGER_TYPE_REF => (Value::from(*n as u32), CDT::uint32_datatype()),
|
||||
others => not_impl_err!("Unknown type variation reference {others}",)?,
|
||||
},
|
||||
Some(LiteralType::I64(n)) => match lit.type_variation_reference {
|
||||
DEFAULT_TYPE_REF => (Value::from(*n), CDT::int64_datatype()),
|
||||
UNSIGNED_INTEGER_TYPE_REF => (Value::from(*n as u64), CDT::uint64_datatype()),
|
||||
others => not_impl_err!("Unknown type variation reference {others}",)?,
|
||||
},
|
||||
Some(LiteralType::Fp32(f)) => (Value::from(*f), CDT::float32_datatype()),
|
||||
Some(LiteralType::Fp64(f)) => (Value::from(*f), CDT::float64_datatype()),
|
||||
Some(LiteralType::Timestamp(t)) => match lit.type_variation_reference {
|
||||
TIMESTAMP_SECOND_TYPE_REF => (
|
||||
Value::from(Timestamp::new_second(*t)),
|
||||
CDT::timestamp_second_datatype(),
|
||||
),
|
||||
TIMESTAMP_MILLI_TYPE_REF => (
|
||||
Value::from(Timestamp::new_millisecond(*t)),
|
||||
CDT::timestamp_millisecond_datatype(),
|
||||
),
|
||||
TIMESTAMP_MICRO_TYPE_REF => (
|
||||
Value::from(Timestamp::new_microsecond(*t)),
|
||||
CDT::timestamp_microsecond_datatype(),
|
||||
),
|
||||
TIMESTAMP_NANO_TYPE_REF => (
|
||||
Value::from(Timestamp::new_nanosecond(*t)),
|
||||
CDT::timestamp_nanosecond_datatype(),
|
||||
),
|
||||
others => not_impl_err!("Unknown type variation reference {others}",)?,
|
||||
},
|
||||
Some(LiteralType::Date(d)) => (Value::from(Date::new(*d)), CDT::date_datatype()),
|
||||
Some(LiteralType::String(s)) => (Value::from(s.clone()), CDT::string_datatype()),
|
||||
Some(LiteralType::Binary(b)) | Some(LiteralType::FixedBinary(b)) => {
|
||||
(Value::from(b.clone()), CDT::binary_datatype())
|
||||
}
|
||||
Some(LiteralType::Decimal(d)) => {
|
||||
let value: [u8; 16] = d.value.clone().try_into().map_err(|e| {
|
||||
PlanSnafu {
|
||||
reason: format!("Failed to parse decimal value from {e:?}"),
|
||||
}
|
||||
.build()
|
||||
})?;
|
||||
let p: u8 = d.precision.try_into().map_err(|e| {
|
||||
PlanSnafu {
|
||||
reason: format!("Failed to parse decimal precision: {e}"),
|
||||
}
|
||||
.build()
|
||||
})?;
|
||||
let s: i8 = d.scale.try_into().map_err(|e| {
|
||||
PlanSnafu {
|
||||
reason: format!("Failed to parse decimal scale: {e}"),
|
||||
}
|
||||
.build()
|
||||
})?;
|
||||
let value = i128::from_le_bytes(value);
|
||||
(
|
||||
Value::from(Decimal128::new(value, p, s)),
|
||||
CDT::decimal128_datatype(p, s),
|
||||
)
|
||||
}
|
||||
Some(LiteralType::Null(ntype)) => (Value::Null, from_substrait_type(ntype)?),
|
||||
_ => not_impl_err!("unsupported literal_type")?,
|
||||
};
|
||||
Ok(scalar_value)
|
||||
}
|
||||
|
||||
/// convert a Substrait type into a ConcreteDataType
|
||||
pub fn from_substrait_type(
|
||||
null_type: &substrait::substrait_proto::proto::Type,
|
||||
) -> Result<CDT, Error> {
|
||||
if let Some(kind) = &null_type.kind {
|
||||
match kind {
|
||||
Kind::Bool(_) => Ok(CDT::boolean_datatype()),
|
||||
Kind::I8(integer) => match integer.type_variation_reference {
|
||||
DEFAULT_TYPE_REF => Ok(CDT::int8_datatype()),
|
||||
UNSIGNED_INTEGER_TYPE_REF => Ok(CDT::uint8_datatype()),
|
||||
v => not_impl_err!("Unsupported Substrait type variation {v} of type {kind:?}"),
|
||||
},
|
||||
Kind::I16(integer) => match integer.type_variation_reference {
|
||||
DEFAULT_TYPE_REF => Ok(CDT::int16_datatype()),
|
||||
UNSIGNED_INTEGER_TYPE_REF => Ok(CDT::uint16_datatype()),
|
||||
v => not_impl_err!("Unsupported Substrait type variation {v} of type {kind:?}"),
|
||||
},
|
||||
Kind::I32(integer) => match integer.type_variation_reference {
|
||||
DEFAULT_TYPE_REF => Ok(CDT::int32_datatype()),
|
||||
UNSIGNED_INTEGER_TYPE_REF => Ok(CDT::uint32_datatype()),
|
||||
v => not_impl_err!("Unsupported Substrait type variation {v} of type {kind:?}"),
|
||||
},
|
||||
Kind::I64(integer) => match integer.type_variation_reference {
|
||||
DEFAULT_TYPE_REF => Ok(CDT::int64_datatype()),
|
||||
UNSIGNED_INTEGER_TYPE_REF => Ok(CDT::uint64_datatype()),
|
||||
v => not_impl_err!("Unsupported Substrait type variation {v} of type {kind:?}"),
|
||||
},
|
||||
Kind::Fp32(_) => Ok(CDT::float32_datatype()),
|
||||
Kind::Fp64(_) => Ok(CDT::float64_datatype()),
|
||||
Kind::Timestamp(ts) => match ts.type_variation_reference {
|
||||
TIMESTAMP_SECOND_TYPE_REF => Ok(CDT::timestamp_second_datatype()),
|
||||
TIMESTAMP_MILLI_TYPE_REF => Ok(CDT::timestamp_millisecond_datatype()),
|
||||
TIMESTAMP_MICRO_TYPE_REF => Ok(CDT::timestamp_microsecond_datatype()),
|
||||
TIMESTAMP_NANO_TYPE_REF => Ok(CDT::timestamp_nanosecond_datatype()),
|
||||
v => not_impl_err!("Unsupported Substrait type variation {v} of type {kind:?}"),
|
||||
},
|
||||
Kind::Date(date) => match date.type_variation_reference {
|
||||
DATE_32_TYPE_REF => Ok(CDT::date_datatype()),
|
||||
DATE_64_TYPE_REF => Ok(CDT::date_datatype()),
|
||||
v => not_impl_err!("Unsupported Substrait type variation {v} of type {kind:?}"),
|
||||
},
|
||||
Kind::Binary(_) => Ok(CDT::binary_datatype()),
|
||||
Kind::String(_) => Ok(CDT::string_datatype()),
|
||||
Kind::Decimal(d) => Ok(CDT::decimal128_datatype(d.precision as u8, d.scale as i8)),
|
||||
_ => not_impl_err!("Unsupported Substrait type: {kind:?}"),
|
||||
}
|
||||
} else {
|
||||
not_impl_err!("Null type without kind is not supported")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use crate::plan::{Plan, TypedPlan};
|
||||
use crate::repr::{self, ColumnType, RelationType};
|
||||
use crate::transform::test::{create_test_ctx, create_test_query_engine, sql_to_substrait};
|
||||
/// test if literal in substrait plan can be correctly converted to flow plan
|
||||
#[tokio::test]
|
||||
async fn test_literal() {
|
||||
let engine = create_test_query_engine();
|
||||
let sql = "SELECT 1 FROM numbers";
|
||||
let plan = sql_to_substrait(engine.clone(), sql).await;
|
||||
|
||||
let mut ctx = create_test_ctx();
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
|
||||
|
||||
let expected = TypedPlan {
|
||||
typ: RelationType::new(vec![ColumnType::new(CDT::int64_datatype(), true)]),
|
||||
plan: Plan::Constant {
|
||||
rows: vec![(
|
||||
repr::Row::new(vec![Value::Int64(1)]),
|
||||
repr::Timestamp::MIN,
|
||||
1,
|
||||
)],
|
||||
},
|
||||
};
|
||||
|
||||
assert_eq!(flow_plan.unwrap(), expected);
|
||||
}
|
||||
}
|
||||
190
src/flow/src/transform/plan.rs
Normal file
190
src/flow/src/transform/plan.rs
Normal file
@@ -0,0 +1,190 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use itertools::Itertools;
|
||||
use snafu::OptionExt;
|
||||
use substrait::substrait_proto::proto::expression::MaskExpression;
|
||||
use substrait::substrait_proto::proto::read_rel::ReadType;
|
||||
use substrait::substrait_proto::proto::rel::RelType;
|
||||
use substrait::substrait_proto::proto::{plan_rel, Plan as SubPlan, Rel};
|
||||
|
||||
use crate::adapter::error::{Error, InvalidQuerySnafu, NotImplementedSnafu, PlanSnafu};
|
||||
use crate::expr::{MapFilterProject, TypedExpr};
|
||||
use crate::plan::{Plan, TypedPlan};
|
||||
use crate::repr::{self, RelationType};
|
||||
use crate::transform::{DataflowContext, FunctionExtensions};
|
||||
|
||||
impl TypedPlan {
|
||||
/// Convert Substrait Plan into Flow's TypedPlan
|
||||
pub fn from_substrait_plan(
|
||||
ctx: &mut DataflowContext,
|
||||
plan: &SubPlan,
|
||||
) -> Result<TypedPlan, Error> {
|
||||
// Register function extension
|
||||
let function_extension = FunctionExtensions::try_from_proto(&plan.extensions)?;
|
||||
|
||||
// Parse relations
|
||||
match plan.relations.len() {
|
||||
1 => {
|
||||
match plan.relations[0].rel_type.as_ref() {
|
||||
Some(rt) => match rt {
|
||||
plan_rel::RelType::Rel(rel) => {
|
||||
Ok(TypedPlan::from_substrait_rel(ctx, rel, &function_extension)?)
|
||||
},
|
||||
plan_rel::RelType::Root(root) => {
|
||||
let input = root.input.as_ref().with_context(|| InvalidQuerySnafu {
|
||||
reason: "Root relation without input",
|
||||
})?;
|
||||
Ok(TypedPlan::from_substrait_rel(ctx, input, &function_extension)?)
|
||||
}
|
||||
},
|
||||
None => plan_err!("Cannot parse plan relation: None")
|
||||
}
|
||||
},
|
||||
_ => not_impl_err!(
|
||||
"Substrait plan with more than 1 relation trees not supported. Number of relation trees: {:?}",
|
||||
plan.relations.len()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert Substrait Rel into Flow's TypedPlan
|
||||
/// TODO: SELECT DISTINCT(does it get compile with something else?)
|
||||
pub fn from_substrait_rel(
|
||||
ctx: &mut DataflowContext,
|
||||
rel: &Rel,
|
||||
extensions: &FunctionExtensions,
|
||||
) -> Result<TypedPlan, Error> {
|
||||
match &rel.rel_type {
|
||||
Some(RelType::Project(p)) => {
|
||||
let input = if let Some(input) = p.input.as_ref() {
|
||||
TypedPlan::from_substrait_rel(ctx, input, extensions)?
|
||||
} else {
|
||||
return not_impl_err!("Projection without an input is not supported");
|
||||
};
|
||||
let mut exprs: Vec<TypedExpr> = vec![];
|
||||
for e in &p.expressions {
|
||||
let expr = TypedExpr::from_substrait_rex(e, &input.typ, extensions)?;
|
||||
exprs.push(expr);
|
||||
}
|
||||
let is_literal = exprs.iter().all(|expr| expr.expr.is_literal());
|
||||
if is_literal {
|
||||
let (literals, lit_types): (Vec<_>, Vec<_>) = exprs
|
||||
.into_iter()
|
||||
.map(|TypedExpr { expr, typ }| (expr, typ))
|
||||
.unzip();
|
||||
let typ = RelationType::new(lit_types);
|
||||
let row = literals
|
||||
.into_iter()
|
||||
.map(|lit| lit.as_literal().expect("A literal"))
|
||||
.collect_vec();
|
||||
let row = repr::Row::new(row);
|
||||
let plan = Plan::Constant {
|
||||
rows: vec![(row, repr::Timestamp::MIN, 1)],
|
||||
};
|
||||
Ok(TypedPlan { typ, plan })
|
||||
} else {
|
||||
input.projection(exprs)
|
||||
}
|
||||
}
|
||||
Some(RelType::Filter(filter)) => {
|
||||
let input = if let Some(input) = filter.input.as_ref() {
|
||||
TypedPlan::from_substrait_rel(ctx, input, extensions)?
|
||||
} else {
|
||||
return not_impl_err!("Filter without an input is not supported");
|
||||
};
|
||||
|
||||
let expr = if let Some(condition) = filter.condition.as_ref() {
|
||||
TypedExpr::from_substrait_rex(condition, &input.typ, extensions)?
|
||||
} else {
|
||||
return not_impl_err!("Filter without an condition is not valid");
|
||||
};
|
||||
input.filter(expr)
|
||||
}
|
||||
Some(RelType::Read(read)) => {
|
||||
if let Some(ReadType::NamedTable(nt)) = &read.as_ref().read_type {
|
||||
let table_reference = nt.names.clone();
|
||||
let table = ctx.table(&table_reference)?;
|
||||
let get_table = Plan::Get {
|
||||
id: crate::expr::Id::Global(table.0),
|
||||
};
|
||||
let get_table = TypedPlan {
|
||||
typ: table.1,
|
||||
plan: get_table,
|
||||
};
|
||||
|
||||
if let Some(MaskExpression {
|
||||
select: Some(projection),
|
||||
..
|
||||
}) = &read.projection
|
||||
{
|
||||
let column_indices: Vec<usize> = projection
|
||||
.struct_items
|
||||
.iter()
|
||||
.map(|item| item.field as usize)
|
||||
.collect();
|
||||
let input_arity = get_table.typ.column_types.len();
|
||||
let mfp =
|
||||
MapFilterProject::new(input_arity).project(column_indices.clone())?;
|
||||
get_table.mfp(mfp)
|
||||
} else {
|
||||
Ok(get_table)
|
||||
}
|
||||
} else {
|
||||
not_impl_err!("Only NamedTable reads are supported")
|
||||
}
|
||||
}
|
||||
Some(RelType::Aggregate(agg)) => {
|
||||
TypedPlan::from_substrait_agg_rel(ctx, agg, extensions)
|
||||
}
|
||||
_ => not_impl_err!("Unsupported relation type: {:?}", rel.rel_type),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use crate::expr::{GlobalId, ScalarExpr};
|
||||
use crate::plan::{Plan, TypedPlan};
|
||||
use crate::repr::{self, ColumnType, RelationType};
|
||||
use crate::transform::test::{create_test_ctx, create_test_query_engine, sql_to_substrait};
|
||||
use crate::transform::CDT;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_select() {
|
||||
let engine = create_test_query_engine();
|
||||
let sql = "SELECT number FROM numbers";
|
||||
let plan = sql_to_substrait(engine.clone(), sql).await;
|
||||
|
||||
let mut ctx = create_test_ctx();
|
||||
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan);
|
||||
|
||||
let expected = TypedPlan {
|
||||
typ: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), false)]),
|
||||
plan: Plan::Mfp {
|
||||
input: Box::new(Plan::Get {
|
||||
id: crate::expr::Id::Global(GlobalId::User(0)),
|
||||
}),
|
||||
mfp: MapFilterProject::new(1)
|
||||
.map(vec![ScalarExpr::Column(0)])
|
||||
.unwrap()
|
||||
.project(vec![1])
|
||||
.unwrap(),
|
||||
},
|
||||
};
|
||||
|
||||
assert_eq!(flow_plan.unwrap(), expected);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user