From d82a3a7d58462923d78a7aae91883f55e8848df9 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Fri, 2 Dec 2022 14:46:05 +0800 Subject: [PATCH] feat: implement most of scalar function and selection conversion in substrait (#678) * impl to_df_scalar_function Signed-off-by: Ruihang Xia * part of scalar functions Signed-off-by: Ruihang Xia * conjunction over filters Signed-off-by: Ruihang Xia * change the ser/de target to substrait::Plan Signed-off-by: Ruihang Xia * basic test coverage Signed-off-by: Ruihang Xia * fix typos and license header Signed-off-by: Ruihang Xia * fix clippy Signed-off-by: Ruihang Xia * fix CR comments Signed-off-by: Ruihang Xia * logs unsupported extension Signed-off-by: Ruihang Xia * Update src/common/substrait/src/df_expr.rs Co-authored-by: Yingwen * address review comments Signed-off-by: Ruihang Xia * change format Signed-off-by: Ruihang Xia * replace context with with_context Signed-off-by: Ruihang Xia Signed-off-by: Ruihang Xia Co-authored-by: Yingwen --- Cargo.lock | 2 + src/common/substrait/Cargo.toml | 2 + src/common/substrait/src/context.rs | 66 +++ src/common/substrait/src/df_expr.rs | 742 +++++++++++++++++++++++++ src/common/substrait/src/df_logical.rs | 126 ++++- src/common/substrait/src/error.rs | 4 +- src/common/substrait/src/lib.rs | 4 + 7 files changed, 916 insertions(+), 30 deletions(-) create mode 100644 src/common/substrait/src/context.rs create mode 100644 src/common/substrait/src/df_expr.rs diff --git a/Cargo.lock b/Cargo.lock index d3836fbf5a..ba44455cf2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6059,7 +6059,9 @@ dependencies = [ "catalog", "common-catalog", "common-error", + "common-telemetry", "datafusion", + "datafusion-expr", "datatypes", "futures", "prost 0.9.0", diff --git a/src/common/substrait/Cargo.toml b/src/common/substrait/Cargo.toml index 0f5502c732..9f9aea0b5e 100644 --- a/src/common/substrait/Cargo.toml +++ b/src/common/substrait/Cargo.toml @@ -9,9 +9,11 @@ bytes = "1.1" catalog = { path = "../../catalog" } common-catalog = { path = "../catalog" } common-error = { path = "../error" } +common-telemetry = { path = "../telemetry" } datafusion = { git = "https://github.com/apache/arrow-datafusion.git", branch = "arrow2", features = [ "simd", ] } +datafusion-expr = { git = "https://github.com/apache/arrow-datafusion.git", branch = "arrow2" } datatypes = { path = "../../datatypes" } futures = "0.3" prost = "0.9" diff --git a/src/common/substrait/src/context.rs b/src/common/substrait/src/context.rs new file mode 100644 index 0000000000..893546ea48 --- /dev/null +++ b/src/common/substrait/src/context.rs @@ -0,0 +1,66 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; + +use substrait_proto::protobuf::extensions::simple_extension_declaration::{ + ExtensionFunction, MappingType, +}; +use substrait_proto::protobuf::extensions::SimpleExtensionDeclaration; + +#[derive(Default)] +pub struct ConvertorContext { + scalar_fn_names: HashMap, + scalar_fn_map: HashMap, +} + +impl ConvertorContext { + pub fn register_scalar_fn>(&mut self, name: S) -> u32 { + if let Some(anchor) = self.scalar_fn_names.get(name.as_ref()) { + return *anchor; + } + + let next_anchor = self.scalar_fn_map.len() as _; + self.scalar_fn_map + .insert(next_anchor, name.as_ref().to_string()); + self.scalar_fn_names + .insert(name.as_ref().to_string(), next_anchor); + next_anchor + } + + pub fn register_scalar_with_anchor>(&mut self, name: S, anchor: u32) { + self.scalar_fn_map.insert(anchor, name.as_ref().to_string()); + self.scalar_fn_names + .insert(name.as_ref().to_string(), anchor); + } + + pub fn find_scalar_fn(&self, anchor: u32) -> Option<&str> { + self.scalar_fn_map.get(&anchor).map(|s| s.as_str()) + } + + pub fn generate_function_extension(&self) -> Vec { + let mut result = Vec::with_capacity(self.scalar_fn_map.len()); + for (anchor, name) in &self.scalar_fn_map { + let declaration = SimpleExtensionDeclaration { + mapping_type: Some(MappingType::ExtensionFunction(ExtensionFunction { + extension_uri_reference: 0, + function_anchor: *anchor, + name: name.clone(), + })), + }; + result.push(declaration); + } + result + } +} diff --git a/src/common/substrait/src/df_expr.rs b/src/common/substrait/src/df_expr.rs new file mode 100644 index 0000000000..8267fa9cc1 --- /dev/null +++ b/src/common/substrait/src/df_expr.rs @@ -0,0 +1,742 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::VecDeque; +use std::str::FromStr; + +use datafusion::logical_plan::{Column, Expr}; +use datafusion_expr::{expr_fn, BuiltinScalarFunction, Operator}; +use datatypes::schema::Schema; +use snafu::{ensure, OptionExt}; +use substrait_proto::protobuf::expression::field_reference::ReferenceType as FieldReferenceType; +use substrait_proto::protobuf::expression::reference_segment::{ + ReferenceType as SegReferenceType, StructField, +}; +use substrait_proto::protobuf::expression::{ + FieldReference, ReferenceSegment, RexType, ScalarFunction, +}; +use substrait_proto::protobuf::function_argument::ArgType; +use substrait_proto::protobuf::Expression; + +use crate::context::ConvertorContext; +use crate::error::{ + EmptyExprSnafu, InvalidParametersSnafu, MissingFieldSnafu, Result, UnsupportedExprSnafu, +}; + +/// Convert substrait's `Expression` to DataFusion's `Expr`. +pub fn to_df_expr(ctx: &ConvertorContext, expression: Expression, schema: &Schema) -> Result { + let expr_rex_type = expression.rex_type.context(EmptyExprSnafu)?; + match expr_rex_type { + RexType::Literal(_) => UnsupportedExprSnafu { + name: "substrait Literal expression", + } + .fail()?, + RexType::Selection(selection) => convert_selection_rex(*selection, schema), + RexType::ScalarFunction(scalar_fn) => convert_scalar_function(ctx, scalar_fn, schema), + RexType::WindowFunction(_) + | RexType::IfThen(_) + | RexType::SwitchExpression(_) + | RexType::SingularOrList(_) + | RexType::MultiOrList(_) + | RexType::Cast(_) + | RexType::Subquery(_) + | RexType::Enum(_) => UnsupportedExprSnafu { + name: format!("substrait expression {:?}", expr_rex_type), + } + .fail()?, + } +} + +/// Convert Substrait's `FieldReference` - `DirectReference` - `StructField` to Datafusion's +/// `Column` expr. +pub fn convert_selection_rex(selection: FieldReference, schema: &Schema) -> Result { + if let Some(FieldReferenceType::DirectReference(direct_ref)) = selection.reference_type + && let Some(SegReferenceType::StructField(field)) = direct_ref.reference_type { + let column_name = schema.column_name_by_index(field.field as _).to_string(); + Ok(Expr::Column(Column { + relation: None, + name: column_name, + })) + } else { + InvalidParametersSnafu { + reason: "Only support direct struct reference in Selection Rex", + } + .fail() + } +} + +pub fn convert_scalar_function( + ctx: &ConvertorContext, + scalar_fn: ScalarFunction, + schema: &Schema, +) -> Result { + // convert argument + let mut inputs = VecDeque::with_capacity(scalar_fn.arguments.len()); + for arg in scalar_fn.arguments { + if let Some(ArgType::Value(sub_expr)) = arg.arg_type { + inputs.push_back(to_df_expr(ctx, sub_expr, schema)?); + } else { + InvalidParametersSnafu { + reason: "Only value expression arg is supported to be function argument", + } + .fail()?; + } + } + + // convert this scalar function + // map function name + let anchor = scalar_fn.function_reference; + let fn_name = ctx + .find_scalar_fn(anchor) + .with_context(|| InvalidParametersSnafu { + reason: format!("Unregistered scalar function reference: {}", anchor), + })?; + + // convenient util + let ensure_arg_len = |expected: usize| -> Result<()> { + ensure!( + inputs.len() == expected, + InvalidParametersSnafu { + reason: format!( + "Invalid number of scalar function {}, expected {} but found {}", + fn_name, + expected, + inputs.len() + ) + } + ); + Ok(()) + }; + + // construct DataFusion expr + let expr = match fn_name { + // begin binary exprs, with the same order of DF `Operator`'s definition. + "eq" | "equal" => { + ensure_arg_len(2)?; + inputs.pop_front().unwrap().eq(inputs.pop_front().unwrap()) + } + "not_eq" | "not_equal" => { + ensure_arg_len(2)?; + inputs + .pop_front() + .unwrap() + .not_eq(inputs.pop_front().unwrap()) + } + "lt" => { + ensure_arg_len(2)?; + inputs.pop_front().unwrap().lt(inputs.pop_front().unwrap()) + } + "lt_eq" | "lte" => { + ensure_arg_len(2)?; + inputs + .pop_front() + .unwrap() + .lt_eq(inputs.pop_front().unwrap()) + } + "gt" => { + ensure_arg_len(2)?; + inputs.pop_front().unwrap().gt(inputs.pop_front().unwrap()) + } + "gt_eq" | "gte" => { + ensure_arg_len(2)?; + inputs + .pop_front() + .unwrap() + .gt_eq(inputs.pop_front().unwrap()) + } + "plus" => { + ensure_arg_len(2)?; + expr_fn::binary_expr( + inputs.pop_front().unwrap(), + Operator::Plus, + inputs.pop_front().unwrap(), + ) + } + "minus" => { + ensure_arg_len(2)?; + expr_fn::binary_expr( + inputs.pop_front().unwrap(), + Operator::Minus, + inputs.pop_front().unwrap(), + ) + } + "multiply" => { + ensure_arg_len(2)?; + expr_fn::binary_expr( + inputs.pop_front().unwrap(), + Operator::Multiply, + inputs.pop_front().unwrap(), + ) + } + "divide" => { + ensure_arg_len(2)?; + expr_fn::binary_expr( + inputs.pop_front().unwrap(), + Operator::Divide, + inputs.pop_front().unwrap(), + ) + } + "modulo" => { + ensure_arg_len(2)?; + expr_fn::binary_expr( + inputs.pop_front().unwrap(), + Operator::Modulo, + inputs.pop_front().unwrap(), + ) + } + "and" => { + ensure_arg_len(2)?; + expr_fn::and(inputs.pop_front().unwrap(), inputs.pop_front().unwrap()) + } + "or" => { + ensure_arg_len(2)?; + expr_fn::or(inputs.pop_front().unwrap(), inputs.pop_front().unwrap()) + } + "like" => { + ensure_arg_len(2)?; + inputs + .pop_front() + .unwrap() + .like(inputs.pop_front().unwrap()) + } + "not_like" => { + ensure_arg_len(2)?; + inputs + .pop_front() + .unwrap() + .not_like(inputs.pop_front().unwrap()) + } + "is_distinct_from" => { + ensure_arg_len(2)?; + expr_fn::binary_expr( + inputs.pop_front().unwrap(), + Operator::IsDistinctFrom, + inputs.pop_front().unwrap(), + ) + } + "is_not_distinct_from" => { + ensure_arg_len(2)?; + expr_fn::binary_expr( + inputs.pop_front().unwrap(), + Operator::IsNotDistinctFrom, + inputs.pop_front().unwrap(), + ) + } + "regex_match" => { + ensure_arg_len(2)?; + expr_fn::binary_expr( + inputs.pop_front().unwrap(), + Operator::RegexMatch, + inputs.pop_front().unwrap(), + ) + } + "regex_i_match" => { + ensure_arg_len(2)?; + expr_fn::binary_expr( + inputs.pop_front().unwrap(), + Operator::RegexIMatch, + inputs.pop_front().unwrap(), + ) + } + "regex_not_match" => { + ensure_arg_len(2)?; + expr_fn::binary_expr( + inputs.pop_front().unwrap(), + Operator::RegexNotMatch, + inputs.pop_front().unwrap(), + ) + } + "regex_not_i_match" => { + ensure_arg_len(2)?; + expr_fn::binary_expr( + inputs.pop_front().unwrap(), + Operator::RegexNotIMatch, + inputs.pop_front().unwrap(), + ) + } + "bitwise_and" => { + ensure_arg_len(2)?; + expr_fn::binary_expr( + inputs.pop_front().unwrap(), + Operator::BitwiseAnd, + inputs.pop_front().unwrap(), + ) + } + "bitwise_or" => { + ensure_arg_len(2)?; + expr_fn::binary_expr( + inputs.pop_front().unwrap(), + Operator::BitwiseOr, + inputs.pop_front().unwrap(), + ) + } + // end binary exprs + // start other direct expr, with the same order of DF `Expr`'s definition. + "not" => { + ensure_arg_len(1)?; + inputs.pop_front().unwrap().not() + } + "is_not_null" => { + ensure_arg_len(1)?; + inputs.pop_front().unwrap().is_not_null() + } + "is_null" => { + ensure_arg_len(1)?; + inputs.pop_front().unwrap().is_null() + } + "negative" => { + ensure_arg_len(1)?; + Expr::Negative(Box::new(inputs.pop_front().unwrap())) + } + // skip GetIndexedField, unimplemented. + "between" => { + ensure_arg_len(3)?; + Expr::Between { + expr: Box::new(inputs.pop_front().unwrap()), + negated: false, + low: Box::new(inputs.pop_front().unwrap()), + high: Box::new(inputs.pop_front().unwrap()), + } + } + "not_between" => { + ensure_arg_len(3)?; + Expr::Between { + expr: Box::new(inputs.pop_front().unwrap()), + negated: true, + low: Box::new(inputs.pop_front().unwrap()), + high: Box::new(inputs.pop_front().unwrap()), + } + } + // skip Case, is covered in substrait::SwitchExpression. + // skip Cast and TryCast, is covered in substrait::Cast. + "sort" | "sort_des" => { + ensure_arg_len(1)?; + Expr::Sort { + expr: Box::new(inputs.pop_front().unwrap()), + asc: false, + nulls_first: false, + } + } + "sort_asc" => { + ensure_arg_len(1)?; + Expr::Sort { + expr: Box::new(inputs.pop_front().unwrap()), + asc: true, + nulls_first: false, + } + } + // those are datafusion built-in "scalar functions". + "abs" + | "acos" + | "asin" + | "atan" + | "atan2" + | "ceil" + | "cos" + | "exp" + | "floor" + | "ln" + | "log" + | "log10" + | "log2" + | "power" + | "pow" + | "round" + | "signum" + | "sin" + | "sqrt" + | "tan" + | "trunc" + | "coalesce" + | "make_array" + | "ascii" + | "bit_length" + | "btrim" + | "char_length" + | "character_length" + | "concat" + | "concat_ws" + | "chr" + | "current_date" + | "current_time" + | "date_part" + | "datepart" + | "date_trunc" + | "datetrunc" + | "date_bin" + | "initcap" + | "left" + | "length" + | "lower" + | "lpad" + | "ltrim" + | "md5" + | "nullif" + | "octet_length" + | "random" + | "regexp_replace" + | "repeat" + | "replace" + | "reverse" + | "right" + | "rpad" + | "rtrim" + | "sha224" + | "sha256" + | "sha384" + | "sha512" + | "digest" + | "split_part" + | "starts_with" + | "strpos" + | "substr" + | "to_hex" + | "to_timestamp" + | "to_timestamp_millis" + | "to_timestamp_micros" + | "to_timestamp_seconds" + | "now" + | "translate" + | "trim" + | "upper" + | "uuid" + | "regexp_match" + | "struct" + | "from_unixtime" + | "arrow_typeof" => Expr::ScalarFunction { + fun: BuiltinScalarFunction::from_str(fn_name).unwrap(), + args: inputs.into(), + }, + // skip ScalarUDF, unimplemented. + // skip AggregateFunction, is covered in substrait::AggregateRel + // skip WindowFunction, is covered in substrait WindowFunction + // skip AggregateUDF, unimplemented. + // skip InList, unimplemented + // skip Wildcard, unimplemented. + // end other direct expr + _ => UnsupportedExprSnafu { + name: format!("scalar function {}", fn_name), + } + .fail()?, + }; + + Ok(expr) +} + +/// Convert DataFusion's `Expr` to substrait's `Expression` +pub fn expression_from_df_expr( + ctx: &mut ConvertorContext, + expr: &Expr, + schema: &Schema, +) -> Result { + let expression = match expr { + // Don't merge them with other unsupported expr arms to preserve the ordering. + Expr::Alias(..) => UnsupportedExprSnafu { + name: expr.to_string(), + } + .fail()?, + Expr::Column(column) => { + let field_reference = convert_column(column, schema)?; + Expression { + rex_type: Some(RexType::Selection(Box::new(field_reference))), + } + } + // Don't merge them with other unsupported expr arms to preserve the ordering. + Expr::ScalarVariable(..) | Expr::Literal(..) => UnsupportedExprSnafu { + name: expr.to_string(), + } + .fail()?, + Expr::BinaryExpr { left, op, right } => { + let left = expression_from_df_expr(ctx, left, schema)?; + let right = expression_from_df_expr(ctx, right, schema)?; + let arguments = utils::expression_to_argument(vec![left, right]); + let op_name = utils::name_df_operator(op); + let function_reference = ctx.register_scalar_fn(op_name); + utils::build_scalar_function_expression(function_reference, arguments) + } + Expr::Not(e) => { + let arg = expression_from_df_expr(ctx, e, schema)?; + let arguments = utils::expression_to_argument(vec![arg]); + let op_name = "not"; + let function_reference = ctx.register_scalar_fn(op_name); + utils::build_scalar_function_expression(function_reference, arguments) + } + Expr::IsNotNull(e) => { + let arg = expression_from_df_expr(ctx, e, schema)?; + let arguments = utils::expression_to_argument(vec![arg]); + let op_name = "is_not_null"; + let function_reference = ctx.register_scalar_fn(op_name); + utils::build_scalar_function_expression(function_reference, arguments) + } + Expr::IsNull(e) => { + let arg = expression_from_df_expr(ctx, e, schema)?; + let arguments = utils::expression_to_argument(vec![arg]); + let op_name = "is_null"; + let function_reference = ctx.register_scalar_fn(op_name); + utils::build_scalar_function_expression(function_reference, arguments) + } + Expr::Negative(e) => { + let arg = expression_from_df_expr(ctx, e, schema)?; + let arguments = utils::expression_to_argument(vec![arg]); + let op_name = "negative"; + let function_reference = ctx.register_scalar_fn(op_name); + utils::build_scalar_function_expression(function_reference, arguments) + } + // Don't merge them with other unsupported expr arms to preserve the ordering. + Expr::GetIndexedField { .. } => UnsupportedExprSnafu { + name: expr.to_string(), + } + .fail()?, + Expr::Between { + expr, + negated, + low, + high, + } => { + let expr = expression_from_df_expr(ctx, expr, schema)?; + let low = expression_from_df_expr(ctx, low, schema)?; + let high = expression_from_df_expr(ctx, high, schema)?; + let arguments = utils::expression_to_argument(vec![expr, low, high]); + let op_name = if *negated { "not_between" } else { "between" }; + let function_reference = ctx.register_scalar_fn(op_name); + utils::build_scalar_function_expression(function_reference, arguments) + } + // Don't merge them with other unsupported expr arms to preserve the ordering. + Expr::Case { .. } | Expr::Cast { .. } | Expr::TryCast { .. } => UnsupportedExprSnafu { + name: expr.to_string(), + } + .fail()?, + Expr::Sort { + expr, + asc, + nulls_first: _, + } => { + let expr = expression_from_df_expr(ctx, expr, schema)?; + let arguments = utils::expression_to_argument(vec![expr]); + let op_name = if *asc { "sort_asc" } else { "sort_des" }; + let function_reference = ctx.register_scalar_fn(op_name); + utils::build_scalar_function_expression(function_reference, arguments) + } + Expr::ScalarFunction { fun, args } => { + let arguments = utils::expression_to_argument( + args.iter() + .map(|e| expression_from_df_expr(ctx, e, schema)) + .collect::>>()?, + ); + let op_name = utils::name_builtin_scalar_function(fun); + let function_reference = ctx.register_scalar_fn(op_name); + utils::build_scalar_function_expression(function_reference, arguments) + } + // Don't merge them with other unsupported expr arms to preserve the ordering. + Expr::ScalarUDF { .. } + | Expr::AggregateFunction { .. } + | Expr::WindowFunction { .. } + | Expr::AggregateUDF { .. } + | Expr::InList { .. } + | Expr::Wildcard => UnsupportedExprSnafu { + name: expr.to_string(), + } + .fail()?, + }; + + Ok(expression) +} + +/// Convert DataFusion's `Column` expr into substrait's `FieldReference` - +/// `DirectReference` - `StructField`. +pub fn convert_column(column: &Column, schema: &Schema) -> Result { + let column_name = &column.name; + let field_index = + schema + .column_index_by_name(column_name) + .with_context(|| MissingFieldSnafu { + field: format!("{:?}", column), + plan: format!("schema: {:?}", schema), + })?; + + Ok(FieldReference { + reference_type: Some(FieldReferenceType::DirectReference(ReferenceSegment { + reference_type: Some(SegReferenceType::StructField(Box::new(StructField { + field: field_index as _, + child: None, + }))), + })), + root_type: None, + }) +} + +/// Some utils special for this `DataFusion::Expr` and `Substrait::Expression` conversion. +mod utils { + use datafusion_expr::{BuiltinScalarFunction, Operator}; + use substrait_proto::protobuf::expression::{RexType, ScalarFunction}; + use substrait_proto::protobuf::function_argument::ArgType; + use substrait_proto::protobuf::{Expression, FunctionArgument}; + + pub(crate) fn name_df_operator(op: &Operator) -> &str { + match op { + Operator::Eq => "equal", + Operator::NotEq => "not_equal", + Operator::Lt => "lt", + Operator::LtEq => "lte", + Operator::Gt => "gt", + Operator::GtEq => "gte", + Operator::Plus => "plus", + Operator::Minus => "minus", + Operator::Multiply => "multiply", + Operator::Divide => "divide", + Operator::Modulo => "modulo", + Operator::And => "and", + Operator::Or => "or", + Operator::Like => "like", + Operator::NotLike => "not_like", + Operator::IsDistinctFrom => "is_distinct_from", + Operator::IsNotDistinctFrom => "is_not_distinct_from", + Operator::RegexMatch => "regex_match", + Operator::RegexIMatch => "regex_i_match", + Operator::RegexNotMatch => "regex_not_match", + Operator::RegexNotIMatch => "regex_not_i_match", + Operator::BitwiseAnd => "bitwise_and", + Operator::BitwiseOr => "bitwise_or", + } + } + + /// Convert list of [Expression] to [FunctionArgument] vector. + pub(crate) fn expression_to_argument>( + expressions: I, + ) -> Vec { + expressions + .into_iter() + .map(|expr| FunctionArgument { + arg_type: Some(ArgType::Value(expr)), + }) + .collect() + } + + /// Convenient builder for [Expression] + pub(crate) fn build_scalar_function_expression( + function_reference: u32, + arguments: Vec, + ) -> Expression { + Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference, + arguments, + output_type: None, + ..Default::default() + })), + } + } + + pub(crate) fn name_builtin_scalar_function(fun: &BuiltinScalarFunction) -> &str { + match fun { + BuiltinScalarFunction::Abs => "abs", + BuiltinScalarFunction::Acos => "acos", + BuiltinScalarFunction::Asin => "asin", + BuiltinScalarFunction::Atan => "atan", + BuiltinScalarFunction::Ceil => "ceil", + BuiltinScalarFunction::Cos => "cos", + BuiltinScalarFunction::Digest => "digest", + BuiltinScalarFunction::Exp => "exp", + BuiltinScalarFunction::Floor => "floor", + BuiltinScalarFunction::Ln => "ln", + BuiltinScalarFunction::Log => "log", + BuiltinScalarFunction::Log10 => "log10", + BuiltinScalarFunction::Log2 => "log2", + BuiltinScalarFunction::Round => "round", + BuiltinScalarFunction::Signum => "signum", + BuiltinScalarFunction::Sin => "sin", + BuiltinScalarFunction::Sqrt => "sqrt", + BuiltinScalarFunction::Tan => "tan", + BuiltinScalarFunction::Trunc => "trunc", + BuiltinScalarFunction::Array => "make_array", + BuiltinScalarFunction::Ascii => "ascii", + BuiltinScalarFunction::BitLength => "bit_length", + BuiltinScalarFunction::Btrim => "btrim", + BuiltinScalarFunction::CharacterLength => "character_length", + BuiltinScalarFunction::Chr => "chr", + BuiltinScalarFunction::Concat => "concat", + BuiltinScalarFunction::ConcatWithSeparator => "concat_ws", + BuiltinScalarFunction::DatePart => "date_part", + BuiltinScalarFunction::DateTrunc => "date_trunc", + BuiltinScalarFunction::InitCap => "initcap", + BuiltinScalarFunction::Left => "left", + BuiltinScalarFunction::Lpad => "lpad", + BuiltinScalarFunction::Lower => "lower", + BuiltinScalarFunction::Ltrim => "ltrim", + BuiltinScalarFunction::MD5 => "md5", + BuiltinScalarFunction::NullIf => "nullif", + BuiltinScalarFunction::OctetLength => "octet_length", + BuiltinScalarFunction::Random => "random", + BuiltinScalarFunction::RegexpReplace => "regexp_replace", + BuiltinScalarFunction::Repeat => "repeat", + BuiltinScalarFunction::Replace => "replace", + BuiltinScalarFunction::Reverse => "reverse", + BuiltinScalarFunction::Right => "right", + BuiltinScalarFunction::Rpad => "rpad", + BuiltinScalarFunction::Rtrim => "rtrim", + BuiltinScalarFunction::SHA224 => "sha224", + BuiltinScalarFunction::SHA256 => "sha256", + BuiltinScalarFunction::SHA384 => "sha384", + BuiltinScalarFunction::SHA512 => "sha512", + BuiltinScalarFunction::SplitPart => "split_part", + BuiltinScalarFunction::StartsWith => "starts_with", + BuiltinScalarFunction::Strpos => "strpos", + BuiltinScalarFunction::Substr => "substr", + BuiltinScalarFunction::ToHex => "to_hex", + BuiltinScalarFunction::ToTimestamp => "to_timestamp", + BuiltinScalarFunction::ToTimestampMillis => "to_timestamp_millis", + BuiltinScalarFunction::ToTimestampMicros => "to_timestamp_macros", + BuiltinScalarFunction::ToTimestampSeconds => "to_timestamp_seconds", + BuiltinScalarFunction::Now => "now", + BuiltinScalarFunction::Translate => "translate", + BuiltinScalarFunction::Trim => "trim", + BuiltinScalarFunction::Upper => "upper", + BuiltinScalarFunction::RegexpMatch => "regexp_match", + } + } +} + +#[cfg(test)] +mod test { + use datatypes::schema::ColumnSchema; + + use super::*; + + #[test] + fn expr_round_trip() { + let expr = expr_fn::and( + expr_fn::col("column_a").lt_eq(expr_fn::col("column_b")), + expr_fn::col("column_a").gt(expr_fn::col("column_b")), + ); + + let schema = Schema::new(vec![ + ColumnSchema::new( + "column_a", + datatypes::data_type::ConcreteDataType::int64_datatype(), + true, + ), + ColumnSchema::new( + "column_b", + datatypes::data_type::ConcreteDataType::float64_datatype(), + true, + ), + ]); + + let mut ctx = ConvertorContext::default(); + let substrait_expr = expression_from_df_expr(&mut ctx, &expr, &schema).unwrap(); + let converted_expr = to_df_expr(&ctx, substrait_expr, &schema).unwrap(); + + assert_eq!(expr, converted_expr); + } +} diff --git a/src/common/substrait/src/df_logical.rs b/src/common/substrait/src/df_logical.rs index 6f0573144c..8d53ef1b08 100644 --- a/src/common/substrait/src/df_logical.rs +++ b/src/common/substrait/src/df_logical.rs @@ -17,6 +17,7 @@ use std::sync::Arc; use bytes::{Buf, Bytes, BytesMut}; use catalog::CatalogManagerRef; use common_error::prelude::BoxedError; +use common_telemetry::debug; use datafusion::datasource::TableProvider; use datafusion::logical_plan::{LogicalPlan, TableScan, ToDFSchema}; use datafusion::physical_plan::project_schema; @@ -24,12 +25,15 @@ use prost::Message; use snafu::{ensure, OptionExt, ResultExt}; use substrait_proto::protobuf::expression::mask_expression::{StructItem, StructSelect}; use substrait_proto::protobuf::expression::MaskExpression; +use substrait_proto::protobuf::extensions::simple_extension_declaration::MappingType; use substrait_proto::protobuf::plan_rel::RelType as PlanRelType; use substrait_proto::protobuf::read_rel::{NamedTable, ReadType}; use substrait_proto::protobuf::rel::RelType; -use substrait_proto::protobuf::{PlanRel, ReadRel, Rel}; +use substrait_proto::protobuf::{Plan, PlanRel, ReadRel, Rel}; use table::table::adapter::DfTableProviderAdapter; +use crate::context::ConvertorContext; +use crate::df_expr::{expression_from_df_expr, to_df_expr}; use crate::error::{ DFInternalSnafu, DecodeRelSnafu, EmptyPlanSnafu, EncodeRelSnafu, Error, InternalSnafu, InvalidParametersSnafu, MissingFieldSnafu, SchemaNotMatchSnafu, TableNotFoundSnafu, @@ -48,25 +52,15 @@ impl SubstraitPlan for DFLogicalSubstraitConvertor { type Plan = LogicalPlan; fn decode(&self, message: B) -> Result { - let plan_rel = PlanRel::decode(message).context(DecodeRelSnafu)?; - let rel = match plan_rel.rel_type.context(EmptyPlanSnafu)? { - PlanRelType::Rel(rel) => rel, - PlanRelType::Root(_) => UnsupportedPlanSnafu { - name: "Root Relation", - } - .fail()?, - }; - self.convert_rel(rel) + let plan = Plan::decode(message).context(DecodeRelSnafu)?; + self.convert_plan(plan) } fn encode(&self, plan: Self::Plan) -> Result { - let rel = self.convert_plan(plan)?; - let plan_rel = PlanRel { - rel_type: Some(PlanRelType::Rel(rel)), - }; + let plan = self.convert_df_plan(plan)?; let mut buf = BytesMut::new(); - plan_rel.encode(&mut buf).context(EncodeRelSnafu)?; + plan.encode(&mut buf).context(EncodeRelSnafu)?; Ok(buf.freeze()) } @@ -79,10 +73,37 @@ impl DFLogicalSubstraitConvertor { } impl DFLogicalSubstraitConvertor { - pub fn convert_rel(&self, rel: Rel) -> Result { + pub fn convert_plan(&self, mut plan: Plan) -> Result { + // prepare convertor context + let mut ctx = ConvertorContext::default(); + for simple_ext in plan.extensions { + if let Some(MappingType::ExtensionFunction(function_extension)) = + simple_ext.mapping_type + { + ctx.register_scalar_with_anchor( + function_extension.name, + function_extension.function_anchor, + ); + } else { + debug!("Encounter unsupported substrait extension {:?}", simple_ext); + } + } + + // extract rel + let rel = if let Some(PlanRel { rel_type }) = plan.relations.pop() + && let Some(PlanRelType::Rel(rel)) = rel_type { + rel + } else { + UnsupportedPlanSnafu { + name: "Emply or non-Rel relation", + } + .fail()? + }; let rel_type = rel.rel_type.context(EmptyPlanSnafu)?; + + // build logical plan let logical_plan = match rel_type { - RelType::Read(read_rel) => self.convert_read_rel(read_rel), + RelType::Read(read_rel) => self.convert_read_rel(&mut ctx, read_rel), RelType::Filter(_filter_rel) => UnsupportedPlanSnafu { name: "Filter Relation", } @@ -132,9 +153,12 @@ impl DFLogicalSubstraitConvertor { Ok(logical_plan) } - fn convert_read_rel(&self, read_rel: Box) -> Result { + fn convert_read_rel( + &self, + ctx: &mut ConvertorContext, + read_rel: Box, + ) -> Result { // Extract the catalog, schema and table name from NamedTable. Assume the first three are those names. - let read_type = read_rel.read_type.context(MissingFieldSnafu { field: "read_type", plan: "Read", @@ -190,6 +214,13 @@ impl DFLogicalSubstraitConvertor { } ); + // Convert filter + let filters = if let Some(filter) = read_rel.filter { + vec![to_df_expr(ctx, *filter, &retrieved_schema)?] + } else { + vec![] + }; + // Calculate the projected schema let projected_schema = project_schema(&stored_schema, projection.as_ref()) .context(DFInternalSnafu)? @@ -202,7 +233,7 @@ impl DFLogicalSubstraitConvertor { source: adapter, projection, projected_schema, - filters: vec![], + filters, limit: None, })) } @@ -219,8 +250,12 @@ impl DFLogicalSubstraitConvertor { } impl DFLogicalSubstraitConvertor { - pub fn convert_plan(&self, plan: LogicalPlan) -> Result { - match plan { + pub fn convert_df_plan(&self, plan: LogicalPlan) -> Result { + let mut ctx = ConvertorContext::default(); + + // TODO(ruihang): extract this translation logic into a separated function + // convert PlanRel + let rel = match plan { LogicalPlan::Projection(_) => UnsupportedPlanSnafu { name: "DataFusion Logical Projection", } @@ -258,10 +293,10 @@ impl DFLogicalSubstraitConvertor { } .fail()?, LogicalPlan::TableScan(table_scan) => { - let read_rel = self.convert_table_scan_plan(table_scan)?; - Ok(Rel { + let read_rel = self.convert_table_scan_plan(&mut ctx, table_scan)?; + Rel { rel_type: Some(RelType::Read(Box::new(read_rel))), - }) + } } LogicalPlan::EmptyRelation(_) => UnsupportedPlanSnafu { name: "DataFusion Logical EmptyRelation", @@ -284,10 +319,30 @@ impl DFLogicalSubstraitConvertor { ), } .fail()?, - } + }; + + // convert extension + let extensions = ctx.generate_function_extension(); + + // assemble PlanRel + let plan_rel = PlanRel { + rel_type: Some(PlanRelType::Rel(rel)), + }; + + Ok(Plan { + extension_uris: vec![], + extensions, + relations: vec![plan_rel], + advanced_extensions: None, + expected_type_urls: vec![], + }) } - pub fn convert_table_scan_plan(&self, table_scan: TableScan) -> Result { + pub fn convert_table_scan_plan( + &self, + ctx: &mut ConvertorContext, + table_scan: TableScan, + ) -> Result { let provider = table_scan .source .as_any() @@ -313,10 +368,25 @@ impl DFLogicalSubstraitConvertor { // assemble base (unprojected) schema using Table's schema. let base_schema = from_schema(&provider.table().schema())?; + // make conjunction over a list of filters and convert the result to substrait + let filter = if let Some(conjunction) = table_scan + .filters + .into_iter() + .reduce(|accum, expr| accum.and(expr)) + { + Some(Box::new(expression_from_df_expr( + ctx, + &conjunction, + &provider.table().schema(), + )?)) + } else { + None + }; + let read_rel = ReadRel { common: None, base_schema: Some(base_schema), - filter: None, + filter, projection, advanced_extension: None, read_type: Some(read_type), diff --git a/src/common/substrait/src/error.rs b/src/common/substrait/src/error.rs index 74e2112a91..c33b3679fb 100644 --- a/src/common/substrait/src/error.rs +++ b/src/common/substrait/src/error.rs @@ -23,10 +23,10 @@ use snafu::{Backtrace, ErrorCompat, Snafu}; #[derive(Debug, Snafu)] #[snafu(visibility(pub))] pub enum Error { - #[snafu(display("Unsupported physical expr: {}", name))] + #[snafu(display("Unsupported physical plan: {}", name))] UnsupportedPlan { name: String, backtrace: Backtrace }, - #[snafu(display("Unsupported physical plan: {}", name))] + #[snafu(display("Unsupported expr: {}", name))] UnsupportedExpr { name: String, backtrace: Backtrace }, #[snafu(display("Unsupported concrete type: {:?}", ty))] diff --git a/src/common/substrait/src/lib.rs b/src/common/substrait/src/lib.rs index 137796b527..c318799a3b 100644 --- a/src/common/substrait/src/lib.rs +++ b/src/common/substrait/src/lib.rs @@ -12,6 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#![feature(let_chains)] + +mod context; +mod df_expr; mod df_logical; pub mod error; mod schema;