refactor: rewrite h3 functions to DataFusion style (#6942)

* refactor: rewrite h3 functions to DataFusion style

Signed-off-by: luofucong <luofc@foxmail.com>

* resolve PR comments

Signed-off-by: luofucong <luofc@foxmail.com>

---------

Signed-off-by: luofucong <luofc@foxmail.com>
This commit is contained in:
LFC
2025-09-12 10:27:24 +08:00
committed by GitHub
parent 9fe7069146
commit 9ab87e11a4
12 changed files with 512 additions and 514 deletions

View File

@@ -15,11 +15,15 @@
use std::fmt;
use std::sync::Arc;
use common_query::error::Result;
use common_error::ext::{BoxedError, PlainError};
use common_error::status_code::StatusCode;
use common_query::error::{ExecuteSnafu, Result};
use datafusion::arrow::datatypes::DataType;
use datafusion_expr::Signature;
use datafusion::logical_expr::ColumnarValue;
use datafusion_expr::{ScalarFunctionArgs, Signature};
use datatypes::vectors::VectorRef;
use session::context::{QueryContextBuilder, QueryContextRef};
use snafu::ResultExt;
use crate::state::FunctionState;
@@ -68,8 +72,26 @@ pub trait Function: fmt::Display + Sync + Send {
/// The signature of function.
fn signature(&self) -> Signature;
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
// TODO(LFC): Remove default implementation once all UDFs have implemented this function.
let _ = args;
Err(datafusion_common::DataFusionError::NotImplemented(
"invoke_with_args".to_string(),
))
}
/// Evaluate the function, e.g. run/execute the function.
fn eval(&self, ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef>;
/// TODO(LFC): Remove `eval` when all UDFs are rewritten to `invoke_with_args`
fn eval(&self, _: &FunctionContext, _: &[VectorRef]) -> Result<VectorRef> {
Err(BoxedError::new(PlainError::new(
"unsupported".to_string(),
StatusCode::Unsupported,
)))
.context(ExecuteSnafu)
}
fn aliases(&self) -> &[String] {
&[]

File diff suppressed because it is too large Load Diff

View File

@@ -16,9 +16,7 @@ use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use common_query::error::{
GeneralDataFusionSnafu, IntoVectorSnafu, InvalidFuncArgsSnafu, InvalidInputTypeSnafu, Result,
};
use common_query::error::{IntoVectorSnafu, InvalidFuncArgsSnafu, InvalidInputTypeSnafu, Result};
use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion};
use datafusion::common::{DFSchema, Result as DfResult};
use datafusion::execution::SessionStateBuilder;
@@ -106,21 +104,16 @@ impl MatchesFunction {
let input_schema = Self::input_schema();
let session_state = SessionStateBuilder::new().with_default_features().build();
let planner = DefaultPhysicalPlanner::default();
let physical_expr = planner
.create_physical_expr(&like_expr, &input_schema, &session_state)
.context(GeneralDataFusionSnafu)?;
let physical_expr =
planner.create_physical_expr(&like_expr, &input_schema, &session_state)?;
let data_array = data.to_arrow_array();
let arrow_schema = Arc::new(input_schema.as_arrow().clone());
let input_record_batch = RecordBatch::try_new(arrow_schema, vec![data_array]).unwrap();
let num_rows = input_record_batch.num_rows();
let result = physical_expr
.evaluate(&input_record_batch)
.context(GeneralDataFusionSnafu)?;
let result_array = result
.into_array(num_rows)
.context(GeneralDataFusionSnafu)?;
let result = physical_expr.evaluate(&input_record_batch)?;
let result_array = result.into_array(num_rows)?;
let result_vector =
BooleanVector::try_from_arrow_array(result_array).context(IntoVectorSnafu {
data_type: DataType::Boolean,
@@ -210,14 +203,12 @@ impl PatternAst {
/// Transform this AST with preset rules to make it correct.
fn transform_ast(self) -> Result<Self> {
self.transform_up(Self::collapse_binary_branch_fn)
.context(GeneralDataFusionSnafu)
.map(|data| data.data)?
.transform_up(Self::eliminate_optional_fn)
.context(GeneralDataFusionSnafu)
.map(|data| data.data)?
.transform_down(Self::eliminate_single_child_fn)
.context(GeneralDataFusionSnafu)
.map(|data| data.data)
.map_err(Into::into)
}
/// Collapse binary branch with the same operator. I.e., this transformer

View File

@@ -19,13 +19,12 @@ mod rate;
use std::fmt;
pub use clamp::{ClampFunction, ClampMaxFunction, ClampMinFunction};
use common_query::error::{GeneralDataFusionSnafu, Result};
use common_query::error::Result;
use datafusion::arrow::datatypes::DataType;
use datafusion::error::DataFusionError;
use datafusion_expr::{Signature, Volatility};
use datatypes::vectors::VectorRef;
pub use rate::RateFunction;
use snafu::ResultExt;
use crate::function::{Function, FunctionContext};
use crate::function_registry::FunctionRegistry;
@@ -68,7 +67,7 @@ impl Function for RangeFunction {
.ok_or(DataFusionError::Internal(
"No expr found in range_fn".into(),
))
.context(GeneralDataFusionSnafu)
.map_err(Into::into)
}
/// `range_fn` will never been used. As long as a legal signature is returned, the specific content of the signature does not matter.
@@ -80,7 +79,7 @@ impl Function for RangeFunction {
fn eval(&self, _func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
Err(DataFusionError::Internal(
"range_fn just a empty function used in range select, It should not be eval!".into(),
))
.context(GeneralDataFusionSnafu)
)
.into())
}
}

View File

@@ -65,6 +65,14 @@ impl ScalarUDFImpl for ScalarUdf {
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<datafusion_expr::ColumnarValue> {
let result = self.function.invoke_with_args(args.clone());
if !matches!(
result,
Err(datafusion_common::DataFusionError::NotImplemented(_))
) {
return result;
}
let columns = args
.args
.iter()

View File

@@ -16,11 +16,12 @@ use std::fmt;
use std::sync::Arc;
use common_query::error::Result;
use datafusion::arrow::array::StringViewArray;
use datafusion::arrow::datatypes::DataType;
use datafusion_expr::{Signature, Volatility};
use datatypes::vectors::{StringVector, VectorRef};
use datafusion::logical_expr::ColumnarValue;
use datafusion_expr::{ScalarFunctionArgs, Signature, Volatility};
use crate::function::{Function, FunctionContext};
use crate::function::Function;
/// Generates build information
#[derive(Clone, Debug, Default)]
@@ -38,17 +39,18 @@ impl Function for BuildFunction {
}
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Utf8)
Ok(DataType::Utf8View)
}
fn signature(&self) -> Signature {
Signature::nullary(Volatility::Immutable)
}
fn eval(&self, _func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
fn invoke_with_args(&self, _: ScalarFunctionArgs) -> datafusion_common::Result<ColumnarValue> {
let build_info = common_version::build_info().to_string();
let v = Arc::new(StringVector::from(vec![build_info]));
Ok(v)
Ok(ColumnarValue::Array(Arc::new(StringViewArray::from(vec![
build_info,
]))))
}
}
@@ -56,16 +58,29 @@ impl Function for BuildFunction {
mod tests {
use std::sync::Arc;
use arrow_schema::Field;
use datafusion::arrow::array::ArrayRef;
use datafusion_common::config::ConfigOptions;
use super::*;
#[test]
fn test_build_function() {
let build = BuildFunction;
assert_eq!("build", build.name());
assert_eq!(DataType::Utf8, build.return_type(&[]).unwrap());
assert_eq!(DataType::Utf8View, build.return_type(&[]).unwrap());
assert_eq!(build.signature(), Signature::nullary(Volatility::Immutable));
let build_info = common_version::build_info().to_string();
let vector = build.eval(&FunctionContext::default(), &[]).unwrap();
let expect: VectorRef = Arc::new(StringVector::from(vec![build_info]));
assert_eq!(expect, vector);
let actual = build
.invoke_with_args(ScalarFunctionArgs {
args: vec![],
arg_fields: vec![],
number_rows: 0,
return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
config_options: Arc::new(ConfigOptions::new()),
})
.unwrap();
let actual = ColumnarValue::values_to_arrays(&[actual]).unwrap();
let expect = vec![Arc::new(StringViewArray::from(vec![build_info])) as ArrayRef];
assert_eq!(actual, expect);
}
}