mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-18 05:50:41 +00:00
feat(flow): support datafusion scalar function (#4142)
* chore: call df function types * feat: RelationDesc to DfSchema * refactor: use RelationDesc instead of Type * chore: WIP get to phy expr * feat: custom deserialize * chore: fmt * refactor: renaming to DfScalarFunction * feat: eval df func(untested) * fix: had to spawn a thread for calling async * chore: per review advices * tests: test df scalar function
This commit is contained in:
4
Cargo.lock
generated
4
Cargo.lock
generated
@@ -3809,7 +3809,9 @@ name = "flow"
|
||||
version = "0.8.2"
|
||||
dependencies = [
|
||||
"api",
|
||||
"arrow-schema",
|
||||
"async-trait",
|
||||
"bytes",
|
||||
"catalog",
|
||||
"common-base",
|
||||
"common-catalog",
|
||||
@@ -3824,8 +3826,10 @@ dependencies = [
|
||||
"common-runtime",
|
||||
"common-telemetry",
|
||||
"common-time",
|
||||
"datafusion 38.0.0",
|
||||
"datafusion-common 38.0.0",
|
||||
"datafusion-expr 38.0.0",
|
||||
"datafusion-physical-expr 38.0.0",
|
||||
"datatypes",
|
||||
"enum-as-inner",
|
||||
"enum_dispatch",
|
||||
|
||||
@@ -9,29 +9,33 @@ workspace = true
|
||||
|
||||
[dependencies]
|
||||
api.workspace = true
|
||||
arrow-schema.workspace = true
|
||||
async-trait.workspace = true
|
||||
bytes.workspace = true
|
||||
catalog.workspace = true
|
||||
common-base.workspace = true
|
||||
common-decimal.workspace = true
|
||||
common-error.workspace = true
|
||||
common-frontend.workspace = true
|
||||
common-macro.workspace = true
|
||||
common-runtime.workspace = true
|
||||
common-telemetry.workspace = true
|
||||
common-time.workspace = true
|
||||
datafusion-common.workspace = true
|
||||
datafusion-expr.workspace = true
|
||||
datatypes.workspace = true
|
||||
enum_dispatch = "0.3"
|
||||
futures = "0.3"
|
||||
# This fork is simply for keeping our dependency in our org, and pin the version
|
||||
# it is the same with upstream repo
|
||||
async-trait.workspace = true
|
||||
common-function.workspace = true
|
||||
common-macro.workspace = true
|
||||
common-meta.workspace = true
|
||||
common-query.workspace = true
|
||||
common-recordbatch.workspace = true
|
||||
common-runtime.workspace = true
|
||||
common-telemetry.workspace = true
|
||||
common-time.workspace = true
|
||||
datafusion.workspace = true
|
||||
datafusion-common.workspace = true
|
||||
datafusion-expr.workspace = true
|
||||
datafusion-physical-expr.workspace = true
|
||||
datatypes.workspace = true
|
||||
enum-as-inner = "0.6.0"
|
||||
enum_dispatch = "0.3"
|
||||
futures = "0.3"
|
||||
greptime-proto.workspace = true
|
||||
# This fork of hydroflow is simply for keeping our dependency in our org, and pin the version
|
||||
# otherwise it is the same with upstream repo
|
||||
hydroflow = { git = "https://github.com/GreptimeTeam/hydroflow.git", branch = "main" }
|
||||
itertools.workspace = true
|
||||
minstant = "0.1.7"
|
||||
|
||||
@@ -64,12 +64,12 @@ mod table_source;
|
||||
|
||||
use error::Error;
|
||||
|
||||
// TODO: replace this with `GREPTIME_TIMESTAMP` before v0.9
|
||||
// TODO(discord9): replace this with `GREPTIME_TIMESTAMP` before v0.9
|
||||
pub const AUTO_CREATED_PLACEHOLDER_TS_COL: &str = "__ts_placeholder";
|
||||
|
||||
pub const UPDATE_AT_TS_COL: &str = "update_at";
|
||||
|
||||
// TODO: refactor common types for flow to a separate module
|
||||
// TODO(discord9): refactor common types for flow to a separate module
|
||||
/// FlowId is a unique identifier for a flow task
|
||||
pub type FlowId = u64;
|
||||
pub type TableName = [String; 3];
|
||||
|
||||
@@ -79,7 +79,6 @@ impl Default for SourceSender {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: make all send operation immut
|
||||
impl SourceSender {
|
||||
pub fn get_receiver(&self) -> broadcast::Receiver<DiffRow> {
|
||||
self.sender.subscribe()
|
||||
|
||||
@@ -27,4 +27,4 @@ pub(crate) use func::{BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc}
|
||||
pub(crate) use id::{GlobalId, Id, LocalId};
|
||||
pub(crate) use linear::{MapFilterProject, MfpPlan, SafeMfpPlan};
|
||||
pub(crate) use relation::{AggregateExpr, AggregateFunc};
|
||||
pub(crate) use scalar::{ScalarExpr, TypedExpr};
|
||||
pub(crate) use scalar::{DfScalarFunction, RawDfScalarFn, ScalarExpr, TypedExpr};
|
||||
|
||||
@@ -16,12 +16,15 @@
|
||||
|
||||
use std::any::Any;
|
||||
|
||||
use arrow_schema::ArrowError;
|
||||
use common_error::ext::BoxedError;
|
||||
use common_macro::stack_trace_debug;
|
||||
use common_telemetry::common_error::ext::ErrorExt;
|
||||
use common_telemetry::common_error::status_code::StatusCode;
|
||||
use datafusion_common::DataFusionError;
|
||||
use datatypes::data_type::ConcreteDataType;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use snafu::{Location, Snafu};
|
||||
use snafu::{Location, ResultExt, Snafu};
|
||||
|
||||
fn is_send_sync() {
|
||||
fn check<T: Send + Sync>() {}
|
||||
@@ -107,4 +110,27 @@ pub enum EvalError {
|
||||
#[snafu(implicit)]
|
||||
location: Location,
|
||||
},
|
||||
|
||||
#[snafu(display("Arrow error: {raw:?}, context: {context}"))]
|
||||
Arrow {
|
||||
raw: ArrowError,
|
||||
context: String,
|
||||
#[snafu(implicit)]
|
||||
location: Location,
|
||||
},
|
||||
|
||||
#[snafu(display("DataFusion error: {raw:?}, context: {context}"))]
|
||||
Datafusion {
|
||||
raw: DataFusionError,
|
||||
context: String,
|
||||
#[snafu(implicit)]
|
||||
location: Location,
|
||||
},
|
||||
|
||||
#[snafu(display("External error"))]
|
||||
External {
|
||||
#[snafu(implicit)]
|
||||
location: Location,
|
||||
source: BoxedError,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -814,7 +814,7 @@ impl VariadicFunc {
|
||||
name: &str,
|
||||
arg_types: &[Option<ConcreteDataType>],
|
||||
) -> Result<Self, Error> {
|
||||
// TODO: future variadic funcs to be added might need to check arg_types
|
||||
// TODO(discord9): future variadic funcs to be added might need to check arg_types
|
||||
let _ = arg_types;
|
||||
match name {
|
||||
"and" => Ok(Self::And),
|
||||
|
||||
@@ -15,19 +15,33 @@
|
||||
//! Scalar expressions.
|
||||
|
||||
use std::collections::{BTreeMap, BTreeSet};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use bytes::BytesMut;
|
||||
use common_error::ext::BoxedError;
|
||||
use common_recordbatch::DfRecordBatch;
|
||||
use datafusion_physical_expr::PhysicalExpr;
|
||||
use datatypes::arrow_array;
|
||||
use datatypes::data_type::DataType;
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use datatypes::value::Value;
|
||||
use prost::Message;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use snafu::ensure;
|
||||
use snafu::{ensure, ResultExt};
|
||||
use substrait::error::{DecodeRelSnafu, EncodeRelSnafu};
|
||||
use substrait::substrait_proto_df::proto::expression::{RexType, ScalarFunction};
|
||||
use substrait::substrait_proto_df::proto::Expression;
|
||||
|
||||
use crate::adapter::error::{
|
||||
Error, InvalidQuerySnafu, UnexpectedSnafu, UnsupportedTemporalFilterSnafu,
|
||||
DatafusionSnafu, Error, InvalidQuerySnafu, UnexpectedSnafu, UnsupportedTemporalFilterSnafu,
|
||||
};
|
||||
use crate::expr::error::{
|
||||
ArrowSnafu, DatafusionSnafu as EvalDatafusionSnafu, EvalError, ExternalSnafu,
|
||||
InvalidArgumentSnafu, OptimizeSnafu,
|
||||
};
|
||||
use crate::expr::error::{EvalError, InvalidArgumentSnafu, OptimizeSnafu};
|
||||
use crate::expr::func::{BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc};
|
||||
use crate::repr::{ColumnType, RelationType};
|
||||
|
||||
use crate::repr::{ColumnType, RelationDesc, RelationType};
|
||||
use crate::transform::{from_scalar_fn_to_df_fn_impl, FunctionExtensions};
|
||||
/// A scalar expression with a known type.
|
||||
#[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
|
||||
pub struct TypedExpr {
|
||||
@@ -47,6 +61,8 @@ impl TypedExpr {
|
||||
/// expand multi-value expression to multiple expressions with new indices
|
||||
///
|
||||
/// Currently it just mean expand `TumbleWindow` to `TumbleWindowFloor` and `TumbleWindowCeiling`
|
||||
///
|
||||
/// TODO(discord9): test if nested reduce combine with df scalar function would cause problem
|
||||
pub fn expand_multi_value(
|
||||
input_typ: &RelationType,
|
||||
exprs: &[TypedExpr],
|
||||
@@ -138,6 +154,10 @@ pub enum ScalarExpr {
|
||||
func: VariadicFunc,
|
||||
exprs: Vec<ScalarExpr>,
|
||||
},
|
||||
CallDf {
|
||||
// TODO(discord9): support shuffle
|
||||
df_scalar_fn: DfScalarFunction,
|
||||
},
|
||||
/// Conditionally evaluated expressions.
|
||||
///
|
||||
/// It is important that `then` and `els` only be evaluated if
|
||||
@@ -151,6 +171,161 @@ pub enum ScalarExpr {
|
||||
},
|
||||
}
|
||||
|
||||
/// A way to represent a scalar function that is implemented in Datafusion
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DfScalarFunction {
|
||||
raw_fn: RawDfScalarFn,
|
||||
// TODO(discord9): directly from datafusion expr
|
||||
fn_impl: Arc<dyn PhysicalExpr>,
|
||||
df_schema: Arc<datafusion_common::DFSchema>,
|
||||
}
|
||||
|
||||
impl DfScalarFunction {
|
||||
pub fn new(raw_fn: RawDfScalarFn, fn_impl: Arc<dyn PhysicalExpr>) -> Result<Self, Error> {
|
||||
Ok(Self {
|
||||
df_schema: Arc::new(raw_fn.input_schema.to_df_schema()?),
|
||||
raw_fn,
|
||||
fn_impl,
|
||||
})
|
||||
}
|
||||
|
||||
// TODO(discord9): add RecordBatch support
|
||||
pub fn eval(&self, values: &[Value]) -> Result<Value, EvalError> {
|
||||
if values.is_empty() {
|
||||
return InvalidArgumentSnafu {
|
||||
reason: "values is empty".to_string(),
|
||||
}
|
||||
.fail();
|
||||
}
|
||||
// TODO(discord9): make cols all array length of one
|
||||
let mut cols = vec![];
|
||||
for (idx, typ) in self
|
||||
.raw_fn
|
||||
.input_schema
|
||||
.typ()
|
||||
.column_types
|
||||
.iter()
|
||||
.enumerate()
|
||||
{
|
||||
let typ = typ.scalar_type();
|
||||
let mut array = typ.create_mutable_vector(1);
|
||||
array.push_value_ref(values[idx].as_value_ref());
|
||||
cols.push(array.to_vector().to_arrow_array());
|
||||
}
|
||||
let schema = self.df_schema.inner().clone();
|
||||
let rb = DfRecordBatch::try_new(schema, cols).map_err(|err| {
|
||||
ArrowSnafu {
|
||||
raw: err,
|
||||
context:
|
||||
"Failed to create RecordBatch from values when eval datafusion scalar function",
|
||||
}
|
||||
.build()
|
||||
})?;
|
||||
let res = self.fn_impl.evaluate(&rb).map_err(|err| {
|
||||
EvalDatafusionSnafu {
|
||||
raw: err,
|
||||
context: "Failed to evaluate datafusion scalar function",
|
||||
}
|
||||
.build()
|
||||
})?;
|
||||
let res = common_query::columnar_value::ColumnarValue::try_from(&res)
|
||||
.map_err(BoxedError::new)
|
||||
.context(ExternalSnafu)?;
|
||||
let res_vec = res
|
||||
.try_into_vector(1)
|
||||
.map_err(BoxedError::new)
|
||||
.context(ExternalSnafu)?;
|
||||
let res_val = res_vec
|
||||
.try_get(0)
|
||||
.map_err(BoxedError::new)
|
||||
.context(ExternalSnafu)?;
|
||||
Ok(res_val)
|
||||
}
|
||||
}
|
||||
|
||||
// simply serialize the raw_fn instead of derive to avoid complex deserialize of struct
|
||||
impl Serialize for DfScalarFunction {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
self.raw_fn.serialize(serializer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> serde::de::Deserialize<'de> for DfScalarFunction {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::de::Deserializer<'de>,
|
||||
{
|
||||
let raw_fn = RawDfScalarFn::deserialize(deserializer)?;
|
||||
let fn_impl = raw_fn.get_fn_impl().map_err(serde::de::Error::custom)?;
|
||||
DfScalarFunction::new(raw_fn, fn_impl).map_err(serde::de::Error::custom)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
pub struct RawDfScalarFn {
|
||||
f: bytes::BytesMut,
|
||||
input_schema: RelationDesc,
|
||||
extensions: FunctionExtensions,
|
||||
}
|
||||
|
||||
impl RawDfScalarFn {
|
||||
pub fn from_proto(
|
||||
f: &substrait::substrait_proto_df::proto::expression::ScalarFunction,
|
||||
input_schema: RelationDesc,
|
||||
extensions: FunctionExtensions,
|
||||
) -> Result<Self, Error> {
|
||||
let mut buf = BytesMut::new();
|
||||
f.encode(&mut buf)
|
||||
.context(EncodeRelSnafu)
|
||||
.map_err(BoxedError::new)
|
||||
.context(crate::adapter::error::ExternalSnafu)?;
|
||||
Ok(Self {
|
||||
f: buf,
|
||||
input_schema,
|
||||
extensions,
|
||||
})
|
||||
}
|
||||
fn get_fn_impl(&self) -> Result<Arc<dyn PhysicalExpr>, Error> {
|
||||
let f = ScalarFunction::decode(&mut self.f.as_ref())
|
||||
.context(DecodeRelSnafu)
|
||||
.map_err(BoxedError::new)
|
||||
.context(crate::adapter::error::ExternalSnafu)?;
|
||||
|
||||
let input_schema = &self.input_schema;
|
||||
let extensions = &self.extensions;
|
||||
|
||||
from_scalar_fn_to_df_fn_impl(&f, input_schema, extensions)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::cmp::PartialEq for DfScalarFunction {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.raw_fn.eq(&other.raw_fn)
|
||||
}
|
||||
}
|
||||
|
||||
// can't derive Eq because of Arc<dyn PhysicalExpr> not eq, so implement it manually
|
||||
impl std::cmp::Eq for DfScalarFunction {}
|
||||
|
||||
impl std::cmp::PartialOrd for DfScalarFunction {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
impl std::cmp::Ord for DfScalarFunction {
|
||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||
self.raw_fn.cmp(&other.raw_fn)
|
||||
}
|
||||
}
|
||||
impl std::hash::Hash for DfScalarFunction {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
self.raw_fn.hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarExpr {
|
||||
pub fn with_type(self, typ: ColumnType) -> TypedExpr {
|
||||
TypedExpr::new(self, typ)
|
||||
@@ -179,6 +354,23 @@ impl ScalarExpr {
|
||||
Ok(ColumnType::new_nullable(func.signature().output))
|
||||
}
|
||||
ScalarExpr::If { then, .. } => then.typ(context),
|
||||
ScalarExpr::CallDf { df_scalar_fn } => {
|
||||
let arrow_typ = df_scalar_fn
|
||||
.fn_impl
|
||||
// TODO(discord9): get scheme from args instead?
|
||||
.data_type(df_scalar_fn.df_schema.as_arrow())
|
||||
.map_err(|err| {
|
||||
DatafusionSnafu {
|
||||
raw: err,
|
||||
context: "Failed to get data type from datafusion scalar function",
|
||||
}
|
||||
.build()
|
||||
})?;
|
||||
let typ = ConcreteDataType::try_from(&arrow_typ)
|
||||
.map_err(BoxedError::new)
|
||||
.context(crate::adapter::error::ExternalSnafu)?;
|
||||
Ok(ColumnType::new_nullable(typ))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -253,6 +445,7 @@ impl ScalarExpr {
|
||||
}
|
||||
.fail(),
|
||||
},
|
||||
ScalarExpr::CallDf { df_scalar_fn } => df_scalar_fn.eval(values),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -421,6 +614,7 @@ impl ScalarExpr {
|
||||
f(then)?;
|
||||
f(els)
|
||||
}
|
||||
_ => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -456,6 +650,7 @@ impl ScalarExpr {
|
||||
f(then)?;
|
||||
f(els)
|
||||
}
|
||||
_ => Ok(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -497,7 +692,7 @@ impl ScalarExpr {
|
||||
return unsupported_err("Not a binary expression");
|
||||
};
|
||||
|
||||
// TODO: support simple transform like `now() + a < b` to `now() < b - a`
|
||||
// TODO(discord9): support simple transform like `now() + a < b` to `now() < b - a`
|
||||
|
||||
let expr1_is_now = *expr1 == ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now);
|
||||
let expr2_is_now = *expr2 == ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now);
|
||||
@@ -531,6 +726,15 @@ impl ScalarExpr {
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use datatypes::arrow::array::Scalar;
|
||||
use query::parser::QueryLanguageParser;
|
||||
use query::QueryEngine;
|
||||
use session::context::QueryContext;
|
||||
use substrait::extension_serializer;
|
||||
use substrait::substrait_proto_df::proto::expression::literal::LiteralType;
|
||||
use substrait::substrait_proto_df::proto::expression::Literal;
|
||||
use substrait::substrait_proto_df::proto::function_argument::ArgType;
|
||||
use substrait::substrait_proto_df::proto::r#type::Kind;
|
||||
use substrait::substrait_proto_df::proto::{r#type, FunctionArgument, Type};
|
||||
|
||||
use super::*;
|
||||
#[test]
|
||||
@@ -622,4 +826,37 @@ mod test {
|
||||
let res = expr.permute_map(&permute_map);
|
||||
assert!(matches!(res, Err(Error::InvalidQuery { .. })));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_df_scalar_function() {
|
||||
let raw_scalar_func = ScalarFunction {
|
||||
function_reference: 0,
|
||||
arguments: vec![FunctionArgument {
|
||||
arg_type: Some(ArgType::Value(Expression {
|
||||
rex_type: Some(RexType::Literal(Literal {
|
||||
nullable: false,
|
||||
type_variation_reference: 0,
|
||||
literal_type: Some(LiteralType::I64(-1)),
|
||||
})),
|
||||
})),
|
||||
}],
|
||||
output_type: None,
|
||||
..Default::default()
|
||||
};
|
||||
let input_schema = RelationDesc::try_new(
|
||||
RelationType::new(vec![ColumnType::new_nullable(
|
||||
ConcreteDataType::null_datatype(),
|
||||
)]),
|
||||
vec!["null_column".to_string()],
|
||||
)
|
||||
.unwrap();
|
||||
let extensions = FunctionExtensions::from_iter(vec![(0, "abs")]);
|
||||
let raw_fn = RawDfScalarFn::from_proto(&raw_scalar_func, input_schema, extensions).unwrap();
|
||||
let fn_impl = raw_fn.get_fn_impl().unwrap();
|
||||
let df_func = DfScalarFunction::new(raw_fn, fn_impl).unwrap();
|
||||
let as_str = serde_json::to_string(&df_func).unwrap();
|
||||
let from_str: DfScalarFunction = serde_json::from_str(&as_str).unwrap();
|
||||
assert_eq!(df_func, from_str);
|
||||
assert_eq!(df_func.eval(&[Value::Null]).unwrap(), Value::Int64(1));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,12 +14,14 @@
|
||||
|
||||
use std::collections::{BTreeMap, HashMap};
|
||||
|
||||
use datafusion_common::DFSchema;
|
||||
use datatypes::data_type::DataType;
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use itertools::Itertools;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use snafu::{ensure, OptionExt};
|
||||
use snafu::{ensure, OptionExt, ResultExt};
|
||||
|
||||
use crate::adapter::error::{InvalidQuerySnafu, Result, UnexpectedSnafu};
|
||||
use crate::adapter::error::{DatafusionSnafu, InvalidQuerySnafu, Result, UnexpectedSnafu};
|
||||
use crate::expr::{MapFilterProject, SafeMfpPlan, ScalarExpr};
|
||||
|
||||
/// a set of column indices that are "keys" for the collection.
|
||||
@@ -338,9 +340,31 @@ pub struct RelationDesc {
|
||||
}
|
||||
|
||||
impl RelationDesc {
|
||||
pub fn to_df_schema(&self) -> Result<DFSchema> {
|
||||
let fields: Vec<_> = self
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, (name, typ))| {
|
||||
let name = name.clone().unwrap_or(format!("Col_{i}"));
|
||||
let nullable = typ.nullable;
|
||||
let data_type = typ.scalar_type.clone().as_arrow_type();
|
||||
arrow_schema::Field::new(name, data_type, nullable)
|
||||
})
|
||||
.collect();
|
||||
let arrow_schema = arrow_schema::Schema::new(fields);
|
||||
|
||||
DFSchema::try_from(arrow_schema.clone()).map_err(|err| {
|
||||
DatafusionSnafu {
|
||||
raw: err,
|
||||
context: format!("Error when converting to DFSchema: {:?}", arrow_schema),
|
||||
}
|
||||
.build()
|
||||
})
|
||||
}
|
||||
|
||||
/// apply mfp, and also project col names for the projected columns
|
||||
pub fn apply_mfp(&self, mfp: &SafeMfpPlan) -> Result<Self> {
|
||||
// TODO: find a way to deduce name at best effect
|
||||
// TODO(discord9): find a way to deduce name at best effect
|
||||
let names = {
|
||||
let mfp = &mfp.mfp;
|
||||
let mut names = self.names.clone();
|
||||
|
||||
@@ -13,9 +13,10 @@
|
||||
// limitations under the License.
|
||||
|
||||
//! Transform Substrait into execution plan
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{BTreeMap, HashMap};
|
||||
use std::sync::Arc;
|
||||
|
||||
use bytes::buf::IntoIter;
|
||||
use common_error::ext::BoxedError;
|
||||
use common_telemetry::info;
|
||||
use datatypes::data_type::ConcreteDataType as CDT;
|
||||
@@ -25,6 +26,7 @@ use query::parser::QueryLanguageParser;
|
||||
use query::plan::LogicalPlan;
|
||||
use query::query_engine::DefaultSerializer;
|
||||
use query::QueryEngine;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use session::context::QueryContext;
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
/// note here we are using the `substrait_proto_df` crate from the `substrait` module and
|
||||
@@ -43,7 +45,6 @@ use crate::adapter::FlownodeContext;
|
||||
use crate::expr::GlobalId;
|
||||
use crate::plan::TypedPlan;
|
||||
use crate::repr::RelationType;
|
||||
|
||||
/// a simple macro to generate a not implemented error
|
||||
macro_rules! not_impl_err {
|
||||
($($arg:tt)*) => {
|
||||
@@ -67,17 +68,26 @@ mod expr;
|
||||
mod literal;
|
||||
mod plan;
|
||||
|
||||
pub(crate) use expr::from_scalar_fn_to_df_fn_impl;
|
||||
|
||||
/// In Substrait, a function can be define by an u32 anchor, and the anchor can be mapped to a name
|
||||
///
|
||||
/// So in substrait plan, a ref to a function can be a single u32 anchor instead of a full name in string
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
pub struct FunctionExtensions {
|
||||
anchor_to_name: HashMap<u32, String>,
|
||||
anchor_to_name: BTreeMap<u32, String>,
|
||||
}
|
||||
|
||||
impl FunctionExtensions {
|
||||
pub fn from_iter(inner: impl IntoIterator<Item = (u32, impl ToString)>) -> Self {
|
||||
Self {
|
||||
anchor_to_name: inner.into_iter().map(|(k, s)| (k, s.to_string())).collect(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new FunctionExtensions from a list of SimpleExtensionDeclaration
|
||||
pub fn try_from_proto(extensions: &[SimpleExtensionDeclaration]) -> Result<Self, Error> {
|
||||
let mut anchor_to_name = HashMap::new();
|
||||
let mut anchor_to_name = BTreeMap::new();
|
||||
for e in extensions {
|
||||
match &e.mapping_type {
|
||||
Some(ext) => match ext {
|
||||
@@ -96,6 +106,10 @@ impl FunctionExtensions {
|
||||
pub fn get(&self, anchor: &u32) -> Option<&String> {
|
||||
self.anchor_to_name.get(anchor)
|
||||
}
|
||||
|
||||
pub fn inner_ref(&self) -> HashMap<u32, &String> {
|
||||
self.anchor_to_name.iter().map(|(k, v)| (*k, v)).collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// To reuse existing code for parse sql, the sql is first parsed into a datafusion logical plan,
|
||||
|
||||
@@ -53,14 +53,14 @@ use crate::expr::{
|
||||
TypedExpr, UnaryFunc, UnmaterializableFunc, VariadicFunc,
|
||||
};
|
||||
use crate::plan::{AccumulablePlan, AggrWithIndex, KeyValPlan, Plan, ReducePlan, TypedPlan};
|
||||
use crate::repr::{self, ColumnType, RelationType};
|
||||
use crate::repr::{self, ColumnType, RelationDesc, RelationType};
|
||||
use crate::transform::{substrait_proto, FlownodeContext, FunctionExtensions};
|
||||
|
||||
impl TypedExpr {
|
||||
fn from_substrait_agg_grouping(
|
||||
ctx: &mut FlownodeContext,
|
||||
groupings: &[Grouping],
|
||||
typ: &RelationType,
|
||||
typ: &RelationDesc,
|
||||
extensions: &FunctionExtensions,
|
||||
) -> Result<Vec<TypedExpr>, Error> {
|
||||
let _ = ctx;
|
||||
@@ -89,7 +89,7 @@ impl AggregateExpr {
|
||||
fn from_substrait_agg_measures(
|
||||
ctx: &mut FlownodeContext,
|
||||
measures: &[Measure],
|
||||
typ: &RelationType,
|
||||
typ: &RelationDesc,
|
||||
extensions: &FunctionExtensions,
|
||||
) -> Result<(Vec<AggregateExpr>, MapFilterProject), Error> {
|
||||
let _ = ctx;
|
||||
@@ -143,7 +143,7 @@ impl AggregateExpr {
|
||||
/// since aggr functions like `avg` need to be transform to `sum(x)/cast(count(x) as x_type)`
|
||||
pub fn from_substrait_agg_func(
|
||||
f: &proto::AggregateFunction,
|
||||
input_schema: &RelationType,
|
||||
input_schema: &RelationDesc,
|
||||
extensions: &FunctionExtensions,
|
||||
filter: &Option<TypedExpr>,
|
||||
order_by: &Option<Vec<TypedExpr>>,
|
||||
@@ -320,7 +320,7 @@ impl TypedPlan {
|
||||
let group_exprs = TypedExpr::from_substrait_agg_grouping(
|
||||
ctx,
|
||||
&agg.groupings,
|
||||
&input.schema.typ,
|
||||
&input.schema,
|
||||
extensions,
|
||||
)?;
|
||||
|
||||
@@ -332,7 +332,7 @@ impl TypedPlan {
|
||||
let (mut aggr_exprs, post_mfp) = AggregateExpr::from_substrait_agg_measures(
|
||||
ctx,
|
||||
&agg.measures,
|
||||
&input.schema.typ,
|
||||
&input.schema,
|
||||
extensions,
|
||||
)?;
|
||||
|
||||
|
||||
@@ -14,6 +14,9 @@
|
||||
|
||||
#![warn(unused_imports)]
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion_physical_expr::PhysicalExpr;
|
||||
use datatypes::data_type::ConcreteDataType as CDT;
|
||||
use itertools::Itertools;
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
@@ -24,15 +27,17 @@ use substrait_proto::proto::function_argument::ArgType;
|
||||
use substrait_proto::proto::Expression;
|
||||
|
||||
use crate::adapter::error::{
|
||||
DatatypesSnafu, Error, EvalSnafu, InvalidQuerySnafu, NotImplementedSnafu, PlanSnafu,
|
||||
DatafusionSnafu, DatatypesSnafu, Error, EvalSnafu, InvalidQuerySnafu, NotImplementedSnafu,
|
||||
PlanSnafu,
|
||||
};
|
||||
use crate::expr::{
|
||||
BinaryFunc, ScalarExpr, TypedExpr, UnaryFunc, UnmaterializableFunc, VariadicFunc,
|
||||
BinaryFunc, DfScalarFunction, RawDfScalarFn, ScalarExpr, TypedExpr, UnaryFunc,
|
||||
UnmaterializableFunc, VariadicFunc,
|
||||
};
|
||||
use crate::repr::{ColumnType, RelationType};
|
||||
use crate::repr::{ColumnType, RelationDesc};
|
||||
use crate::transform::literal::{from_substrait_literal, from_substrait_type};
|
||||
use crate::transform::{substrait_proto, FunctionExtensions};
|
||||
// TODO: found proper place for this
|
||||
// TODO(discord9): found proper place for this
|
||||
/// ref to `arrow_schema::datatype` for type name
|
||||
fn typename_to_cdt(name: &str) -> CDT {
|
||||
match name {
|
||||
@@ -54,11 +59,64 @@ fn typename_to_cdt(name: &str) -> CDT {
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert [`ScalarFunction`] to corresponding Datafusion's [`PhysicalExpr`]
|
||||
pub(crate) fn from_scalar_fn_to_df_fn_impl(
|
||||
f: &ScalarFunction,
|
||||
input_schema: &RelationDesc,
|
||||
extensions: &FunctionExtensions,
|
||||
) -> Result<Arc<dyn PhysicalExpr>, Error> {
|
||||
let e = Expression {
|
||||
rex_type: Some(RexType::ScalarFunction(f.clone())),
|
||||
};
|
||||
let schema = input_schema.to_df_schema()?;
|
||||
|
||||
let df_expr = futures::executor::block_on(async {
|
||||
// TODO(discord9): consider coloring everything async....
|
||||
substrait::df_logical_plan::consumer::from_substrait_rex(
|
||||
&datafusion::prelude::SessionContext::new(),
|
||||
&e,
|
||||
&schema,
|
||||
&extensions.inner_ref(),
|
||||
)
|
||||
.await
|
||||
});
|
||||
let expr = df_expr.map_err(|err| {
|
||||
DatafusionSnafu {
|
||||
raw: err,
|
||||
context: "Failed to convert substrait scalar function to datafusion scalar function",
|
||||
}
|
||||
.build()
|
||||
})?;
|
||||
let phy_expr =
|
||||
datafusion::physical_expr::create_physical_expr(&expr, &schema, &Default::default())
|
||||
.map_err(|err| {
|
||||
DatafusionSnafu {
|
||||
raw: err,
|
||||
context: "Failed to create physical expression from logical expression",
|
||||
}
|
||||
.build()
|
||||
})?;
|
||||
Ok(phy_expr)
|
||||
}
|
||||
|
||||
impl TypedExpr {
|
||||
pub fn from_substrait_to_datafusion_scalar_func(
|
||||
f: &ScalarFunction,
|
||||
input_schema: &RelationDesc,
|
||||
extensions: &FunctionExtensions,
|
||||
) -> Result<TypedExpr, Error> {
|
||||
let phy_expr = from_scalar_fn_to_df_fn_impl(f, input_schema, extensions)?;
|
||||
let raw_fn = RawDfScalarFn::from_proto(f, input_schema.clone(), extensions.clone())?;
|
||||
let expr = DfScalarFunction::new(raw_fn, phy_expr)?;
|
||||
let expr = ScalarExpr::CallDf { df_scalar_fn: expr };
|
||||
// df already know it's own schema, so not providing here
|
||||
let ret_type = expr.typ(&[])?;
|
||||
Ok(TypedExpr::new(expr, ret_type))
|
||||
}
|
||||
/// Convert ScalarFunction into Flow's ScalarExpr
|
||||
pub fn from_substrait_scalar_func(
|
||||
f: &ScalarFunction,
|
||||
input_schema: &RelationType,
|
||||
input_schema: &RelationDesc,
|
||||
extensions: &FunctionExtensions,
|
||||
) -> Result<TypedExpr, Error> {
|
||||
let fn_name =
|
||||
@@ -182,7 +240,12 @@ impl TypedExpr {
|
||||
ret_type,
|
||||
))
|
||||
} else {
|
||||
not_impl_err!("Unsupported function {fn_name} with {arg_len} arguments")
|
||||
let try_as_df = Self::from_substrait_to_datafusion_scalar_func(
|
||||
f,
|
||||
input_schema,
|
||||
extensions,
|
||||
)?;
|
||||
Ok(try_as_df)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -191,7 +254,7 @@ impl TypedExpr {
|
||||
/// Convert IfThen into Flow's ScalarExpr
|
||||
pub fn from_substrait_ifthen_rex(
|
||||
if_then: &IfThen,
|
||||
input_schema: &RelationType,
|
||||
input_schema: &RelationDesc,
|
||||
extensions: &FunctionExtensions,
|
||||
) -> Result<TypedExpr, Error> {
|
||||
let ifs: Vec<_> = if_then
|
||||
@@ -246,7 +309,7 @@ impl TypedExpr {
|
||||
/// Convert Substrait Rex into Flow's ScalarExpr
|
||||
pub fn from_substrait_rex(
|
||||
e: &Expression,
|
||||
input_schema: &RelationType,
|
||||
input_schema: &RelationDesc,
|
||||
extensions: &FunctionExtensions,
|
||||
) -> Result<TypedExpr, Error> {
|
||||
match &e.rex_type {
|
||||
@@ -277,7 +340,7 @@ impl TypedExpr {
|
||||
}
|
||||
None => {
|
||||
let column = x.field as usize;
|
||||
let column_type = input_schema.column_types[column].clone();
|
||||
let column_type = input_schema.typ().column_types[column].clone();
|
||||
Ok(TypedExpr::new(ScalarExpr::Column(column), column_type))
|
||||
}
|
||||
},
|
||||
@@ -322,8 +385,6 @@ impl TypedExpr {
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use std::collections::HashMap;
|
||||
|
||||
use common_time::{DateTime, Interval};
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use datatypes::value::Value;
|
||||
@@ -584,10 +645,9 @@ mod test {
|
||||
output_type: None,
|
||||
..Default::default()
|
||||
};
|
||||
let input_schema = RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), false)]);
|
||||
let extensions = FunctionExtensions {
|
||||
anchor_to_name: HashMap::from([(0, "is_null".to_string())]),
|
||||
};
|
||||
let input_schema =
|
||||
RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), false)]).into_unnamed();
|
||||
let extensions = FunctionExtensions::from_iter([(0, "is_null".to_string())]);
|
||||
let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
@@ -611,10 +671,9 @@ mod test {
|
||||
let input_schema = RelationType::new(vec![
|
||||
ColumnType::new(CDT::uint32_datatype(), false),
|
||||
ColumnType::new(CDT::uint32_datatype(), false),
|
||||
]);
|
||||
let extensions = FunctionExtensions {
|
||||
anchor_to_name: HashMap::from([(0, "add".to_string())]),
|
||||
};
|
||||
])
|
||||
.into_unnamed();
|
||||
let extensions = FunctionExtensions::from_iter([(0, "add".to_string())]);
|
||||
let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
@@ -639,10 +698,9 @@ mod test {
|
||||
let input_schema = RelationType::new(vec![
|
||||
ColumnType::new(CDT::timestamp_nanosecond_datatype(), false),
|
||||
ColumnType::new(CDT::string_datatype(), false),
|
||||
]);
|
||||
let extensions = FunctionExtensions {
|
||||
anchor_to_name: HashMap::from([(0, "tumble".to_string())]),
|
||||
};
|
||||
])
|
||||
.into_unnamed();
|
||||
let extensions = FunctionExtensions::from_iter(vec![(0, "tumble".to_string())]);
|
||||
let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
@@ -668,10 +726,9 @@ mod test {
|
||||
let input_schema = RelationType::new(vec![
|
||||
ColumnType::new(CDT::timestamp_nanosecond_datatype(), false),
|
||||
ColumnType::new(CDT::string_datatype(), false),
|
||||
]);
|
||||
let extensions = FunctionExtensions {
|
||||
anchor_to_name: HashMap::from([(0, "tumble".to_string())]),
|
||||
};
|
||||
])
|
||||
.into_unnamed();
|
||||
let extensions = FunctionExtensions::from_iter(vec![(0, "tumble".to_string())]);
|
||||
let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
|
||||
@@ -64,7 +64,7 @@ impl TypedPlan {
|
||||
}
|
||||
|
||||
/// Convert Substrait Rel into Flow's TypedPlan
|
||||
/// TODO: SELECT DISTINCT(does it get compile with something else?)
|
||||
/// TODO(discord9): SELECT DISTINCT(does it get compile with something else?)
|
||||
pub fn from_substrait_rel(
|
||||
ctx: &mut FlownodeContext,
|
||||
rel: &Rel,
|
||||
@@ -80,7 +80,7 @@ impl TypedPlan {
|
||||
|
||||
let mut exprs: Vec<TypedExpr> = vec![];
|
||||
for e in &p.expressions {
|
||||
let expr = TypedExpr::from_substrait_rex(e, &input.schema.typ, extensions)?;
|
||||
let expr = TypedExpr::from_substrait_rex(e, &input.schema, extensions)?;
|
||||
exprs.push(expr);
|
||||
}
|
||||
let is_literal = exprs.iter().all(|expr| expr.expr.is_literal());
|
||||
@@ -133,7 +133,7 @@ impl TypedPlan {
|
||||
};
|
||||
|
||||
let expr = if let Some(condition) = filter.condition.as_ref() {
|
||||
TypedExpr::from_substrait_rex(condition, &input.schema.typ, extensions)?
|
||||
TypedExpr::from_substrait_rex(condition, &input.schema, extensions)?
|
||||
} else {
|
||||
return not_impl_err!("Filter without an condition is not valid");
|
||||
};
|
||||
@@ -213,7 +213,7 @@ fn rewrite_projection_after_reduce(
|
||||
reduce_output_type: &RelationDesc,
|
||||
proj_exprs: &mut Vec<TypedExpr>,
|
||||
) -> Result<(), Error> {
|
||||
// TODO: get keys correctly
|
||||
// TODO(discord9): get keys correctly
|
||||
let key_exprs = key_val_plan
|
||||
.key_plan
|
||||
.projection
|
||||
|
||||
@@ -177,7 +177,7 @@ pub struct Arrangement {
|
||||
///
|
||||
/// Since updates typically occur as a delete followed by an insert, a small vector of size 2 is used to store updates for efficiency.
|
||||
///
|
||||
/// TODO: Consider balancing the batch size?
|
||||
/// TODO(discord9): Consider balancing the batch size?
|
||||
spine: Spine,
|
||||
|
||||
/// Indicates whether the arrangement maintains a complete history of updates.
|
||||
|
||||
Reference in New Issue
Block a user