refactor: DfUdfAdapter to bridge ScalaUdf (#3814)

* refactor: DfUdfAdapter to bridge ScalaUdf

Signed-off-by: tison <wander4096@gmail.com>

* tidy impl

Signed-off-by: tison <wander4096@gmail.com>

* for more

Signed-off-by: tison <wander4096@gmail.com>

* for more

Signed-off-by: tison <wander4096@gmail.com>

* for more

Signed-off-by: tison <wander4096@gmail.com>

---------

Signed-off-by: tison <wander4096@gmail.com>
This commit is contained in:
tison
2024-04-28 12:17:06 +08:00
committed by GitHub
parent ed8b13689e
commit e154dc5fd4
13 changed files with 173 additions and 191 deletions

View File

@@ -119,16 +119,12 @@ fn build_struct(
}
pub fn scalar_udf() -> ScalarUDF {
// TODO(LFC): Use the new Datafusion UDF impl.
#[allow(deprecated)]
ScalarUDF::new(
datafusion_expr::create_udf(
Self::name(),
&Signature::new(
TypeSignature::Exact(Self::input_type()),
Volatility::Immutable,
),
&(Arc::new(|_: &_| Ok(Arc::new(Self::return_type()))) as _),
&(Arc::new(Self::calc) as _),
Self::input_type(),
Arc::new(Self::return_type()),
Volatility::Immutable,
Arc::new(Self::calc) as _,
)
}

View File

@@ -29,6 +29,7 @@ pub use self::udf::ScalarUdf;
use crate::function::{ReturnTypeFunction, ScalarFunctionImplementation};
use crate::logical_plan::accumulator::*;
use crate::signature::{Signature, Volatility};
/// Creates a new UDF with a specific signature and specific return type.
/// This is a helper function to create a new UDF.
/// The function `create_udf` returns a subset of all possible `ScalarFunction`:

View File

@@ -91,74 +91,67 @@ impl AggregateFunction {
}
}
struct DfUdafAdapter {
name: String,
signature: datafusion_expr::Signature,
return_type_func: datafusion_expr::ReturnTypeFunction,
accumulator: AccumulatorFactoryFunction,
creator: AggregateFunctionCreatorRef,
}
impl Debug for DfUdafAdapter {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("DfUdafAdapter")
.field("name", &self.name)
.field("signature", &self.signature)
.finish()
}
}
impl AggregateUDFImpl for DfUdafAdapter {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
&self.name
}
fn signature(&self) -> &datafusion_expr::Signature {
&self.signature
}
fn return_type(&self, arg_types: &[ArrowDataType]) -> Result<ArrowDataType> {
(self.return_type_func)(arg_types).map(|x| x.as_ref().clone())
}
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
(self.accumulator)(acc_args)
}
fn state_fields(&self, name: &str, _: ArrowDataType, _: Vec<Field>) -> Result<Vec<Field>> {
let state_types = self.creator.state_types()?;
let fields = state_types
.into_iter()
.enumerate()
.map(|(i, t)| {
let name = format!("{name}_{i}");
Field::new(name, t.as_arrow_type(), true)
})
.collect::<Vec<_>>();
Ok(fields)
}
}
impl From<AggregateFunction> for DfAggregateUdf {
fn from(udaf: AggregateFunction) -> Self {
struct DfUdafAdapter {
name: String,
signature: datafusion_expr::Signature,
return_type_func: datafusion_expr::ReturnTypeFunction,
accumulator: AccumulatorFactoryFunction,
creator: AggregateFunctionCreatorRef,
}
impl Debug for DfUdafAdapter {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("DfUdafAdapter")
.field("name", &self.name)
.field("signature", &self.signature)
.finish()
}
}
impl AggregateUDFImpl for DfUdafAdapter {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
&self.name
}
fn signature(&self) -> &datafusion_expr::Signature {
&self.signature
}
fn return_type(&self, arg_types: &[ArrowDataType]) -> Result<ArrowDataType> {
(self.return_type_func)(arg_types).map(|x| x.as_ref().clone())
}
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
(self.accumulator)(acc_args)
}
fn state_fields(
&self,
name: &str,
_value_type: ArrowDataType,
_ordering_fields: Vec<Field>,
) -> Result<Vec<Field>> {
self.creator
.state_types()
.map(|x| {
(0..x.len())
.zip(x)
.map(|(i, t)| {
Field::new(format!("{}_{}", name, i), t.as_arrow_type(), true)
})
.collect::<Vec<_>>()
})
.map_err(|e| e.into())
}
}
DfUdafAdapter {
DfAggregateUdf::new_from_impl(DfUdafAdapter {
name: udaf.name,
signature: udaf.signature.into(),
return_type_func: to_df_return_type(udaf.return_type),
accumulator: to_df_accumulator_func(udaf.accumulator, udaf.creator.clone()),
creator: udaf.creator,
}
.into()
})
}
}

View File

@@ -14,6 +14,7 @@
//! Udf module contains foundational types that are used to represent UDFs.
//! It's modified from datafusion.
use std::any::Any;
use std::fmt;
use std::fmt::{Debug, Formatter};
use std::sync::Arc;
@@ -21,7 +22,9 @@ use std::sync::Arc;
use datafusion_expr::{
ColumnarValue as DfColumnarValue,
ScalarFunctionImplementation as DfScalarFunctionImplementation, ScalarUDF as DfScalarUDF,
ScalarUDFImpl,
};
use datatypes::arrow::datatypes::DataType;
use crate::error::Result;
use crate::function::{ReturnTypeFunction, ScalarFunctionImplementation};
@@ -68,25 +71,60 @@ impl ScalarUdf {
}
}
#[derive(Clone)]
struct DfUdfAdapter {
name: String,
signature: datafusion_expr::Signature,
return_type: datafusion_expr::ReturnTypeFunction,
fun: DfScalarFunctionImplementation,
}
impl Debug for DfUdfAdapter {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_struct("DfUdfAdapter")
.field("name", &self.name)
.field("signature", &self.signature)
.finish()
}
}
impl ScalarUDFImpl for DfUdfAdapter {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
&self.name
}
fn signature(&self) -> &datafusion_expr::Signature {
&self.signature
}
fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
(self.return_type)(arg_types).map(|ty| ty.as_ref().clone())
}
fn invoke(&self, args: &[DfColumnarValue]) -> datafusion_common::Result<DfColumnarValue> {
(self.fun)(args)
}
}
impl From<ScalarUdf> for DfScalarUDF {
fn from(udf: ScalarUdf) -> Self {
// TODO(LFC): remove deprecated
#[allow(deprecated)]
DfScalarUDF::new(
&udf.name,
&udf.signature.into(),
&to_df_return_type(udf.return_type),
&to_df_scalar_func(udf.fun),
)
DfScalarUDF::new_from_impl(DfUdfAdapter {
name: udf.name,
signature: udf.signature.into(),
return_type: to_df_return_type(udf.return_type),
fun: to_df_scalar_func(udf.fun),
})
}
}
fn to_df_scalar_func(fun: ScalarFunctionImplementation) -> DfScalarFunctionImplementation {
Arc::new(move |args: &[DfColumnarValue]| {
let args: Result<Vec<_>> = args.iter().map(TryFrom::try_from).collect();
let result = (fun)(&args?);
let result = fun(&args?);
result.map(From::from).map_err(|e| e.into())
})
}

View File

@@ -17,7 +17,7 @@ use std::sync::Arc;
use common_macro::range_fn;
use datafusion::arrow::array::{Float64Array, TimestampMillisecondArray};
use datafusion::common::DataFusionError;
use datafusion::logical_expr::{ScalarUDF, Signature, TypeSignature, Volatility};
use datafusion::logical_expr::{ScalarUDF, Volatility};
use datafusion::physical_plan::ColumnarValue;
use datatypes::arrow::array::Array;
use datatypes::arrow::compute;

View File

@@ -20,7 +20,7 @@ use std::sync::Arc;
use common_macro::range_fn;
use datafusion::arrow::array::{Float64Array, TimestampMillisecondArray};
use datafusion::common::DataFusionError;
use datafusion::logical_expr::{ScalarUDF, Signature, TypeSignature, Volatility};
use datafusion::logical_expr::{ScalarUDF, Volatility};
use datafusion::physical_plan::ColumnarValue;
use datatypes::arrow::array::Array;
use datatypes::arrow::datatypes::DataType;

View File

@@ -20,7 +20,7 @@ use std::sync::Arc;
use common_macro::range_fn;
use datafusion::arrow::array::{Float64Array, TimestampMillisecondArray};
use datafusion::common::DataFusionError;
use datafusion::logical_expr::{ScalarUDF, Signature, TypeSignature, Volatility};
use datafusion::logical_expr::{ScalarUDF, Volatility};
use datafusion::physical_plan::ColumnarValue;
use datatypes::arrow::array::Array;
use datatypes::arrow::datatypes::DataType;

View File

@@ -35,8 +35,9 @@ use std::sync::Arc;
use datafusion::arrow::array::{Float64Array, TimestampMillisecondArray};
use datafusion::arrow::datatypes::TimeUnit;
use datafusion::common::DataFusionError;
use datafusion::logical_expr::{ScalarUDF, Signature, TypeSignature, Volatility};
use datafusion::logical_expr::{ScalarUDF, Volatility};
use datafusion::physical_plan::ColumnarValue;
use datafusion_expr::create_udf;
use datatypes::arrow::array::Array;
use datatypes::arrow::datatypes::DataType;
@@ -62,19 +63,23 @@ impl<const IS_COUNTER: bool, const IS_RATE: bool> ExtrapolatedRate<IS_COUNTER, I
Self { range_length }
}
fn input_type() -> Vec<DataType> {
vec![
fn scalar_udf_with_name(name: &str, range_length: i64) -> ScalarUDF {
let input_types = vec![
// timestamp range vector
RangeArray::convert_data_type(DataType::Timestamp(TimeUnit::Millisecond, None)),
// value range vector
RangeArray::convert_data_type(DataType::Float64),
// timestamp vector
DataType::Timestamp(TimeUnit::Millisecond, None),
]
}
];
fn return_type() -> DataType {
DataType::Float64
create_udf(
name,
input_types,
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(move |input: &_| Self::new(range_length).calc(input)) as _,
)
}
fn calc(&self, input: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
@@ -204,17 +209,7 @@ impl ExtrapolatedRate<false, false> {
}
pub fn scalar_udf(range_length: i64) -> ScalarUDF {
// TODO(LFC): Use the new Datafusion UDF impl.
#[allow(deprecated)]
ScalarUDF::new(
Self::name(),
&Signature::new(
TypeSignature::Exact(Self::input_type()),
Volatility::Immutable,
),
&(Arc::new(|_: &_| Ok(Arc::new(Self::return_type()))) as _),
&(Arc::new(move |input: &_| Self::new(range_length).calc(input)) as _),
)
Self::scalar_udf_with_name(Self::name(), range_length)
}
}
@@ -225,17 +220,7 @@ impl ExtrapolatedRate<true, true> {
}
pub fn scalar_udf(range_length: i64) -> ScalarUDF {
// TODO(LFC): Use the new Datafusion UDF impl.
#[allow(deprecated)]
ScalarUDF::new(
Self::name(),
&Signature::new(
TypeSignature::Exact(Self::input_type()),
Volatility::Immutable,
),
&(Arc::new(|_: &_| Ok(Arc::new(Self::return_type()))) as _),
&(Arc::new(move |input: &_| Self::new(range_length).calc(input)) as _),
)
Self::scalar_udf_with_name(Self::name(), range_length)
}
}
@@ -246,17 +231,7 @@ impl ExtrapolatedRate<true, false> {
}
pub fn scalar_udf(range_length: i64) -> ScalarUDF {
// TODO(LFC): Use the new Datafusion UDF impl.
#[allow(deprecated)]
ScalarUDF::new(
Self::name(),
&Signature::new(
TypeSignature::Exact(Self::input_type()),
Volatility::Immutable,
),
&(Arc::new(|_: &_| Ok(Arc::new(Self::return_type()))) as _),
&(Arc::new(move |input: &_| Self::new(range_length).calc(input)) as _),
)
Self::scalar_udf_with_name(Self::name(), range_length)
}
}

View File

@@ -20,8 +20,9 @@ use std::sync::Arc;
use datafusion::arrow::array::Float64Array;
use datafusion::arrow::datatypes::TimeUnit;
use datafusion::common::DataFusionError;
use datafusion::logical_expr::{ScalarUDF, Signature, TypeSignature, Volatility};
use datafusion::logical_expr::{ScalarUDF, Volatility};
use datafusion::physical_plan::ColumnarValue;
use datafusion_expr::create_udf;
use datatypes::arrow::array::Array;
use datatypes::arrow::datatypes::DataType;
@@ -68,16 +69,12 @@ impl HoltWinters {
}
pub fn scalar_udf(level: f64, trend: f64) -> ScalarUDF {
// TODO(LFC): Use the new Datafusion UDF impl.
#[allow(deprecated)]
ScalarUDF::new(
create_udf(
Self::name(),
&Signature::new(
TypeSignature::Exact(Self::input_type()),
Volatility::Immutable,
),
&(Arc::new(|_: &_| Ok(Arc::new(Self::return_type()))) as _),
&(Arc::new(move |input: &_| Self::new(level, trend).calc(input)) as _),
Self::input_type(),
Arc::new(Self::return_type()),
Volatility::Immutable,
Arc::new(move |input: &_| Self::new(level, trend).calc(input)) as _,
)
}

View File

@@ -18,8 +18,9 @@ use std::sync::Arc;
use datafusion::arrow::array::{Float64Array, TimestampMillisecondArray};
use datafusion::arrow::datatypes::TimeUnit;
use datafusion::common::DataFusionError;
use datafusion::logical_expr::{ScalarUDF, Signature, TypeSignature, Volatility};
use datafusion::logical_expr::{ScalarUDF, Volatility};
use datafusion::physical_plan::ColumnarValue;
use datafusion_expr::create_udf;
use datatypes::arrow::array::Array;
use datatypes::arrow::datatypes::DataType;
@@ -42,16 +43,12 @@ impl<const IS_RATE: bool> IDelta<IS_RATE> {
}
pub fn scalar_udf() -> ScalarUDF {
// TODO(LFC): Use the new Datafusion UDF impl.
#[allow(deprecated)]
ScalarUDF::new(
create_udf(
Self::name(),
&Signature::new(
TypeSignature::Exact(Self::input_type()),
Volatility::Immutable,
),
&(Arc::new(|_: &_| Ok(Arc::new(Self::return_type()))) as _),
&(Arc::new(Self::calc) as _),
Self::input_type(),
Arc::new(Self::return_type()),
Volatility::Immutable,
Arc::new(Self::calc) as _,
)
}

View File

@@ -20,8 +20,9 @@ use std::sync::Arc;
use datafusion::arrow::array::{Float64Array, TimestampMillisecondArray};
use datafusion::arrow::datatypes::TimeUnit;
use datafusion::common::DataFusionError;
use datafusion::logical_expr::{ScalarUDF, Signature, TypeSignature, Volatility};
use datafusion::logical_expr::{ScalarUDF, Volatility};
use datafusion::physical_plan::ColumnarValue;
use datafusion_expr::create_udf;
use datatypes::arrow::array::Array;
use datatypes::arrow::datatypes::DataType;
@@ -44,32 +45,22 @@ impl PredictLinear {
}
pub fn scalar_udf(t: i64) -> ScalarUDF {
// TODO(LFC): Use the new Datafusion UDF impl.
#[allow(deprecated)]
ScalarUDF::new(
let input_types = vec![
// time index column
RangeArray::convert_data_type(DataType::Timestamp(TimeUnit::Millisecond, None)),
// value column
RangeArray::convert_data_type(DataType::Float64),
];
create_udf(
Self::name(),
&Signature::new(
TypeSignature::Exact(Self::input_type()),
Volatility::Immutable,
),
&(Arc::new(|_: &_| Ok(Arc::new(Self::return_type()))) as _),
&(Arc::new(move |input: &_| Self::new(t).calc(input)) as _),
input_types,
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(move |input: &_| Self::new(t).predict_linear(input)) as _,
)
}
// time index column and value column
fn input_type() -> Vec<DataType> {
vec![
RangeArray::convert_data_type(DataType::Timestamp(TimeUnit::Millisecond, None)),
RangeArray::convert_data_type(DataType::Float64),
]
}
fn return_type() -> DataType {
DataType::Float64
}
fn calc(&self, input: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
fn predict_linear(&self, input: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
// construct matrix from input.
assert_eq!(input.len(), 2);
let ts_array = extract_array(&input[0])?;

View File

@@ -17,8 +17,9 @@ use std::sync::Arc;
use datafusion::arrow::array::Float64Array;
use datafusion::arrow::datatypes::TimeUnit;
use datafusion::common::DataFusionError;
use datafusion::logical_expr::{ScalarUDF, Signature, TypeSignature, Volatility};
use datafusion::logical_expr::{ScalarUDF, Volatility};
use datafusion::physical_plan::ColumnarValue;
use datafusion_expr::create_udf;
use datatypes::arrow::array::Array;
use datatypes::arrow::datatypes::DataType;
@@ -40,32 +41,25 @@ impl QuantileOverTime {
}
pub fn scalar_udf(quantile: f64) -> ScalarUDF {
// TODO(LFC): Use the new Datafusion UDF impl.
#[allow(deprecated)]
ScalarUDF::new(
let input_types = vec![
// time index column
RangeArray::convert_data_type(DataType::Timestamp(TimeUnit::Millisecond, None)),
// value column
RangeArray::convert_data_type(DataType::Float64),
];
create_udf(
Self::name(),
&Signature::new(
TypeSignature::Exact(Self::input_type()),
Volatility::Immutable,
),
&(Arc::new(|_: &_| Ok(Arc::new(Self::return_type()))) as _),
&(Arc::new(move |input: &_| Self::new(quantile).calc(input)) as _),
input_types,
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(move |input: &_| Self::new(quantile).quantile_over_time(input)) as _,
)
}
// time index column and value column
fn input_type() -> Vec<DataType> {
vec![
RangeArray::convert_data_type(DataType::Timestamp(TimeUnit::Millisecond, None)),
RangeArray::convert_data_type(DataType::Float64),
]
}
fn return_type() -> DataType {
DataType::Float64
}
fn calc(&self, input: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
fn quantile_over_time(
&self,
input: &[ColumnarValue],
) -> Result<ColumnarValue, DataFusionError> {
// construct matrix from input.
assert_eq!(input.len(), 2);
let ts_array = extract_array(&input[0])?;

View File

@@ -20,7 +20,7 @@ use std::sync::Arc;
use common_macro::range_fn;
use datafusion::arrow::array::{Float64Array, TimestampMillisecondArray};
use datafusion::common::DataFusionError;
use datafusion::logical_expr::{ScalarUDF, Signature, TypeSignature, Volatility};
use datafusion::logical_expr::{ScalarUDF, Volatility};
use datafusion::physical_plan::ColumnarValue;
use datatypes::arrow::array::Array;
use datatypes::arrow::datatypes::DataType;