feat: implement most of scalar function and selection conversion in substrait (#678)

* impl to_df_scalar_function

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* part of scalar functions

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* conjunction over filters

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* change the ser/de target to substrait::Plan

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* basic test coverage

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* fix typos and license header

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* fix clippy

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* fix CR comments

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* logs unsupported extension

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* Update src/common/substrait/src/df_expr.rs

Co-authored-by: Yingwen <realevenyag@gmail.com>

* address review comments

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* change format

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* replace context with with_context

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
Co-authored-by: Yingwen <realevenyag@gmail.com>
This commit is contained in:
Ruihang Xia
2022-12-02 14:46:05 +08:00
committed by GitHub
parent 0599465685
commit d82a3a7d58
7 changed files with 916 additions and 30 deletions

2
Cargo.lock generated
View File

@@ -6059,7 +6059,9 @@ dependencies = [
"catalog",
"common-catalog",
"common-error",
"common-telemetry",
"datafusion",
"datafusion-expr",
"datatypes",
"futures",
"prost 0.9.0",

View File

@@ -9,9 +9,11 @@ bytes = "1.1"
catalog = { path = "../../catalog" }
common-catalog = { path = "../catalog" }
common-error = { path = "../error" }
common-telemetry = { path = "../telemetry" }
datafusion = { git = "https://github.com/apache/arrow-datafusion.git", branch = "arrow2", features = [
"simd",
] }
datafusion-expr = { git = "https://github.com/apache/arrow-datafusion.git", branch = "arrow2" }
datatypes = { path = "../../datatypes" }
futures = "0.3"
prost = "0.9"

View File

@@ -0,0 +1,66 @@
// Copyright 2022 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use substrait_proto::protobuf::extensions::simple_extension_declaration::{
ExtensionFunction, MappingType,
};
use substrait_proto::protobuf::extensions::SimpleExtensionDeclaration;
#[derive(Default)]
pub struct ConvertorContext {
scalar_fn_names: HashMap<String, u32>,
scalar_fn_map: HashMap<u32, String>,
}
impl ConvertorContext {
pub fn register_scalar_fn<S: AsRef<str>>(&mut self, name: S) -> u32 {
if let Some(anchor) = self.scalar_fn_names.get(name.as_ref()) {
return *anchor;
}
let next_anchor = self.scalar_fn_map.len() as _;
self.scalar_fn_map
.insert(next_anchor, name.as_ref().to_string());
self.scalar_fn_names
.insert(name.as_ref().to_string(), next_anchor);
next_anchor
}
pub fn register_scalar_with_anchor<S: AsRef<str>>(&mut self, name: S, anchor: u32) {
self.scalar_fn_map.insert(anchor, name.as_ref().to_string());
self.scalar_fn_names
.insert(name.as_ref().to_string(), anchor);
}
pub fn find_scalar_fn(&self, anchor: u32) -> Option<&str> {
self.scalar_fn_map.get(&anchor).map(|s| s.as_str())
}
pub fn generate_function_extension(&self) -> Vec<SimpleExtensionDeclaration> {
let mut result = Vec::with_capacity(self.scalar_fn_map.len());
for (anchor, name) in &self.scalar_fn_map {
let declaration = SimpleExtensionDeclaration {
mapping_type: Some(MappingType::ExtensionFunction(ExtensionFunction {
extension_uri_reference: 0,
function_anchor: *anchor,
name: name.clone(),
})),
};
result.push(declaration);
}
result
}
}

View File

@@ -0,0 +1,742 @@
// Copyright 2022 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::VecDeque;
use std::str::FromStr;
use datafusion::logical_plan::{Column, Expr};
use datafusion_expr::{expr_fn, BuiltinScalarFunction, Operator};
use datatypes::schema::Schema;
use snafu::{ensure, OptionExt};
use substrait_proto::protobuf::expression::field_reference::ReferenceType as FieldReferenceType;
use substrait_proto::protobuf::expression::reference_segment::{
ReferenceType as SegReferenceType, StructField,
};
use substrait_proto::protobuf::expression::{
FieldReference, ReferenceSegment, RexType, ScalarFunction,
};
use substrait_proto::protobuf::function_argument::ArgType;
use substrait_proto::protobuf::Expression;
use crate::context::ConvertorContext;
use crate::error::{
EmptyExprSnafu, InvalidParametersSnafu, MissingFieldSnafu, Result, UnsupportedExprSnafu,
};
/// Convert substrait's `Expression` to DataFusion's `Expr`.
pub fn to_df_expr(ctx: &ConvertorContext, expression: Expression, schema: &Schema) -> Result<Expr> {
let expr_rex_type = expression.rex_type.context(EmptyExprSnafu)?;
match expr_rex_type {
RexType::Literal(_) => UnsupportedExprSnafu {
name: "substrait Literal expression",
}
.fail()?,
RexType::Selection(selection) => convert_selection_rex(*selection, schema),
RexType::ScalarFunction(scalar_fn) => convert_scalar_function(ctx, scalar_fn, schema),
RexType::WindowFunction(_)
| RexType::IfThen(_)
| RexType::SwitchExpression(_)
| RexType::SingularOrList(_)
| RexType::MultiOrList(_)
| RexType::Cast(_)
| RexType::Subquery(_)
| RexType::Enum(_) => UnsupportedExprSnafu {
name: format!("substrait expression {:?}", expr_rex_type),
}
.fail()?,
}
}
/// Convert Substrait's `FieldReference` - `DirectReference` - `StructField` to Datafusion's
/// `Column` expr.
pub fn convert_selection_rex(selection: FieldReference, schema: &Schema) -> Result<Expr> {
if let Some(FieldReferenceType::DirectReference(direct_ref)) = selection.reference_type
&& let Some(SegReferenceType::StructField(field)) = direct_ref.reference_type {
let column_name = schema.column_name_by_index(field.field as _).to_string();
Ok(Expr::Column(Column {
relation: None,
name: column_name,
}))
} else {
InvalidParametersSnafu {
reason: "Only support direct struct reference in Selection Rex",
}
.fail()
}
}
pub fn convert_scalar_function(
ctx: &ConvertorContext,
scalar_fn: ScalarFunction,
schema: &Schema,
) -> Result<Expr> {
// convert argument
let mut inputs = VecDeque::with_capacity(scalar_fn.arguments.len());
for arg in scalar_fn.arguments {
if let Some(ArgType::Value(sub_expr)) = arg.arg_type {
inputs.push_back(to_df_expr(ctx, sub_expr, schema)?);
} else {
InvalidParametersSnafu {
reason: "Only value expression arg is supported to be function argument",
}
.fail()?;
}
}
// convert this scalar function
// map function name
let anchor = scalar_fn.function_reference;
let fn_name = ctx
.find_scalar_fn(anchor)
.with_context(|| InvalidParametersSnafu {
reason: format!("Unregistered scalar function reference: {}", anchor),
})?;
// convenient util
let ensure_arg_len = |expected: usize| -> Result<()> {
ensure!(
inputs.len() == expected,
InvalidParametersSnafu {
reason: format!(
"Invalid number of scalar function {}, expected {} but found {}",
fn_name,
expected,
inputs.len()
)
}
);
Ok(())
};
// construct DataFusion expr
let expr = match fn_name {
// begin binary exprs, with the same order of DF `Operator`'s definition.
"eq" | "equal" => {
ensure_arg_len(2)?;
inputs.pop_front().unwrap().eq(inputs.pop_front().unwrap())
}
"not_eq" | "not_equal" => {
ensure_arg_len(2)?;
inputs
.pop_front()
.unwrap()
.not_eq(inputs.pop_front().unwrap())
}
"lt" => {
ensure_arg_len(2)?;
inputs.pop_front().unwrap().lt(inputs.pop_front().unwrap())
}
"lt_eq" | "lte" => {
ensure_arg_len(2)?;
inputs
.pop_front()
.unwrap()
.lt_eq(inputs.pop_front().unwrap())
}
"gt" => {
ensure_arg_len(2)?;
inputs.pop_front().unwrap().gt(inputs.pop_front().unwrap())
}
"gt_eq" | "gte" => {
ensure_arg_len(2)?;
inputs
.pop_front()
.unwrap()
.gt_eq(inputs.pop_front().unwrap())
}
"plus" => {
ensure_arg_len(2)?;
expr_fn::binary_expr(
inputs.pop_front().unwrap(),
Operator::Plus,
inputs.pop_front().unwrap(),
)
}
"minus" => {
ensure_arg_len(2)?;
expr_fn::binary_expr(
inputs.pop_front().unwrap(),
Operator::Minus,
inputs.pop_front().unwrap(),
)
}
"multiply" => {
ensure_arg_len(2)?;
expr_fn::binary_expr(
inputs.pop_front().unwrap(),
Operator::Multiply,
inputs.pop_front().unwrap(),
)
}
"divide" => {
ensure_arg_len(2)?;
expr_fn::binary_expr(
inputs.pop_front().unwrap(),
Operator::Divide,
inputs.pop_front().unwrap(),
)
}
"modulo" => {
ensure_arg_len(2)?;
expr_fn::binary_expr(
inputs.pop_front().unwrap(),
Operator::Modulo,
inputs.pop_front().unwrap(),
)
}
"and" => {
ensure_arg_len(2)?;
expr_fn::and(inputs.pop_front().unwrap(), inputs.pop_front().unwrap())
}
"or" => {
ensure_arg_len(2)?;
expr_fn::or(inputs.pop_front().unwrap(), inputs.pop_front().unwrap())
}
"like" => {
ensure_arg_len(2)?;
inputs
.pop_front()
.unwrap()
.like(inputs.pop_front().unwrap())
}
"not_like" => {
ensure_arg_len(2)?;
inputs
.pop_front()
.unwrap()
.not_like(inputs.pop_front().unwrap())
}
"is_distinct_from" => {
ensure_arg_len(2)?;
expr_fn::binary_expr(
inputs.pop_front().unwrap(),
Operator::IsDistinctFrom,
inputs.pop_front().unwrap(),
)
}
"is_not_distinct_from" => {
ensure_arg_len(2)?;
expr_fn::binary_expr(
inputs.pop_front().unwrap(),
Operator::IsNotDistinctFrom,
inputs.pop_front().unwrap(),
)
}
"regex_match" => {
ensure_arg_len(2)?;
expr_fn::binary_expr(
inputs.pop_front().unwrap(),
Operator::RegexMatch,
inputs.pop_front().unwrap(),
)
}
"regex_i_match" => {
ensure_arg_len(2)?;
expr_fn::binary_expr(
inputs.pop_front().unwrap(),
Operator::RegexIMatch,
inputs.pop_front().unwrap(),
)
}
"regex_not_match" => {
ensure_arg_len(2)?;
expr_fn::binary_expr(
inputs.pop_front().unwrap(),
Operator::RegexNotMatch,
inputs.pop_front().unwrap(),
)
}
"regex_not_i_match" => {
ensure_arg_len(2)?;
expr_fn::binary_expr(
inputs.pop_front().unwrap(),
Operator::RegexNotIMatch,
inputs.pop_front().unwrap(),
)
}
"bitwise_and" => {
ensure_arg_len(2)?;
expr_fn::binary_expr(
inputs.pop_front().unwrap(),
Operator::BitwiseAnd,
inputs.pop_front().unwrap(),
)
}
"bitwise_or" => {
ensure_arg_len(2)?;
expr_fn::binary_expr(
inputs.pop_front().unwrap(),
Operator::BitwiseOr,
inputs.pop_front().unwrap(),
)
}
// end binary exprs
// start other direct expr, with the same order of DF `Expr`'s definition.
"not" => {
ensure_arg_len(1)?;
inputs.pop_front().unwrap().not()
}
"is_not_null" => {
ensure_arg_len(1)?;
inputs.pop_front().unwrap().is_not_null()
}
"is_null" => {
ensure_arg_len(1)?;
inputs.pop_front().unwrap().is_null()
}
"negative" => {
ensure_arg_len(1)?;
Expr::Negative(Box::new(inputs.pop_front().unwrap()))
}
// skip GetIndexedField, unimplemented.
"between" => {
ensure_arg_len(3)?;
Expr::Between {
expr: Box::new(inputs.pop_front().unwrap()),
negated: false,
low: Box::new(inputs.pop_front().unwrap()),
high: Box::new(inputs.pop_front().unwrap()),
}
}
"not_between" => {
ensure_arg_len(3)?;
Expr::Between {
expr: Box::new(inputs.pop_front().unwrap()),
negated: true,
low: Box::new(inputs.pop_front().unwrap()),
high: Box::new(inputs.pop_front().unwrap()),
}
}
// skip Case, is covered in substrait::SwitchExpression.
// skip Cast and TryCast, is covered in substrait::Cast.
"sort" | "sort_des" => {
ensure_arg_len(1)?;
Expr::Sort {
expr: Box::new(inputs.pop_front().unwrap()),
asc: false,
nulls_first: false,
}
}
"sort_asc" => {
ensure_arg_len(1)?;
Expr::Sort {
expr: Box::new(inputs.pop_front().unwrap()),
asc: true,
nulls_first: false,
}
}
// those are datafusion built-in "scalar functions".
"abs"
| "acos"
| "asin"
| "atan"
| "atan2"
| "ceil"
| "cos"
| "exp"
| "floor"
| "ln"
| "log"
| "log10"
| "log2"
| "power"
| "pow"
| "round"
| "signum"
| "sin"
| "sqrt"
| "tan"
| "trunc"
| "coalesce"
| "make_array"
| "ascii"
| "bit_length"
| "btrim"
| "char_length"
| "character_length"
| "concat"
| "concat_ws"
| "chr"
| "current_date"
| "current_time"
| "date_part"
| "datepart"
| "date_trunc"
| "datetrunc"
| "date_bin"
| "initcap"
| "left"
| "length"
| "lower"
| "lpad"
| "ltrim"
| "md5"
| "nullif"
| "octet_length"
| "random"
| "regexp_replace"
| "repeat"
| "replace"
| "reverse"
| "right"
| "rpad"
| "rtrim"
| "sha224"
| "sha256"
| "sha384"
| "sha512"
| "digest"
| "split_part"
| "starts_with"
| "strpos"
| "substr"
| "to_hex"
| "to_timestamp"
| "to_timestamp_millis"
| "to_timestamp_micros"
| "to_timestamp_seconds"
| "now"
| "translate"
| "trim"
| "upper"
| "uuid"
| "regexp_match"
| "struct"
| "from_unixtime"
| "arrow_typeof" => Expr::ScalarFunction {
fun: BuiltinScalarFunction::from_str(fn_name).unwrap(),
args: inputs.into(),
},
// skip ScalarUDF, unimplemented.
// skip AggregateFunction, is covered in substrait::AggregateRel
// skip WindowFunction, is covered in substrait WindowFunction
// skip AggregateUDF, unimplemented.
// skip InList, unimplemented
// skip Wildcard, unimplemented.
// end other direct expr
_ => UnsupportedExprSnafu {
name: format!("scalar function {}", fn_name),
}
.fail()?,
};
Ok(expr)
}
/// Convert DataFusion's `Expr` to substrait's `Expression`
pub fn expression_from_df_expr(
ctx: &mut ConvertorContext,
expr: &Expr,
schema: &Schema,
) -> Result<Expression> {
let expression = match expr {
// Don't merge them with other unsupported expr arms to preserve the ordering.
Expr::Alias(..) => UnsupportedExprSnafu {
name: expr.to_string(),
}
.fail()?,
Expr::Column(column) => {
let field_reference = convert_column(column, schema)?;
Expression {
rex_type: Some(RexType::Selection(Box::new(field_reference))),
}
}
// Don't merge them with other unsupported expr arms to preserve the ordering.
Expr::ScalarVariable(..) | Expr::Literal(..) => UnsupportedExprSnafu {
name: expr.to_string(),
}
.fail()?,
Expr::BinaryExpr { left, op, right } => {
let left = expression_from_df_expr(ctx, left, schema)?;
let right = expression_from_df_expr(ctx, right, schema)?;
let arguments = utils::expression_to_argument(vec![left, right]);
let op_name = utils::name_df_operator(op);
let function_reference = ctx.register_scalar_fn(op_name);
utils::build_scalar_function_expression(function_reference, arguments)
}
Expr::Not(e) => {
let arg = expression_from_df_expr(ctx, e, schema)?;
let arguments = utils::expression_to_argument(vec![arg]);
let op_name = "not";
let function_reference = ctx.register_scalar_fn(op_name);
utils::build_scalar_function_expression(function_reference, arguments)
}
Expr::IsNotNull(e) => {
let arg = expression_from_df_expr(ctx, e, schema)?;
let arguments = utils::expression_to_argument(vec![arg]);
let op_name = "is_not_null";
let function_reference = ctx.register_scalar_fn(op_name);
utils::build_scalar_function_expression(function_reference, arguments)
}
Expr::IsNull(e) => {
let arg = expression_from_df_expr(ctx, e, schema)?;
let arguments = utils::expression_to_argument(vec![arg]);
let op_name = "is_null";
let function_reference = ctx.register_scalar_fn(op_name);
utils::build_scalar_function_expression(function_reference, arguments)
}
Expr::Negative(e) => {
let arg = expression_from_df_expr(ctx, e, schema)?;
let arguments = utils::expression_to_argument(vec![arg]);
let op_name = "negative";
let function_reference = ctx.register_scalar_fn(op_name);
utils::build_scalar_function_expression(function_reference, arguments)
}
// Don't merge them with other unsupported expr arms to preserve the ordering.
Expr::GetIndexedField { .. } => UnsupportedExprSnafu {
name: expr.to_string(),
}
.fail()?,
Expr::Between {
expr,
negated,
low,
high,
} => {
let expr = expression_from_df_expr(ctx, expr, schema)?;
let low = expression_from_df_expr(ctx, low, schema)?;
let high = expression_from_df_expr(ctx, high, schema)?;
let arguments = utils::expression_to_argument(vec![expr, low, high]);
let op_name = if *negated { "not_between" } else { "between" };
let function_reference = ctx.register_scalar_fn(op_name);
utils::build_scalar_function_expression(function_reference, arguments)
}
// Don't merge them with other unsupported expr arms to preserve the ordering.
Expr::Case { .. } | Expr::Cast { .. } | Expr::TryCast { .. } => UnsupportedExprSnafu {
name: expr.to_string(),
}
.fail()?,
Expr::Sort {
expr,
asc,
nulls_first: _,
} => {
let expr = expression_from_df_expr(ctx, expr, schema)?;
let arguments = utils::expression_to_argument(vec![expr]);
let op_name = if *asc { "sort_asc" } else { "sort_des" };
let function_reference = ctx.register_scalar_fn(op_name);
utils::build_scalar_function_expression(function_reference, arguments)
}
Expr::ScalarFunction { fun, args } => {
let arguments = utils::expression_to_argument(
args.iter()
.map(|e| expression_from_df_expr(ctx, e, schema))
.collect::<Result<Vec<_>>>()?,
);
let op_name = utils::name_builtin_scalar_function(fun);
let function_reference = ctx.register_scalar_fn(op_name);
utils::build_scalar_function_expression(function_reference, arguments)
}
// Don't merge them with other unsupported expr arms to preserve the ordering.
Expr::ScalarUDF { .. }
| Expr::AggregateFunction { .. }
| Expr::WindowFunction { .. }
| Expr::AggregateUDF { .. }
| Expr::InList { .. }
| Expr::Wildcard => UnsupportedExprSnafu {
name: expr.to_string(),
}
.fail()?,
};
Ok(expression)
}
/// Convert DataFusion's `Column` expr into substrait's `FieldReference` -
/// `DirectReference` - `StructField`.
pub fn convert_column(column: &Column, schema: &Schema) -> Result<FieldReference> {
let column_name = &column.name;
let field_index =
schema
.column_index_by_name(column_name)
.with_context(|| MissingFieldSnafu {
field: format!("{:?}", column),
plan: format!("schema: {:?}", schema),
})?;
Ok(FieldReference {
reference_type: Some(FieldReferenceType::DirectReference(ReferenceSegment {
reference_type: Some(SegReferenceType::StructField(Box::new(StructField {
field: field_index as _,
child: None,
}))),
})),
root_type: None,
})
}
/// Some utils special for this `DataFusion::Expr` and `Substrait::Expression` conversion.
mod utils {
use datafusion_expr::{BuiltinScalarFunction, Operator};
use substrait_proto::protobuf::expression::{RexType, ScalarFunction};
use substrait_proto::protobuf::function_argument::ArgType;
use substrait_proto::protobuf::{Expression, FunctionArgument};
pub(crate) fn name_df_operator(op: &Operator) -> &str {
match op {
Operator::Eq => "equal",
Operator::NotEq => "not_equal",
Operator::Lt => "lt",
Operator::LtEq => "lte",
Operator::Gt => "gt",
Operator::GtEq => "gte",
Operator::Plus => "plus",
Operator::Minus => "minus",
Operator::Multiply => "multiply",
Operator::Divide => "divide",
Operator::Modulo => "modulo",
Operator::And => "and",
Operator::Or => "or",
Operator::Like => "like",
Operator::NotLike => "not_like",
Operator::IsDistinctFrom => "is_distinct_from",
Operator::IsNotDistinctFrom => "is_not_distinct_from",
Operator::RegexMatch => "regex_match",
Operator::RegexIMatch => "regex_i_match",
Operator::RegexNotMatch => "regex_not_match",
Operator::RegexNotIMatch => "regex_not_i_match",
Operator::BitwiseAnd => "bitwise_and",
Operator::BitwiseOr => "bitwise_or",
}
}
/// Convert list of [Expression] to [FunctionArgument] vector.
pub(crate) fn expression_to_argument<I: IntoIterator<Item = Expression>>(
expressions: I,
) -> Vec<FunctionArgument> {
expressions
.into_iter()
.map(|expr| FunctionArgument {
arg_type: Some(ArgType::Value(expr)),
})
.collect()
}
/// Convenient builder for [Expression]
pub(crate) fn build_scalar_function_expression(
function_reference: u32,
arguments: Vec<FunctionArgument>,
) -> Expression {
Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
function_reference,
arguments,
output_type: None,
..Default::default()
})),
}
}
pub(crate) fn name_builtin_scalar_function(fun: &BuiltinScalarFunction) -> &str {
match fun {
BuiltinScalarFunction::Abs => "abs",
BuiltinScalarFunction::Acos => "acos",
BuiltinScalarFunction::Asin => "asin",
BuiltinScalarFunction::Atan => "atan",
BuiltinScalarFunction::Ceil => "ceil",
BuiltinScalarFunction::Cos => "cos",
BuiltinScalarFunction::Digest => "digest",
BuiltinScalarFunction::Exp => "exp",
BuiltinScalarFunction::Floor => "floor",
BuiltinScalarFunction::Ln => "ln",
BuiltinScalarFunction::Log => "log",
BuiltinScalarFunction::Log10 => "log10",
BuiltinScalarFunction::Log2 => "log2",
BuiltinScalarFunction::Round => "round",
BuiltinScalarFunction::Signum => "signum",
BuiltinScalarFunction::Sin => "sin",
BuiltinScalarFunction::Sqrt => "sqrt",
BuiltinScalarFunction::Tan => "tan",
BuiltinScalarFunction::Trunc => "trunc",
BuiltinScalarFunction::Array => "make_array",
BuiltinScalarFunction::Ascii => "ascii",
BuiltinScalarFunction::BitLength => "bit_length",
BuiltinScalarFunction::Btrim => "btrim",
BuiltinScalarFunction::CharacterLength => "character_length",
BuiltinScalarFunction::Chr => "chr",
BuiltinScalarFunction::Concat => "concat",
BuiltinScalarFunction::ConcatWithSeparator => "concat_ws",
BuiltinScalarFunction::DatePart => "date_part",
BuiltinScalarFunction::DateTrunc => "date_trunc",
BuiltinScalarFunction::InitCap => "initcap",
BuiltinScalarFunction::Left => "left",
BuiltinScalarFunction::Lpad => "lpad",
BuiltinScalarFunction::Lower => "lower",
BuiltinScalarFunction::Ltrim => "ltrim",
BuiltinScalarFunction::MD5 => "md5",
BuiltinScalarFunction::NullIf => "nullif",
BuiltinScalarFunction::OctetLength => "octet_length",
BuiltinScalarFunction::Random => "random",
BuiltinScalarFunction::RegexpReplace => "regexp_replace",
BuiltinScalarFunction::Repeat => "repeat",
BuiltinScalarFunction::Replace => "replace",
BuiltinScalarFunction::Reverse => "reverse",
BuiltinScalarFunction::Right => "right",
BuiltinScalarFunction::Rpad => "rpad",
BuiltinScalarFunction::Rtrim => "rtrim",
BuiltinScalarFunction::SHA224 => "sha224",
BuiltinScalarFunction::SHA256 => "sha256",
BuiltinScalarFunction::SHA384 => "sha384",
BuiltinScalarFunction::SHA512 => "sha512",
BuiltinScalarFunction::SplitPart => "split_part",
BuiltinScalarFunction::StartsWith => "starts_with",
BuiltinScalarFunction::Strpos => "strpos",
BuiltinScalarFunction::Substr => "substr",
BuiltinScalarFunction::ToHex => "to_hex",
BuiltinScalarFunction::ToTimestamp => "to_timestamp",
BuiltinScalarFunction::ToTimestampMillis => "to_timestamp_millis",
BuiltinScalarFunction::ToTimestampMicros => "to_timestamp_macros",
BuiltinScalarFunction::ToTimestampSeconds => "to_timestamp_seconds",
BuiltinScalarFunction::Now => "now",
BuiltinScalarFunction::Translate => "translate",
BuiltinScalarFunction::Trim => "trim",
BuiltinScalarFunction::Upper => "upper",
BuiltinScalarFunction::RegexpMatch => "regexp_match",
}
}
}
#[cfg(test)]
mod test {
use datatypes::schema::ColumnSchema;
use super::*;
#[test]
fn expr_round_trip() {
let expr = expr_fn::and(
expr_fn::col("column_a").lt_eq(expr_fn::col("column_b")),
expr_fn::col("column_a").gt(expr_fn::col("column_b")),
);
let schema = Schema::new(vec![
ColumnSchema::new(
"column_a",
datatypes::data_type::ConcreteDataType::int64_datatype(),
true,
),
ColumnSchema::new(
"column_b",
datatypes::data_type::ConcreteDataType::float64_datatype(),
true,
),
]);
let mut ctx = ConvertorContext::default();
let substrait_expr = expression_from_df_expr(&mut ctx, &expr, &schema).unwrap();
let converted_expr = to_df_expr(&ctx, substrait_expr, &schema).unwrap();
assert_eq!(expr, converted_expr);
}
}

View File

@@ -17,6 +17,7 @@ use std::sync::Arc;
use bytes::{Buf, Bytes, BytesMut};
use catalog::CatalogManagerRef;
use common_error::prelude::BoxedError;
use common_telemetry::debug;
use datafusion::datasource::TableProvider;
use datafusion::logical_plan::{LogicalPlan, TableScan, ToDFSchema};
use datafusion::physical_plan::project_schema;
@@ -24,12 +25,15 @@ use prost::Message;
use snafu::{ensure, OptionExt, ResultExt};
use substrait_proto::protobuf::expression::mask_expression::{StructItem, StructSelect};
use substrait_proto::protobuf::expression::MaskExpression;
use substrait_proto::protobuf::extensions::simple_extension_declaration::MappingType;
use substrait_proto::protobuf::plan_rel::RelType as PlanRelType;
use substrait_proto::protobuf::read_rel::{NamedTable, ReadType};
use substrait_proto::protobuf::rel::RelType;
use substrait_proto::protobuf::{PlanRel, ReadRel, Rel};
use substrait_proto::protobuf::{Plan, PlanRel, ReadRel, Rel};
use table::table::adapter::DfTableProviderAdapter;
use crate::context::ConvertorContext;
use crate::df_expr::{expression_from_df_expr, to_df_expr};
use crate::error::{
DFInternalSnafu, DecodeRelSnafu, EmptyPlanSnafu, EncodeRelSnafu, Error, InternalSnafu,
InvalidParametersSnafu, MissingFieldSnafu, SchemaNotMatchSnafu, TableNotFoundSnafu,
@@ -48,25 +52,15 @@ impl SubstraitPlan for DFLogicalSubstraitConvertor {
type Plan = LogicalPlan;
fn decode<B: Buf + Send>(&self, message: B) -> Result<Self::Plan, Self::Error> {
let plan_rel = PlanRel::decode(message).context(DecodeRelSnafu)?;
let rel = match plan_rel.rel_type.context(EmptyPlanSnafu)? {
PlanRelType::Rel(rel) => rel,
PlanRelType::Root(_) => UnsupportedPlanSnafu {
name: "Root Relation",
}
.fail()?,
};
self.convert_rel(rel)
let plan = Plan::decode(message).context(DecodeRelSnafu)?;
self.convert_plan(plan)
}
fn encode(&self, plan: Self::Plan) -> Result<Bytes, Self::Error> {
let rel = self.convert_plan(plan)?;
let plan_rel = PlanRel {
rel_type: Some(PlanRelType::Rel(rel)),
};
let plan = self.convert_df_plan(plan)?;
let mut buf = BytesMut::new();
plan_rel.encode(&mut buf).context(EncodeRelSnafu)?;
plan.encode(&mut buf).context(EncodeRelSnafu)?;
Ok(buf.freeze())
}
@@ -79,10 +73,37 @@ impl DFLogicalSubstraitConvertor {
}
impl DFLogicalSubstraitConvertor {
pub fn convert_rel(&self, rel: Rel) -> Result<LogicalPlan, Error> {
pub fn convert_plan(&self, mut plan: Plan) -> Result<LogicalPlan, Error> {
// prepare convertor context
let mut ctx = ConvertorContext::default();
for simple_ext in plan.extensions {
if let Some(MappingType::ExtensionFunction(function_extension)) =
simple_ext.mapping_type
{
ctx.register_scalar_with_anchor(
function_extension.name,
function_extension.function_anchor,
);
} else {
debug!("Encounter unsupported substrait extension {:?}", simple_ext);
}
}
// extract rel
let rel = if let Some(PlanRel { rel_type }) = plan.relations.pop()
&& let Some(PlanRelType::Rel(rel)) = rel_type {
rel
} else {
UnsupportedPlanSnafu {
name: "Emply or non-Rel relation",
}
.fail()?
};
let rel_type = rel.rel_type.context(EmptyPlanSnafu)?;
// build logical plan
let logical_plan = match rel_type {
RelType::Read(read_rel) => self.convert_read_rel(read_rel),
RelType::Read(read_rel) => self.convert_read_rel(&mut ctx, read_rel),
RelType::Filter(_filter_rel) => UnsupportedPlanSnafu {
name: "Filter Relation",
}
@@ -132,9 +153,12 @@ impl DFLogicalSubstraitConvertor {
Ok(logical_plan)
}
fn convert_read_rel(&self, read_rel: Box<ReadRel>) -> Result<LogicalPlan, Error> {
fn convert_read_rel(
&self,
ctx: &mut ConvertorContext,
read_rel: Box<ReadRel>,
) -> Result<LogicalPlan, Error> {
// Extract the catalog, schema and table name from NamedTable. Assume the first three are those names.
let read_type = read_rel.read_type.context(MissingFieldSnafu {
field: "read_type",
plan: "Read",
@@ -190,6 +214,13 @@ impl DFLogicalSubstraitConvertor {
}
);
// Convert filter
let filters = if let Some(filter) = read_rel.filter {
vec![to_df_expr(ctx, *filter, &retrieved_schema)?]
} else {
vec![]
};
// Calculate the projected schema
let projected_schema = project_schema(&stored_schema, projection.as_ref())
.context(DFInternalSnafu)?
@@ -202,7 +233,7 @@ impl DFLogicalSubstraitConvertor {
source: adapter,
projection,
projected_schema,
filters: vec![],
filters,
limit: None,
}))
}
@@ -219,8 +250,12 @@ impl DFLogicalSubstraitConvertor {
}
impl DFLogicalSubstraitConvertor {
pub fn convert_plan(&self, plan: LogicalPlan) -> Result<Rel, Error> {
match plan {
pub fn convert_df_plan(&self, plan: LogicalPlan) -> Result<Plan, Error> {
let mut ctx = ConvertorContext::default();
// TODO(ruihang): extract this translation logic into a separated function
// convert PlanRel
let rel = match plan {
LogicalPlan::Projection(_) => UnsupportedPlanSnafu {
name: "DataFusion Logical Projection",
}
@@ -258,10 +293,10 @@ impl DFLogicalSubstraitConvertor {
}
.fail()?,
LogicalPlan::TableScan(table_scan) => {
let read_rel = self.convert_table_scan_plan(table_scan)?;
Ok(Rel {
let read_rel = self.convert_table_scan_plan(&mut ctx, table_scan)?;
Rel {
rel_type: Some(RelType::Read(Box::new(read_rel))),
})
}
}
LogicalPlan::EmptyRelation(_) => UnsupportedPlanSnafu {
name: "DataFusion Logical EmptyRelation",
@@ -284,10 +319,30 @@ impl DFLogicalSubstraitConvertor {
),
}
.fail()?,
}
};
// convert extension
let extensions = ctx.generate_function_extension();
// assemble PlanRel
let plan_rel = PlanRel {
rel_type: Some(PlanRelType::Rel(rel)),
};
Ok(Plan {
extension_uris: vec![],
extensions,
relations: vec![plan_rel],
advanced_extensions: None,
expected_type_urls: vec![],
})
}
pub fn convert_table_scan_plan(&self, table_scan: TableScan) -> Result<ReadRel, Error> {
pub fn convert_table_scan_plan(
&self,
ctx: &mut ConvertorContext,
table_scan: TableScan,
) -> Result<ReadRel, Error> {
let provider = table_scan
.source
.as_any()
@@ -313,10 +368,25 @@ impl DFLogicalSubstraitConvertor {
// assemble base (unprojected) schema using Table's schema.
let base_schema = from_schema(&provider.table().schema())?;
// make conjunction over a list of filters and convert the result to substrait
let filter = if let Some(conjunction) = table_scan
.filters
.into_iter()
.reduce(|accum, expr| accum.and(expr))
{
Some(Box::new(expression_from_df_expr(
ctx,
&conjunction,
&provider.table().schema(),
)?))
} else {
None
};
let read_rel = ReadRel {
common: None,
base_schema: Some(base_schema),
filter: None,
filter,
projection,
advanced_extension: None,
read_type: Some(read_type),

View File

@@ -23,10 +23,10 @@ use snafu::{Backtrace, ErrorCompat, Snafu};
#[derive(Debug, Snafu)]
#[snafu(visibility(pub))]
pub enum Error {
#[snafu(display("Unsupported physical expr: {}", name))]
#[snafu(display("Unsupported physical plan: {}", name))]
UnsupportedPlan { name: String, backtrace: Backtrace },
#[snafu(display("Unsupported physical plan: {}", name))]
#[snafu(display("Unsupported expr: {}", name))]
UnsupportedExpr { name: String, backtrace: Backtrace },
#[snafu(display("Unsupported concrete type: {:?}", ty))]

View File

@@ -12,6 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#![feature(let_chains)]
mod context;
mod df_expr;
mod df_logical;
pub mod error;
mod schema;