diff --git a/Cargo.lock b/Cargo.lock index dad0265790..7219c09f89 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3809,7 +3809,9 @@ name = "flow" version = "0.8.2" dependencies = [ "api", + "arrow-schema", "async-trait", + "bytes", "catalog", "common-base", "common-catalog", @@ -3824,8 +3826,10 @@ dependencies = [ "common-runtime", "common-telemetry", "common-time", + "datafusion 38.0.0", "datafusion-common 38.0.0", "datafusion-expr 38.0.0", + "datafusion-physical-expr 38.0.0", "datatypes", "enum-as-inner", "enum_dispatch", diff --git a/src/flow/Cargo.toml b/src/flow/Cargo.toml index 1f1bd1562f..8283f0595f 100644 --- a/src/flow/Cargo.toml +++ b/src/flow/Cargo.toml @@ -9,29 +9,33 @@ workspace = true [dependencies] api.workspace = true +arrow-schema.workspace = true +async-trait.workspace = true +bytes.workspace = true catalog.workspace = true common-base.workspace = true common-decimal.workspace = true common-error.workspace = true common-frontend.workspace = true -common-macro.workspace = true -common-runtime.workspace = true -common-telemetry.workspace = true -common-time.workspace = true -datafusion-common.workspace = true -datafusion-expr.workspace = true -datatypes.workspace = true -enum_dispatch = "0.3" -futures = "0.3" -# This fork is simply for keeping our dependency in our org, and pin the version -# it is the same with upstream repo -async-trait.workspace = true common-function.workspace = true +common-macro.workspace = true common-meta.workspace = true common-query.workspace = true common-recordbatch.workspace = true +common-runtime.workspace = true +common-telemetry.workspace = true +common-time.workspace = true +datafusion.workspace = true +datafusion-common.workspace = true +datafusion-expr.workspace = true +datafusion-physical-expr.workspace = true +datatypes.workspace = true enum-as-inner = "0.6.0" +enum_dispatch = "0.3" +futures = "0.3" greptime-proto.workspace = true +# This fork of hydroflow is simply for keeping our dependency in our org, and pin the version +# otherwise it is the same with upstream repo hydroflow = { git = "https://github.com/GreptimeTeam/hydroflow.git", branch = "main" } itertools.workspace = true minstant = "0.1.7" diff --git a/src/flow/src/adapter.rs b/src/flow/src/adapter.rs index 80c3a9ff78..fd07ff1dc2 100644 --- a/src/flow/src/adapter.rs +++ b/src/flow/src/adapter.rs @@ -64,12 +64,12 @@ mod table_source; use error::Error; -// TODO: replace this with `GREPTIME_TIMESTAMP` before v0.9 +// TODO(discord9): replace this with `GREPTIME_TIMESTAMP` before v0.9 pub const AUTO_CREATED_PLACEHOLDER_TS_COL: &str = "__ts_placeholder"; pub const UPDATE_AT_TS_COL: &str = "update_at"; -// TODO: refactor common types for flow to a separate module +// TODO(discord9): refactor common types for flow to a separate module /// FlowId is a unique identifier for a flow task pub type FlowId = u64; pub type TableName = [String; 3]; diff --git a/src/flow/src/adapter/node_context.rs b/src/flow/src/adapter/node_context.rs index dcbfb65719..40c5169f5e 100644 --- a/src/flow/src/adapter/node_context.rs +++ b/src/flow/src/adapter/node_context.rs @@ -79,7 +79,6 @@ impl Default for SourceSender { } } -// TODO: make all send operation immut impl SourceSender { pub fn get_receiver(&self) -> broadcast::Receiver { self.sender.subscribe() diff --git a/src/flow/src/expr.rs b/src/flow/src/expr.rs index 7fb2ba7f29..aefc4db3be 100644 --- a/src/flow/src/expr.rs +++ b/src/flow/src/expr.rs @@ -27,4 +27,4 @@ pub(crate) use func::{BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc} pub(crate) use id::{GlobalId, Id, LocalId}; pub(crate) use linear::{MapFilterProject, MfpPlan, SafeMfpPlan}; pub(crate) use relation::{AggregateExpr, AggregateFunc}; -pub(crate) use scalar::{ScalarExpr, TypedExpr}; +pub(crate) use scalar::{DfScalarFunction, RawDfScalarFn, ScalarExpr, TypedExpr}; diff --git a/src/flow/src/expr/error.rs b/src/flow/src/expr/error.rs index 09ad758056..ff1765df49 100644 --- a/src/flow/src/expr/error.rs +++ b/src/flow/src/expr/error.rs @@ -16,12 +16,15 @@ use std::any::Any; +use arrow_schema::ArrowError; +use common_error::ext::BoxedError; use common_macro::stack_trace_debug; use common_telemetry::common_error::ext::ErrorExt; use common_telemetry::common_error::status_code::StatusCode; +use datafusion_common::DataFusionError; use datatypes::data_type::ConcreteDataType; use serde::{Deserialize, Serialize}; -use snafu::{Location, Snafu}; +use snafu::{Location, ResultExt, Snafu}; fn is_send_sync() { fn check() {} @@ -107,4 +110,27 @@ pub enum EvalError { #[snafu(implicit)] location: Location, }, + + #[snafu(display("Arrow error: {raw:?}, context: {context}"))] + Arrow { + raw: ArrowError, + context: String, + #[snafu(implicit)] + location: Location, + }, + + #[snafu(display("DataFusion error: {raw:?}, context: {context}"))] + Datafusion { + raw: DataFusionError, + context: String, + #[snafu(implicit)] + location: Location, + }, + + #[snafu(display("External error"))] + External { + #[snafu(implicit)] + location: Location, + source: BoxedError, + }, } diff --git a/src/flow/src/expr/func.rs b/src/flow/src/expr/func.rs index 31131a2758..2109356ad6 100644 --- a/src/flow/src/expr/func.rs +++ b/src/flow/src/expr/func.rs @@ -814,7 +814,7 @@ impl VariadicFunc { name: &str, arg_types: &[Option], ) -> Result { - // TODO: future variadic funcs to be added might need to check arg_types + // TODO(discord9): future variadic funcs to be added might need to check arg_types let _ = arg_types; match name { "and" => Ok(Self::And), diff --git a/src/flow/src/expr/scalar.rs b/src/flow/src/expr/scalar.rs index 984d6f1a44..8103089e67 100644 --- a/src/flow/src/expr/scalar.rs +++ b/src/flow/src/expr/scalar.rs @@ -15,19 +15,33 @@ //! Scalar expressions. use std::collections::{BTreeMap, BTreeSet}; +use std::sync::{Arc, Mutex}; +use bytes::BytesMut; +use common_error::ext::BoxedError; +use common_recordbatch::DfRecordBatch; +use datafusion_physical_expr::PhysicalExpr; +use datatypes::arrow_array; +use datatypes::data_type::DataType; use datatypes::prelude::ConcreteDataType; use datatypes::value::Value; +use prost::Message; use serde::{Deserialize, Serialize}; -use snafu::ensure; +use snafu::{ensure, ResultExt}; +use substrait::error::{DecodeRelSnafu, EncodeRelSnafu}; +use substrait::substrait_proto_df::proto::expression::{RexType, ScalarFunction}; +use substrait::substrait_proto_df::proto::Expression; use crate::adapter::error::{ - Error, InvalidQuerySnafu, UnexpectedSnafu, UnsupportedTemporalFilterSnafu, + DatafusionSnafu, Error, InvalidQuerySnafu, UnexpectedSnafu, UnsupportedTemporalFilterSnafu, +}; +use crate::expr::error::{ + ArrowSnafu, DatafusionSnafu as EvalDatafusionSnafu, EvalError, ExternalSnafu, + InvalidArgumentSnafu, OptimizeSnafu, }; -use crate::expr::error::{EvalError, InvalidArgumentSnafu, OptimizeSnafu}; use crate::expr::func::{BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc}; -use crate::repr::{ColumnType, RelationType}; - +use crate::repr::{ColumnType, RelationDesc, RelationType}; +use crate::transform::{from_scalar_fn_to_df_fn_impl, FunctionExtensions}; /// A scalar expression with a known type. #[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)] pub struct TypedExpr { @@ -47,6 +61,8 @@ impl TypedExpr { /// expand multi-value expression to multiple expressions with new indices /// /// Currently it just mean expand `TumbleWindow` to `TumbleWindowFloor` and `TumbleWindowCeiling` + /// + /// TODO(discord9): test if nested reduce combine with df scalar function would cause problem pub fn expand_multi_value( input_typ: &RelationType, exprs: &[TypedExpr], @@ -138,6 +154,10 @@ pub enum ScalarExpr { func: VariadicFunc, exprs: Vec, }, + CallDf { + // TODO(discord9): support shuffle + df_scalar_fn: DfScalarFunction, + }, /// Conditionally evaluated expressions. /// /// It is important that `then` and `els` only be evaluated if @@ -151,6 +171,161 @@ pub enum ScalarExpr { }, } +/// A way to represent a scalar function that is implemented in Datafusion +#[derive(Debug, Clone)] +pub struct DfScalarFunction { + raw_fn: RawDfScalarFn, + // TODO(discord9): directly from datafusion expr + fn_impl: Arc, + df_schema: Arc, +} + +impl DfScalarFunction { + pub fn new(raw_fn: RawDfScalarFn, fn_impl: Arc) -> Result { + Ok(Self { + df_schema: Arc::new(raw_fn.input_schema.to_df_schema()?), + raw_fn, + fn_impl, + }) + } + + // TODO(discord9): add RecordBatch support + pub fn eval(&self, values: &[Value]) -> Result { + if values.is_empty() { + return InvalidArgumentSnafu { + reason: "values is empty".to_string(), + } + .fail(); + } + // TODO(discord9): make cols all array length of one + let mut cols = vec![]; + for (idx, typ) in self + .raw_fn + .input_schema + .typ() + .column_types + .iter() + .enumerate() + { + let typ = typ.scalar_type(); + let mut array = typ.create_mutable_vector(1); + array.push_value_ref(values[idx].as_value_ref()); + cols.push(array.to_vector().to_arrow_array()); + } + let schema = self.df_schema.inner().clone(); + let rb = DfRecordBatch::try_new(schema, cols).map_err(|err| { + ArrowSnafu { + raw: err, + context: + "Failed to create RecordBatch from values when eval datafusion scalar function", + } + .build() + })?; + let res = self.fn_impl.evaluate(&rb).map_err(|err| { + EvalDatafusionSnafu { + raw: err, + context: "Failed to evaluate datafusion scalar function", + } + .build() + })?; + let res = common_query::columnar_value::ColumnarValue::try_from(&res) + .map_err(BoxedError::new) + .context(ExternalSnafu)?; + let res_vec = res + .try_into_vector(1) + .map_err(BoxedError::new) + .context(ExternalSnafu)?; + let res_val = res_vec + .try_get(0) + .map_err(BoxedError::new) + .context(ExternalSnafu)?; + Ok(res_val) + } +} + +// simply serialize the raw_fn instead of derive to avoid complex deserialize of struct +impl Serialize for DfScalarFunction { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + self.raw_fn.serialize(serializer) + } +} + +impl<'de> serde::de::Deserialize<'de> for DfScalarFunction { + fn deserialize(deserializer: D) -> Result + where + D: serde::de::Deserializer<'de>, + { + let raw_fn = RawDfScalarFn::deserialize(deserializer)?; + let fn_impl = raw_fn.get_fn_impl().map_err(serde::de::Error::custom)?; + DfScalarFunction::new(raw_fn, fn_impl).map_err(serde::de::Error::custom) + } +} + +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct RawDfScalarFn { + f: bytes::BytesMut, + input_schema: RelationDesc, + extensions: FunctionExtensions, +} + +impl RawDfScalarFn { + pub fn from_proto( + f: &substrait::substrait_proto_df::proto::expression::ScalarFunction, + input_schema: RelationDesc, + extensions: FunctionExtensions, + ) -> Result { + let mut buf = BytesMut::new(); + f.encode(&mut buf) + .context(EncodeRelSnafu) + .map_err(BoxedError::new) + .context(crate::adapter::error::ExternalSnafu)?; + Ok(Self { + f: buf, + input_schema, + extensions, + }) + } + fn get_fn_impl(&self) -> Result, Error> { + let f = ScalarFunction::decode(&mut self.f.as_ref()) + .context(DecodeRelSnafu) + .map_err(BoxedError::new) + .context(crate::adapter::error::ExternalSnafu)?; + + let input_schema = &self.input_schema; + let extensions = &self.extensions; + + from_scalar_fn_to_df_fn_impl(&f, input_schema, extensions) + } +} + +impl std::cmp::PartialEq for DfScalarFunction { + fn eq(&self, other: &Self) -> bool { + self.raw_fn.eq(&other.raw_fn) + } +} + +// can't derive Eq because of Arc not eq, so implement it manually +impl std::cmp::Eq for DfScalarFunction {} + +impl std::cmp::PartialOrd for DfScalarFunction { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl std::cmp::Ord for DfScalarFunction { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.raw_fn.cmp(&other.raw_fn) + } +} +impl std::hash::Hash for DfScalarFunction { + fn hash(&self, state: &mut H) { + self.raw_fn.hash(state); + } +} + impl ScalarExpr { pub fn with_type(self, typ: ColumnType) -> TypedExpr { TypedExpr::new(self, typ) @@ -179,6 +354,23 @@ impl ScalarExpr { Ok(ColumnType::new_nullable(func.signature().output)) } ScalarExpr::If { then, .. } => then.typ(context), + ScalarExpr::CallDf { df_scalar_fn } => { + let arrow_typ = df_scalar_fn + .fn_impl + // TODO(discord9): get scheme from args instead? + .data_type(df_scalar_fn.df_schema.as_arrow()) + .map_err(|err| { + DatafusionSnafu { + raw: err, + context: "Failed to get data type from datafusion scalar function", + } + .build() + })?; + let typ = ConcreteDataType::try_from(&arrow_typ) + .map_err(BoxedError::new) + .context(crate::adapter::error::ExternalSnafu)?; + Ok(ColumnType::new_nullable(typ)) + } } } } @@ -253,6 +445,7 @@ impl ScalarExpr { } .fail(), }, + ScalarExpr::CallDf { df_scalar_fn } => df_scalar_fn.eval(values), } } @@ -421,6 +614,7 @@ impl ScalarExpr { f(then)?; f(els) } + _ => Ok(()), } } @@ -456,6 +650,7 @@ impl ScalarExpr { f(then)?; f(els) } + _ => Ok(()), } } } @@ -497,7 +692,7 @@ impl ScalarExpr { return unsupported_err("Not a binary expression"); }; - // TODO: support simple transform like `now() + a < b` to `now() < b - a` + // TODO(discord9): support simple transform like `now() + a < b` to `now() < b - a` let expr1_is_now = *expr1 == ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now); let expr2_is_now = *expr2 == ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now); @@ -531,6 +726,15 @@ impl ScalarExpr { #[cfg(test)] mod test { use datatypes::arrow::array::Scalar; + use query::parser::QueryLanguageParser; + use query::QueryEngine; + use session::context::QueryContext; + use substrait::extension_serializer; + use substrait::substrait_proto_df::proto::expression::literal::LiteralType; + use substrait::substrait_proto_df::proto::expression::Literal; + use substrait::substrait_proto_df::proto::function_argument::ArgType; + use substrait::substrait_proto_df::proto::r#type::Kind; + use substrait::substrait_proto_df::proto::{r#type, FunctionArgument, Type}; use super::*; #[test] @@ -622,4 +826,37 @@ mod test { let res = expr.permute_map(&permute_map); assert!(matches!(res, Err(Error::InvalidQuery { .. }))); } + + #[tokio::test] + async fn test_df_scalar_function() { + let raw_scalar_func = ScalarFunction { + function_reference: 0, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(Expression { + rex_type: Some(RexType::Literal(Literal { + nullable: false, + type_variation_reference: 0, + literal_type: Some(LiteralType::I64(-1)), + })), + })), + }], + output_type: None, + ..Default::default() + }; + let input_schema = RelationDesc::try_new( + RelationType::new(vec![ColumnType::new_nullable( + ConcreteDataType::null_datatype(), + )]), + vec!["null_column".to_string()], + ) + .unwrap(); + let extensions = FunctionExtensions::from_iter(vec![(0, "abs")]); + let raw_fn = RawDfScalarFn::from_proto(&raw_scalar_func, input_schema, extensions).unwrap(); + let fn_impl = raw_fn.get_fn_impl().unwrap(); + let df_func = DfScalarFunction::new(raw_fn, fn_impl).unwrap(); + let as_str = serde_json::to_string(&df_func).unwrap(); + let from_str: DfScalarFunction = serde_json::from_str(&as_str).unwrap(); + assert_eq!(df_func, from_str); + assert_eq!(df_func.eval(&[Value::Null]).unwrap(), Value::Int64(1)); + } } diff --git a/src/flow/src/repr/relation.rs b/src/flow/src/repr/relation.rs index 8d7fdd9a33..ae5c6b46ff 100644 --- a/src/flow/src/repr/relation.rs +++ b/src/flow/src/repr/relation.rs @@ -14,12 +14,14 @@ use std::collections::{BTreeMap, HashMap}; +use datafusion_common::DFSchema; +use datatypes::data_type::DataType; use datatypes::prelude::ConcreteDataType; use itertools::Itertools; use serde::{Deserialize, Serialize}; -use snafu::{ensure, OptionExt}; +use snafu::{ensure, OptionExt, ResultExt}; -use crate::adapter::error::{InvalidQuerySnafu, Result, UnexpectedSnafu}; +use crate::adapter::error::{DatafusionSnafu, InvalidQuerySnafu, Result, UnexpectedSnafu}; use crate::expr::{MapFilterProject, SafeMfpPlan, ScalarExpr}; /// a set of column indices that are "keys" for the collection. @@ -338,9 +340,31 @@ pub struct RelationDesc { } impl RelationDesc { + pub fn to_df_schema(&self) -> Result { + let fields: Vec<_> = self + .iter() + .enumerate() + .map(|(i, (name, typ))| { + let name = name.clone().unwrap_or(format!("Col_{i}")); + let nullable = typ.nullable; + let data_type = typ.scalar_type.clone().as_arrow_type(); + arrow_schema::Field::new(name, data_type, nullable) + }) + .collect(); + let arrow_schema = arrow_schema::Schema::new(fields); + + DFSchema::try_from(arrow_schema.clone()).map_err(|err| { + DatafusionSnafu { + raw: err, + context: format!("Error when converting to DFSchema: {:?}", arrow_schema), + } + .build() + }) + } + /// apply mfp, and also project col names for the projected columns pub fn apply_mfp(&self, mfp: &SafeMfpPlan) -> Result { - // TODO: find a way to deduce name at best effect + // TODO(discord9): find a way to deduce name at best effect let names = { let mfp = &mfp.mfp; let mut names = self.names.clone(); diff --git a/src/flow/src/transform.rs b/src/flow/src/transform.rs index 39166cce13..35d811a037 100644 --- a/src/flow/src/transform.rs +++ b/src/flow/src/transform.rs @@ -13,9 +13,10 @@ // limitations under the License. //! Transform Substrait into execution plan -use std::collections::HashMap; +use std::collections::{BTreeMap, HashMap}; use std::sync::Arc; +use bytes::buf::IntoIter; use common_error::ext::BoxedError; use common_telemetry::info; use datatypes::data_type::ConcreteDataType as CDT; @@ -25,6 +26,7 @@ use query::parser::QueryLanguageParser; use query::plan::LogicalPlan; use query::query_engine::DefaultSerializer; use query::QueryEngine; +use serde::{Deserialize, Serialize}; use session::context::QueryContext; use snafu::{OptionExt, ResultExt}; /// note here we are using the `substrait_proto_df` crate from the `substrait` module and @@ -43,7 +45,6 @@ use crate::adapter::FlownodeContext; use crate::expr::GlobalId; use crate::plan::TypedPlan; use crate::repr::RelationType; - /// a simple macro to generate a not implemented error macro_rules! not_impl_err { ($($arg:tt)*) => { @@ -67,17 +68,26 @@ mod expr; mod literal; mod plan; +pub(crate) use expr::from_scalar_fn_to_df_fn_impl; + /// In Substrait, a function can be define by an u32 anchor, and the anchor can be mapped to a name /// /// So in substrait plan, a ref to a function can be a single u32 anchor instead of a full name in string +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct FunctionExtensions { - anchor_to_name: HashMap, + anchor_to_name: BTreeMap, } impl FunctionExtensions { + pub fn from_iter(inner: impl IntoIterator) -> Self { + Self { + anchor_to_name: inner.into_iter().map(|(k, s)| (k, s.to_string())).collect(), + } + } + /// Create a new FunctionExtensions from a list of SimpleExtensionDeclaration pub fn try_from_proto(extensions: &[SimpleExtensionDeclaration]) -> Result { - let mut anchor_to_name = HashMap::new(); + let mut anchor_to_name = BTreeMap::new(); for e in extensions { match &e.mapping_type { Some(ext) => match ext { @@ -96,6 +106,10 @@ impl FunctionExtensions { pub fn get(&self, anchor: &u32) -> Option<&String> { self.anchor_to_name.get(anchor) } + + pub fn inner_ref(&self) -> HashMap { + self.anchor_to_name.iter().map(|(k, v)| (*k, v)).collect() + } } /// To reuse existing code for parse sql, the sql is first parsed into a datafusion logical plan, diff --git a/src/flow/src/transform/aggr.rs b/src/flow/src/transform/aggr.rs index 688f616ebf..19a8ba2dfc 100644 --- a/src/flow/src/transform/aggr.rs +++ b/src/flow/src/transform/aggr.rs @@ -53,14 +53,14 @@ use crate::expr::{ TypedExpr, UnaryFunc, UnmaterializableFunc, VariadicFunc, }; use crate::plan::{AccumulablePlan, AggrWithIndex, KeyValPlan, Plan, ReducePlan, TypedPlan}; -use crate::repr::{self, ColumnType, RelationType}; +use crate::repr::{self, ColumnType, RelationDesc, RelationType}; use crate::transform::{substrait_proto, FlownodeContext, FunctionExtensions}; impl TypedExpr { fn from_substrait_agg_grouping( ctx: &mut FlownodeContext, groupings: &[Grouping], - typ: &RelationType, + typ: &RelationDesc, extensions: &FunctionExtensions, ) -> Result, Error> { let _ = ctx; @@ -89,7 +89,7 @@ impl AggregateExpr { fn from_substrait_agg_measures( ctx: &mut FlownodeContext, measures: &[Measure], - typ: &RelationType, + typ: &RelationDesc, extensions: &FunctionExtensions, ) -> Result<(Vec, MapFilterProject), Error> { let _ = ctx; @@ -143,7 +143,7 @@ impl AggregateExpr { /// since aggr functions like `avg` need to be transform to `sum(x)/cast(count(x) as x_type)` pub fn from_substrait_agg_func( f: &proto::AggregateFunction, - input_schema: &RelationType, + input_schema: &RelationDesc, extensions: &FunctionExtensions, filter: &Option, order_by: &Option>, @@ -320,7 +320,7 @@ impl TypedPlan { let group_exprs = TypedExpr::from_substrait_agg_grouping( ctx, &agg.groupings, - &input.schema.typ, + &input.schema, extensions, )?; @@ -332,7 +332,7 @@ impl TypedPlan { let (mut aggr_exprs, post_mfp) = AggregateExpr::from_substrait_agg_measures( ctx, &agg.measures, - &input.schema.typ, + &input.schema, extensions, )?; diff --git a/src/flow/src/transform/expr.rs b/src/flow/src/transform/expr.rs index eb3f9bafc3..5434ea237b 100644 --- a/src/flow/src/transform/expr.rs +++ b/src/flow/src/transform/expr.rs @@ -14,6 +14,9 @@ #![warn(unused_imports)] +use std::sync::Arc; + +use datafusion_physical_expr::PhysicalExpr; use datatypes::data_type::ConcreteDataType as CDT; use itertools::Itertools; use snafu::{OptionExt, ResultExt}; @@ -24,15 +27,17 @@ use substrait_proto::proto::function_argument::ArgType; use substrait_proto::proto::Expression; use crate::adapter::error::{ - DatatypesSnafu, Error, EvalSnafu, InvalidQuerySnafu, NotImplementedSnafu, PlanSnafu, + DatafusionSnafu, DatatypesSnafu, Error, EvalSnafu, InvalidQuerySnafu, NotImplementedSnafu, + PlanSnafu, }; use crate::expr::{ - BinaryFunc, ScalarExpr, TypedExpr, UnaryFunc, UnmaterializableFunc, VariadicFunc, + BinaryFunc, DfScalarFunction, RawDfScalarFn, ScalarExpr, TypedExpr, UnaryFunc, + UnmaterializableFunc, VariadicFunc, }; -use crate::repr::{ColumnType, RelationType}; +use crate::repr::{ColumnType, RelationDesc}; use crate::transform::literal::{from_substrait_literal, from_substrait_type}; use crate::transform::{substrait_proto, FunctionExtensions}; -// TODO: found proper place for this +// TODO(discord9): found proper place for this /// ref to `arrow_schema::datatype` for type name fn typename_to_cdt(name: &str) -> CDT { match name { @@ -54,11 +59,64 @@ fn typename_to_cdt(name: &str) -> CDT { } } +/// Convert [`ScalarFunction`] to corresponding Datafusion's [`PhysicalExpr`] +pub(crate) fn from_scalar_fn_to_df_fn_impl( + f: &ScalarFunction, + input_schema: &RelationDesc, + extensions: &FunctionExtensions, +) -> Result, Error> { + let e = Expression { + rex_type: Some(RexType::ScalarFunction(f.clone())), + }; + let schema = input_schema.to_df_schema()?; + + let df_expr = futures::executor::block_on(async { + // TODO(discord9): consider coloring everything async.... + substrait::df_logical_plan::consumer::from_substrait_rex( + &datafusion::prelude::SessionContext::new(), + &e, + &schema, + &extensions.inner_ref(), + ) + .await + }); + let expr = df_expr.map_err(|err| { + DatafusionSnafu { + raw: err, + context: "Failed to convert substrait scalar function to datafusion scalar function", + } + .build() + })?; + let phy_expr = + datafusion::physical_expr::create_physical_expr(&expr, &schema, &Default::default()) + .map_err(|err| { + DatafusionSnafu { + raw: err, + context: "Failed to create physical expression from logical expression", + } + .build() + })?; + Ok(phy_expr) +} + impl TypedExpr { + pub fn from_substrait_to_datafusion_scalar_func( + f: &ScalarFunction, + input_schema: &RelationDesc, + extensions: &FunctionExtensions, + ) -> Result { + let phy_expr = from_scalar_fn_to_df_fn_impl(f, input_schema, extensions)?; + let raw_fn = RawDfScalarFn::from_proto(f, input_schema.clone(), extensions.clone())?; + let expr = DfScalarFunction::new(raw_fn, phy_expr)?; + let expr = ScalarExpr::CallDf { df_scalar_fn: expr }; + // df already know it's own schema, so not providing here + let ret_type = expr.typ(&[])?; + Ok(TypedExpr::new(expr, ret_type)) + } /// Convert ScalarFunction into Flow's ScalarExpr pub fn from_substrait_scalar_func( f: &ScalarFunction, - input_schema: &RelationType, + input_schema: &RelationDesc, extensions: &FunctionExtensions, ) -> Result { let fn_name = @@ -182,7 +240,12 @@ impl TypedExpr { ret_type, )) } else { - not_impl_err!("Unsupported function {fn_name} with {arg_len} arguments") + let try_as_df = Self::from_substrait_to_datafusion_scalar_func( + f, + input_schema, + extensions, + )?; + Ok(try_as_df) } } } @@ -191,7 +254,7 @@ impl TypedExpr { /// Convert IfThen into Flow's ScalarExpr pub fn from_substrait_ifthen_rex( if_then: &IfThen, - input_schema: &RelationType, + input_schema: &RelationDesc, extensions: &FunctionExtensions, ) -> Result { let ifs: Vec<_> = if_then @@ -246,7 +309,7 @@ impl TypedExpr { /// Convert Substrait Rex into Flow's ScalarExpr pub fn from_substrait_rex( e: &Expression, - input_schema: &RelationType, + input_schema: &RelationDesc, extensions: &FunctionExtensions, ) -> Result { match &e.rex_type { @@ -277,7 +340,7 @@ impl TypedExpr { } None => { let column = x.field as usize; - let column_type = input_schema.column_types[column].clone(); + let column_type = input_schema.typ().column_types[column].clone(); Ok(TypedExpr::new(ScalarExpr::Column(column), column_type)) } }, @@ -322,8 +385,6 @@ impl TypedExpr { #[cfg(test)] mod test { - use std::collections::HashMap; - use common_time::{DateTime, Interval}; use datatypes::prelude::ConcreteDataType; use datatypes::value::Value; @@ -584,10 +645,9 @@ mod test { output_type: None, ..Default::default() }; - let input_schema = RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), false)]); - let extensions = FunctionExtensions { - anchor_to_name: HashMap::from([(0, "is_null".to_string())]), - }; + let input_schema = + RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), false)]).into_unnamed(); + let extensions = FunctionExtensions::from_iter([(0, "is_null".to_string())]); let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions).unwrap(); assert_eq!( @@ -611,10 +671,9 @@ mod test { let input_schema = RelationType::new(vec![ ColumnType::new(CDT::uint32_datatype(), false), ColumnType::new(CDT::uint32_datatype(), false), - ]); - let extensions = FunctionExtensions { - anchor_to_name: HashMap::from([(0, "add".to_string())]), - }; + ]) + .into_unnamed(); + let extensions = FunctionExtensions::from_iter([(0, "add".to_string())]); let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions).unwrap(); assert_eq!( @@ -639,10 +698,9 @@ mod test { let input_schema = RelationType::new(vec![ ColumnType::new(CDT::timestamp_nanosecond_datatype(), false), ColumnType::new(CDT::string_datatype(), false), - ]); - let extensions = FunctionExtensions { - anchor_to_name: HashMap::from([(0, "tumble".to_string())]), - }; + ]) + .into_unnamed(); + let extensions = FunctionExtensions::from_iter(vec![(0, "tumble".to_string())]); let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions).unwrap(); assert_eq!( @@ -668,10 +726,9 @@ mod test { let input_schema = RelationType::new(vec![ ColumnType::new(CDT::timestamp_nanosecond_datatype(), false), ColumnType::new(CDT::string_datatype(), false), - ]); - let extensions = FunctionExtensions { - anchor_to_name: HashMap::from([(0, "tumble".to_string())]), - }; + ]) + .into_unnamed(); + let extensions = FunctionExtensions::from_iter(vec![(0, "tumble".to_string())]); let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions).unwrap(); assert_eq!( diff --git a/src/flow/src/transform/plan.rs b/src/flow/src/transform/plan.rs index 813266ee4d..339fe80586 100644 --- a/src/flow/src/transform/plan.rs +++ b/src/flow/src/transform/plan.rs @@ -64,7 +64,7 @@ impl TypedPlan { } /// Convert Substrait Rel into Flow's TypedPlan - /// TODO: SELECT DISTINCT(does it get compile with something else?) + /// TODO(discord9): SELECT DISTINCT(does it get compile with something else?) pub fn from_substrait_rel( ctx: &mut FlownodeContext, rel: &Rel, @@ -80,7 +80,7 @@ impl TypedPlan { let mut exprs: Vec = vec![]; for e in &p.expressions { - let expr = TypedExpr::from_substrait_rex(e, &input.schema.typ, extensions)?; + let expr = TypedExpr::from_substrait_rex(e, &input.schema, extensions)?; exprs.push(expr); } let is_literal = exprs.iter().all(|expr| expr.expr.is_literal()); @@ -133,7 +133,7 @@ impl TypedPlan { }; let expr = if let Some(condition) = filter.condition.as_ref() { - TypedExpr::from_substrait_rex(condition, &input.schema.typ, extensions)? + TypedExpr::from_substrait_rex(condition, &input.schema, extensions)? } else { return not_impl_err!("Filter without an condition is not valid"); }; @@ -213,7 +213,7 @@ fn rewrite_projection_after_reduce( reduce_output_type: &RelationDesc, proj_exprs: &mut Vec, ) -> Result<(), Error> { - // TODO: get keys correctly + // TODO(discord9): get keys correctly let key_exprs = key_val_plan .key_plan .projection diff --git a/src/flow/src/utils.rs b/src/flow/src/utils.rs index 30d48f0319..69c300ab8f 100644 --- a/src/flow/src/utils.rs +++ b/src/flow/src/utils.rs @@ -177,7 +177,7 @@ pub struct Arrangement { /// /// Since updates typically occur as a delete followed by an insert, a small vector of size 2 is used to store updates for efficiency. /// - /// TODO: Consider balancing the batch size? + /// TODO(discord9): Consider balancing the batch size? spine: Spine, /// Indicates whether the arrangement maintains a complete history of updates.