feat(flow): support datafusion scalar function (#4142)

* chore: call df function types

* feat: RelationDesc to DfSchema

* refactor: use RelationDesc instead of Type

* chore: WIP get to phy expr

* feat: custom deserialize

* chore: fmt

* refactor: renaming to DfScalarFunction

* feat: eval df func(untested)

* fix: had to spawn a thread for calling async

* chore: per review advices

* tests: test df scalar function
This commit is contained in:
discord9
2024-06-18 20:34:38 +08:00
committed by GitHub
parent ea2d067cf1
commit cd9705ccd7
14 changed files with 434 additions and 69 deletions

4
Cargo.lock generated
View File

@@ -3809,7 +3809,9 @@ name = "flow"
version = "0.8.2"
dependencies = [
"api",
"arrow-schema",
"async-trait",
"bytes",
"catalog",
"common-base",
"common-catalog",
@@ -3824,8 +3826,10 @@ dependencies = [
"common-runtime",
"common-telemetry",
"common-time",
"datafusion 38.0.0",
"datafusion-common 38.0.0",
"datafusion-expr 38.0.0",
"datafusion-physical-expr 38.0.0",
"datatypes",
"enum-as-inner",
"enum_dispatch",

View File

@@ -9,29 +9,33 @@ workspace = true
[dependencies]
api.workspace = true
arrow-schema.workspace = true
async-trait.workspace = true
bytes.workspace = true
catalog.workspace = true
common-base.workspace = true
common-decimal.workspace = true
common-error.workspace = true
common-frontend.workspace = true
common-macro.workspace = true
common-runtime.workspace = true
common-telemetry.workspace = true
common-time.workspace = true
datafusion-common.workspace = true
datafusion-expr.workspace = true
datatypes.workspace = true
enum_dispatch = "0.3"
futures = "0.3"
# This fork is simply for keeping our dependency in our org, and pin the version
# it is the same with upstream repo
async-trait.workspace = true
common-function.workspace = true
common-macro.workspace = true
common-meta.workspace = true
common-query.workspace = true
common-recordbatch.workspace = true
common-runtime.workspace = true
common-telemetry.workspace = true
common-time.workspace = true
datafusion.workspace = true
datafusion-common.workspace = true
datafusion-expr.workspace = true
datafusion-physical-expr.workspace = true
datatypes.workspace = true
enum-as-inner = "0.6.0"
enum_dispatch = "0.3"
futures = "0.3"
greptime-proto.workspace = true
# This fork of hydroflow is simply for keeping our dependency in our org, and pin the version
# otherwise it is the same with upstream repo
hydroflow = { git = "https://github.com/GreptimeTeam/hydroflow.git", branch = "main" }
itertools.workspace = true
minstant = "0.1.7"

View File

@@ -64,12 +64,12 @@ mod table_source;
use error::Error;
// TODO: replace this with `GREPTIME_TIMESTAMP` before v0.9
// TODO(discord9): replace this with `GREPTIME_TIMESTAMP` before v0.9
pub const AUTO_CREATED_PLACEHOLDER_TS_COL: &str = "__ts_placeholder";
pub const UPDATE_AT_TS_COL: &str = "update_at";
// TODO: refactor common types for flow to a separate module
// TODO(discord9): refactor common types for flow to a separate module
/// FlowId is a unique identifier for a flow task
pub type FlowId = u64;
pub type TableName = [String; 3];

View File

@@ -79,7 +79,6 @@ impl Default for SourceSender {
}
}
// TODO: make all send operation immut
impl SourceSender {
pub fn get_receiver(&self) -> broadcast::Receiver<DiffRow> {
self.sender.subscribe()

View File

@@ -27,4 +27,4 @@ pub(crate) use func::{BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc}
pub(crate) use id::{GlobalId, Id, LocalId};
pub(crate) use linear::{MapFilterProject, MfpPlan, SafeMfpPlan};
pub(crate) use relation::{AggregateExpr, AggregateFunc};
pub(crate) use scalar::{ScalarExpr, TypedExpr};
pub(crate) use scalar::{DfScalarFunction, RawDfScalarFn, ScalarExpr, TypedExpr};

View File

@@ -16,12 +16,15 @@
use std::any::Any;
use arrow_schema::ArrowError;
use common_error::ext::BoxedError;
use common_macro::stack_trace_debug;
use common_telemetry::common_error::ext::ErrorExt;
use common_telemetry::common_error::status_code::StatusCode;
use datafusion_common::DataFusionError;
use datatypes::data_type::ConcreteDataType;
use serde::{Deserialize, Serialize};
use snafu::{Location, Snafu};
use snafu::{Location, ResultExt, Snafu};
fn is_send_sync() {
fn check<T: Send + Sync>() {}
@@ -107,4 +110,27 @@ pub enum EvalError {
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Arrow error: {raw:?}, context: {context}"))]
Arrow {
raw: ArrowError,
context: String,
#[snafu(implicit)]
location: Location,
},
#[snafu(display("DataFusion error: {raw:?}, context: {context}"))]
Datafusion {
raw: DataFusionError,
context: String,
#[snafu(implicit)]
location: Location,
},
#[snafu(display("External error"))]
External {
#[snafu(implicit)]
location: Location,
source: BoxedError,
},
}

View File

@@ -814,7 +814,7 @@ impl VariadicFunc {
name: &str,
arg_types: &[Option<ConcreteDataType>],
) -> Result<Self, Error> {
// TODO: future variadic funcs to be added might need to check arg_types
// TODO(discord9): future variadic funcs to be added might need to check arg_types
let _ = arg_types;
match name {
"and" => Ok(Self::And),

View File

@@ -15,19 +15,33 @@
//! Scalar expressions.
use std::collections::{BTreeMap, BTreeSet};
use std::sync::{Arc, Mutex};
use bytes::BytesMut;
use common_error::ext::BoxedError;
use common_recordbatch::DfRecordBatch;
use datafusion_physical_expr::PhysicalExpr;
use datatypes::arrow_array;
use datatypes::data_type::DataType;
use datatypes::prelude::ConcreteDataType;
use datatypes::value::Value;
use prost::Message;
use serde::{Deserialize, Serialize};
use snafu::ensure;
use snafu::{ensure, ResultExt};
use substrait::error::{DecodeRelSnafu, EncodeRelSnafu};
use substrait::substrait_proto_df::proto::expression::{RexType, ScalarFunction};
use substrait::substrait_proto_df::proto::Expression;
use crate::adapter::error::{
Error, InvalidQuerySnafu, UnexpectedSnafu, UnsupportedTemporalFilterSnafu,
DatafusionSnafu, Error, InvalidQuerySnafu, UnexpectedSnafu, UnsupportedTemporalFilterSnafu,
};
use crate::expr::error::{
ArrowSnafu, DatafusionSnafu as EvalDatafusionSnafu, EvalError, ExternalSnafu,
InvalidArgumentSnafu, OptimizeSnafu,
};
use crate::expr::error::{EvalError, InvalidArgumentSnafu, OptimizeSnafu};
use crate::expr::func::{BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc};
use crate::repr::{ColumnType, RelationType};
use crate::repr::{ColumnType, RelationDesc, RelationType};
use crate::transform::{from_scalar_fn_to_df_fn_impl, FunctionExtensions};
/// A scalar expression with a known type.
#[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
pub struct TypedExpr {
@@ -47,6 +61,8 @@ impl TypedExpr {
/// expand multi-value expression to multiple expressions with new indices
///
/// Currently it just mean expand `TumbleWindow` to `TumbleWindowFloor` and `TumbleWindowCeiling`
///
/// TODO(discord9): test if nested reduce combine with df scalar function would cause problem
pub fn expand_multi_value(
input_typ: &RelationType,
exprs: &[TypedExpr],
@@ -138,6 +154,10 @@ pub enum ScalarExpr {
func: VariadicFunc,
exprs: Vec<ScalarExpr>,
},
CallDf {
// TODO(discord9): support shuffle
df_scalar_fn: DfScalarFunction,
},
/// Conditionally evaluated expressions.
///
/// It is important that `then` and `els` only be evaluated if
@@ -151,6 +171,161 @@ pub enum ScalarExpr {
},
}
/// A way to represent a scalar function that is implemented in Datafusion
#[derive(Debug, Clone)]
pub struct DfScalarFunction {
raw_fn: RawDfScalarFn,
// TODO(discord9): directly from datafusion expr
fn_impl: Arc<dyn PhysicalExpr>,
df_schema: Arc<datafusion_common::DFSchema>,
}
impl DfScalarFunction {
pub fn new(raw_fn: RawDfScalarFn, fn_impl: Arc<dyn PhysicalExpr>) -> Result<Self, Error> {
Ok(Self {
df_schema: Arc::new(raw_fn.input_schema.to_df_schema()?),
raw_fn,
fn_impl,
})
}
// TODO(discord9): add RecordBatch support
pub fn eval(&self, values: &[Value]) -> Result<Value, EvalError> {
if values.is_empty() {
return InvalidArgumentSnafu {
reason: "values is empty".to_string(),
}
.fail();
}
// TODO(discord9): make cols all array length of one
let mut cols = vec![];
for (idx, typ) in self
.raw_fn
.input_schema
.typ()
.column_types
.iter()
.enumerate()
{
let typ = typ.scalar_type();
let mut array = typ.create_mutable_vector(1);
array.push_value_ref(values[idx].as_value_ref());
cols.push(array.to_vector().to_arrow_array());
}
let schema = self.df_schema.inner().clone();
let rb = DfRecordBatch::try_new(schema, cols).map_err(|err| {
ArrowSnafu {
raw: err,
context:
"Failed to create RecordBatch from values when eval datafusion scalar function",
}
.build()
})?;
let res = self.fn_impl.evaluate(&rb).map_err(|err| {
EvalDatafusionSnafu {
raw: err,
context: "Failed to evaluate datafusion scalar function",
}
.build()
})?;
let res = common_query::columnar_value::ColumnarValue::try_from(&res)
.map_err(BoxedError::new)
.context(ExternalSnafu)?;
let res_vec = res
.try_into_vector(1)
.map_err(BoxedError::new)
.context(ExternalSnafu)?;
let res_val = res_vec
.try_get(0)
.map_err(BoxedError::new)
.context(ExternalSnafu)?;
Ok(res_val)
}
}
// simply serialize the raw_fn instead of derive to avoid complex deserialize of struct
impl Serialize for DfScalarFunction {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.raw_fn.serialize(serializer)
}
}
impl<'de> serde::de::Deserialize<'de> for DfScalarFunction {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::de::Deserializer<'de>,
{
let raw_fn = RawDfScalarFn::deserialize(deserializer)?;
let fn_impl = raw_fn.get_fn_impl().map_err(serde::de::Error::custom)?;
DfScalarFunction::new(raw_fn, fn_impl).map_err(serde::de::Error::custom)
}
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct RawDfScalarFn {
f: bytes::BytesMut,
input_schema: RelationDesc,
extensions: FunctionExtensions,
}
impl RawDfScalarFn {
pub fn from_proto(
f: &substrait::substrait_proto_df::proto::expression::ScalarFunction,
input_schema: RelationDesc,
extensions: FunctionExtensions,
) -> Result<Self, Error> {
let mut buf = BytesMut::new();
f.encode(&mut buf)
.context(EncodeRelSnafu)
.map_err(BoxedError::new)
.context(crate::adapter::error::ExternalSnafu)?;
Ok(Self {
f: buf,
input_schema,
extensions,
})
}
fn get_fn_impl(&self) -> Result<Arc<dyn PhysicalExpr>, Error> {
let f = ScalarFunction::decode(&mut self.f.as_ref())
.context(DecodeRelSnafu)
.map_err(BoxedError::new)
.context(crate::adapter::error::ExternalSnafu)?;
let input_schema = &self.input_schema;
let extensions = &self.extensions;
from_scalar_fn_to_df_fn_impl(&f, input_schema, extensions)
}
}
impl std::cmp::PartialEq for DfScalarFunction {
fn eq(&self, other: &Self) -> bool {
self.raw_fn.eq(&other.raw_fn)
}
}
// can't derive Eq because of Arc<dyn PhysicalExpr> not eq, so implement it manually
impl std::cmp::Eq for DfScalarFunction {}
impl std::cmp::PartialOrd for DfScalarFunction {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl std::cmp::Ord for DfScalarFunction {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.raw_fn.cmp(&other.raw_fn)
}
}
impl std::hash::Hash for DfScalarFunction {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.raw_fn.hash(state);
}
}
impl ScalarExpr {
pub fn with_type(self, typ: ColumnType) -> TypedExpr {
TypedExpr::new(self, typ)
@@ -179,6 +354,23 @@ impl ScalarExpr {
Ok(ColumnType::new_nullable(func.signature().output))
}
ScalarExpr::If { then, .. } => then.typ(context),
ScalarExpr::CallDf { df_scalar_fn } => {
let arrow_typ = df_scalar_fn
.fn_impl
// TODO(discord9): get scheme from args instead?
.data_type(df_scalar_fn.df_schema.as_arrow())
.map_err(|err| {
DatafusionSnafu {
raw: err,
context: "Failed to get data type from datafusion scalar function",
}
.build()
})?;
let typ = ConcreteDataType::try_from(&arrow_typ)
.map_err(BoxedError::new)
.context(crate::adapter::error::ExternalSnafu)?;
Ok(ColumnType::new_nullable(typ))
}
}
}
}
@@ -253,6 +445,7 @@ impl ScalarExpr {
}
.fail(),
},
ScalarExpr::CallDf { df_scalar_fn } => df_scalar_fn.eval(values),
}
}
@@ -421,6 +614,7 @@ impl ScalarExpr {
f(then)?;
f(els)
}
_ => Ok(()),
}
}
@@ -456,6 +650,7 @@ impl ScalarExpr {
f(then)?;
f(els)
}
_ => Ok(()),
}
}
}
@@ -497,7 +692,7 @@ impl ScalarExpr {
return unsupported_err("Not a binary expression");
};
// TODO: support simple transform like `now() + a < b` to `now() < b - a`
// TODO(discord9): support simple transform like `now() + a < b` to `now() < b - a`
let expr1_is_now = *expr1 == ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now);
let expr2_is_now = *expr2 == ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now);
@@ -531,6 +726,15 @@ impl ScalarExpr {
#[cfg(test)]
mod test {
use datatypes::arrow::array::Scalar;
use query::parser::QueryLanguageParser;
use query::QueryEngine;
use session::context::QueryContext;
use substrait::extension_serializer;
use substrait::substrait_proto_df::proto::expression::literal::LiteralType;
use substrait::substrait_proto_df::proto::expression::Literal;
use substrait::substrait_proto_df::proto::function_argument::ArgType;
use substrait::substrait_proto_df::proto::r#type::Kind;
use substrait::substrait_proto_df::proto::{r#type, FunctionArgument, Type};
use super::*;
#[test]
@@ -622,4 +826,37 @@ mod test {
let res = expr.permute_map(&permute_map);
assert!(matches!(res, Err(Error::InvalidQuery { .. })));
}
#[tokio::test]
async fn test_df_scalar_function() {
let raw_scalar_func = ScalarFunction {
function_reference: 0,
arguments: vec![FunctionArgument {
arg_type: Some(ArgType::Value(Expression {
rex_type: Some(RexType::Literal(Literal {
nullable: false,
type_variation_reference: 0,
literal_type: Some(LiteralType::I64(-1)),
})),
})),
}],
output_type: None,
..Default::default()
};
let input_schema = RelationDesc::try_new(
RelationType::new(vec![ColumnType::new_nullable(
ConcreteDataType::null_datatype(),
)]),
vec!["null_column".to_string()],
)
.unwrap();
let extensions = FunctionExtensions::from_iter(vec![(0, "abs")]);
let raw_fn = RawDfScalarFn::from_proto(&raw_scalar_func, input_schema, extensions).unwrap();
let fn_impl = raw_fn.get_fn_impl().unwrap();
let df_func = DfScalarFunction::new(raw_fn, fn_impl).unwrap();
let as_str = serde_json::to_string(&df_func).unwrap();
let from_str: DfScalarFunction = serde_json::from_str(&as_str).unwrap();
assert_eq!(df_func, from_str);
assert_eq!(df_func.eval(&[Value::Null]).unwrap(), Value::Int64(1));
}
}

View File

@@ -14,12 +14,14 @@
use std::collections::{BTreeMap, HashMap};
use datafusion_common::DFSchema;
use datatypes::data_type::DataType;
use datatypes::prelude::ConcreteDataType;
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use snafu::{ensure, OptionExt};
use snafu::{ensure, OptionExt, ResultExt};
use crate::adapter::error::{InvalidQuerySnafu, Result, UnexpectedSnafu};
use crate::adapter::error::{DatafusionSnafu, InvalidQuerySnafu, Result, UnexpectedSnafu};
use crate::expr::{MapFilterProject, SafeMfpPlan, ScalarExpr};
/// a set of column indices that are "keys" for the collection.
@@ -338,9 +340,31 @@ pub struct RelationDesc {
}
impl RelationDesc {
pub fn to_df_schema(&self) -> Result<DFSchema> {
let fields: Vec<_> = self
.iter()
.enumerate()
.map(|(i, (name, typ))| {
let name = name.clone().unwrap_or(format!("Col_{i}"));
let nullable = typ.nullable;
let data_type = typ.scalar_type.clone().as_arrow_type();
arrow_schema::Field::new(name, data_type, nullable)
})
.collect();
let arrow_schema = arrow_schema::Schema::new(fields);
DFSchema::try_from(arrow_schema.clone()).map_err(|err| {
DatafusionSnafu {
raw: err,
context: format!("Error when converting to DFSchema: {:?}", arrow_schema),
}
.build()
})
}
/// apply mfp, and also project col names for the projected columns
pub fn apply_mfp(&self, mfp: &SafeMfpPlan) -> Result<Self> {
// TODO: find a way to deduce name at best effect
// TODO(discord9): find a way to deduce name at best effect
let names = {
let mfp = &mfp.mfp;
let mut names = self.names.clone();

View File

@@ -13,9 +13,10 @@
// limitations under the License.
//! Transform Substrait into execution plan
use std::collections::HashMap;
use std::collections::{BTreeMap, HashMap};
use std::sync::Arc;
use bytes::buf::IntoIter;
use common_error::ext::BoxedError;
use common_telemetry::info;
use datatypes::data_type::ConcreteDataType as CDT;
@@ -25,6 +26,7 @@ use query::parser::QueryLanguageParser;
use query::plan::LogicalPlan;
use query::query_engine::DefaultSerializer;
use query::QueryEngine;
use serde::{Deserialize, Serialize};
use session::context::QueryContext;
use snafu::{OptionExt, ResultExt};
/// note here we are using the `substrait_proto_df` crate from the `substrait` module and
@@ -43,7 +45,6 @@ use crate::adapter::FlownodeContext;
use crate::expr::GlobalId;
use crate::plan::TypedPlan;
use crate::repr::RelationType;
/// a simple macro to generate a not implemented error
macro_rules! not_impl_err {
($($arg:tt)*) => {
@@ -67,17 +68,26 @@ mod expr;
mod literal;
mod plan;
pub(crate) use expr::from_scalar_fn_to_df_fn_impl;
/// In Substrait, a function can be define by an u32 anchor, and the anchor can be mapped to a name
///
/// So in substrait plan, a ref to a function can be a single u32 anchor instead of a full name in string
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct FunctionExtensions {
anchor_to_name: HashMap<u32, String>,
anchor_to_name: BTreeMap<u32, String>,
}
impl FunctionExtensions {
pub fn from_iter(inner: impl IntoIterator<Item = (u32, impl ToString)>) -> Self {
Self {
anchor_to_name: inner.into_iter().map(|(k, s)| (k, s.to_string())).collect(),
}
}
/// Create a new FunctionExtensions from a list of SimpleExtensionDeclaration
pub fn try_from_proto(extensions: &[SimpleExtensionDeclaration]) -> Result<Self, Error> {
let mut anchor_to_name = HashMap::new();
let mut anchor_to_name = BTreeMap::new();
for e in extensions {
match &e.mapping_type {
Some(ext) => match ext {
@@ -96,6 +106,10 @@ impl FunctionExtensions {
pub fn get(&self, anchor: &u32) -> Option<&String> {
self.anchor_to_name.get(anchor)
}
pub fn inner_ref(&self) -> HashMap<u32, &String> {
self.anchor_to_name.iter().map(|(k, v)| (*k, v)).collect()
}
}
/// To reuse existing code for parse sql, the sql is first parsed into a datafusion logical plan,

View File

@@ -53,14 +53,14 @@ use crate::expr::{
TypedExpr, UnaryFunc, UnmaterializableFunc, VariadicFunc,
};
use crate::plan::{AccumulablePlan, AggrWithIndex, KeyValPlan, Plan, ReducePlan, TypedPlan};
use crate::repr::{self, ColumnType, RelationType};
use crate::repr::{self, ColumnType, RelationDesc, RelationType};
use crate::transform::{substrait_proto, FlownodeContext, FunctionExtensions};
impl TypedExpr {
fn from_substrait_agg_grouping(
ctx: &mut FlownodeContext,
groupings: &[Grouping],
typ: &RelationType,
typ: &RelationDesc,
extensions: &FunctionExtensions,
) -> Result<Vec<TypedExpr>, Error> {
let _ = ctx;
@@ -89,7 +89,7 @@ impl AggregateExpr {
fn from_substrait_agg_measures(
ctx: &mut FlownodeContext,
measures: &[Measure],
typ: &RelationType,
typ: &RelationDesc,
extensions: &FunctionExtensions,
) -> Result<(Vec<AggregateExpr>, MapFilterProject), Error> {
let _ = ctx;
@@ -143,7 +143,7 @@ impl AggregateExpr {
/// since aggr functions like `avg` need to be transform to `sum(x)/cast(count(x) as x_type)`
pub fn from_substrait_agg_func(
f: &proto::AggregateFunction,
input_schema: &RelationType,
input_schema: &RelationDesc,
extensions: &FunctionExtensions,
filter: &Option<TypedExpr>,
order_by: &Option<Vec<TypedExpr>>,
@@ -320,7 +320,7 @@ impl TypedPlan {
let group_exprs = TypedExpr::from_substrait_agg_grouping(
ctx,
&agg.groupings,
&input.schema.typ,
&input.schema,
extensions,
)?;
@@ -332,7 +332,7 @@ impl TypedPlan {
let (mut aggr_exprs, post_mfp) = AggregateExpr::from_substrait_agg_measures(
ctx,
&agg.measures,
&input.schema.typ,
&input.schema,
extensions,
)?;

View File

@@ -14,6 +14,9 @@
#![warn(unused_imports)]
use std::sync::Arc;
use datafusion_physical_expr::PhysicalExpr;
use datatypes::data_type::ConcreteDataType as CDT;
use itertools::Itertools;
use snafu::{OptionExt, ResultExt};
@@ -24,15 +27,17 @@ use substrait_proto::proto::function_argument::ArgType;
use substrait_proto::proto::Expression;
use crate::adapter::error::{
DatatypesSnafu, Error, EvalSnafu, InvalidQuerySnafu, NotImplementedSnafu, PlanSnafu,
DatafusionSnafu, DatatypesSnafu, Error, EvalSnafu, InvalidQuerySnafu, NotImplementedSnafu,
PlanSnafu,
};
use crate::expr::{
BinaryFunc, ScalarExpr, TypedExpr, UnaryFunc, UnmaterializableFunc, VariadicFunc,
BinaryFunc, DfScalarFunction, RawDfScalarFn, ScalarExpr, TypedExpr, UnaryFunc,
UnmaterializableFunc, VariadicFunc,
};
use crate::repr::{ColumnType, RelationType};
use crate::repr::{ColumnType, RelationDesc};
use crate::transform::literal::{from_substrait_literal, from_substrait_type};
use crate::transform::{substrait_proto, FunctionExtensions};
// TODO: found proper place for this
// TODO(discord9): found proper place for this
/// ref to `arrow_schema::datatype` for type name
fn typename_to_cdt(name: &str) -> CDT {
match name {
@@ -54,11 +59,64 @@ fn typename_to_cdt(name: &str) -> CDT {
}
}
/// Convert [`ScalarFunction`] to corresponding Datafusion's [`PhysicalExpr`]
pub(crate) fn from_scalar_fn_to_df_fn_impl(
f: &ScalarFunction,
input_schema: &RelationDesc,
extensions: &FunctionExtensions,
) -> Result<Arc<dyn PhysicalExpr>, Error> {
let e = Expression {
rex_type: Some(RexType::ScalarFunction(f.clone())),
};
let schema = input_schema.to_df_schema()?;
let df_expr = futures::executor::block_on(async {
// TODO(discord9): consider coloring everything async....
substrait::df_logical_plan::consumer::from_substrait_rex(
&datafusion::prelude::SessionContext::new(),
&e,
&schema,
&extensions.inner_ref(),
)
.await
});
let expr = df_expr.map_err(|err| {
DatafusionSnafu {
raw: err,
context: "Failed to convert substrait scalar function to datafusion scalar function",
}
.build()
})?;
let phy_expr =
datafusion::physical_expr::create_physical_expr(&expr, &schema, &Default::default())
.map_err(|err| {
DatafusionSnafu {
raw: err,
context: "Failed to create physical expression from logical expression",
}
.build()
})?;
Ok(phy_expr)
}
impl TypedExpr {
pub fn from_substrait_to_datafusion_scalar_func(
f: &ScalarFunction,
input_schema: &RelationDesc,
extensions: &FunctionExtensions,
) -> Result<TypedExpr, Error> {
let phy_expr = from_scalar_fn_to_df_fn_impl(f, input_schema, extensions)?;
let raw_fn = RawDfScalarFn::from_proto(f, input_schema.clone(), extensions.clone())?;
let expr = DfScalarFunction::new(raw_fn, phy_expr)?;
let expr = ScalarExpr::CallDf { df_scalar_fn: expr };
// df already know it's own schema, so not providing here
let ret_type = expr.typ(&[])?;
Ok(TypedExpr::new(expr, ret_type))
}
/// Convert ScalarFunction into Flow's ScalarExpr
pub fn from_substrait_scalar_func(
f: &ScalarFunction,
input_schema: &RelationType,
input_schema: &RelationDesc,
extensions: &FunctionExtensions,
) -> Result<TypedExpr, Error> {
let fn_name =
@@ -182,7 +240,12 @@ impl TypedExpr {
ret_type,
))
} else {
not_impl_err!("Unsupported function {fn_name} with {arg_len} arguments")
let try_as_df = Self::from_substrait_to_datafusion_scalar_func(
f,
input_schema,
extensions,
)?;
Ok(try_as_df)
}
}
}
@@ -191,7 +254,7 @@ impl TypedExpr {
/// Convert IfThen into Flow's ScalarExpr
pub fn from_substrait_ifthen_rex(
if_then: &IfThen,
input_schema: &RelationType,
input_schema: &RelationDesc,
extensions: &FunctionExtensions,
) -> Result<TypedExpr, Error> {
let ifs: Vec<_> = if_then
@@ -246,7 +309,7 @@ impl TypedExpr {
/// Convert Substrait Rex into Flow's ScalarExpr
pub fn from_substrait_rex(
e: &Expression,
input_schema: &RelationType,
input_schema: &RelationDesc,
extensions: &FunctionExtensions,
) -> Result<TypedExpr, Error> {
match &e.rex_type {
@@ -277,7 +340,7 @@ impl TypedExpr {
}
None => {
let column = x.field as usize;
let column_type = input_schema.column_types[column].clone();
let column_type = input_schema.typ().column_types[column].clone();
Ok(TypedExpr::new(ScalarExpr::Column(column), column_type))
}
},
@@ -322,8 +385,6 @@ impl TypedExpr {
#[cfg(test)]
mod test {
use std::collections::HashMap;
use common_time::{DateTime, Interval};
use datatypes::prelude::ConcreteDataType;
use datatypes::value::Value;
@@ -584,10 +645,9 @@ mod test {
output_type: None,
..Default::default()
};
let input_schema = RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), false)]);
let extensions = FunctionExtensions {
anchor_to_name: HashMap::from([(0, "is_null".to_string())]),
};
let input_schema =
RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), false)]).into_unnamed();
let extensions = FunctionExtensions::from_iter([(0, "is_null".to_string())]);
let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions).unwrap();
assert_eq!(
@@ -611,10 +671,9 @@ mod test {
let input_schema = RelationType::new(vec![
ColumnType::new(CDT::uint32_datatype(), false),
ColumnType::new(CDT::uint32_datatype(), false),
]);
let extensions = FunctionExtensions {
anchor_to_name: HashMap::from([(0, "add".to_string())]),
};
])
.into_unnamed();
let extensions = FunctionExtensions::from_iter([(0, "add".to_string())]);
let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions).unwrap();
assert_eq!(
@@ -639,10 +698,9 @@ mod test {
let input_schema = RelationType::new(vec![
ColumnType::new(CDT::timestamp_nanosecond_datatype(), false),
ColumnType::new(CDT::string_datatype(), false),
]);
let extensions = FunctionExtensions {
anchor_to_name: HashMap::from([(0, "tumble".to_string())]),
};
])
.into_unnamed();
let extensions = FunctionExtensions::from_iter(vec![(0, "tumble".to_string())]);
let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions).unwrap();
assert_eq!(
@@ -668,10 +726,9 @@ mod test {
let input_schema = RelationType::new(vec![
ColumnType::new(CDT::timestamp_nanosecond_datatype(), false),
ColumnType::new(CDT::string_datatype(), false),
]);
let extensions = FunctionExtensions {
anchor_to_name: HashMap::from([(0, "tumble".to_string())]),
};
])
.into_unnamed();
let extensions = FunctionExtensions::from_iter(vec![(0, "tumble".to_string())]);
let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions).unwrap();
assert_eq!(

View File

@@ -64,7 +64,7 @@ impl TypedPlan {
}
/// Convert Substrait Rel into Flow's TypedPlan
/// TODO: SELECT DISTINCT(does it get compile with something else?)
/// TODO(discord9): SELECT DISTINCT(does it get compile with something else?)
pub fn from_substrait_rel(
ctx: &mut FlownodeContext,
rel: &Rel,
@@ -80,7 +80,7 @@ impl TypedPlan {
let mut exprs: Vec<TypedExpr> = vec![];
for e in &p.expressions {
let expr = TypedExpr::from_substrait_rex(e, &input.schema.typ, extensions)?;
let expr = TypedExpr::from_substrait_rex(e, &input.schema, extensions)?;
exprs.push(expr);
}
let is_literal = exprs.iter().all(|expr| expr.expr.is_literal());
@@ -133,7 +133,7 @@ impl TypedPlan {
};
let expr = if let Some(condition) = filter.condition.as_ref() {
TypedExpr::from_substrait_rex(condition, &input.schema.typ, extensions)?
TypedExpr::from_substrait_rex(condition, &input.schema, extensions)?
} else {
return not_impl_err!("Filter without an condition is not valid");
};
@@ -213,7 +213,7 @@ fn rewrite_projection_after_reduce(
reduce_output_type: &RelationDesc,
proj_exprs: &mut Vec<TypedExpr>,
) -> Result<(), Error> {
// TODO: get keys correctly
// TODO(discord9): get keys correctly
let key_exprs = key_val_plan
.key_plan
.projection

View File

@@ -177,7 +177,7 @@ pub struct Arrangement {
///
/// Since updates typically occur as a delete followed by an insert, a small vector of size 2 is used to store updates for efficiency.
///
/// TODO: Consider balancing the batch size?
/// TODO(discord9): Consider balancing the batch size?
spine: Spine,
/// Indicates whether the arrangement maintains a complete history of updates.