mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-31 04:10:38 +00:00
refactor: rewrite h3 functions to DataFusion style (#6942)
* refactor: rewrite h3 functions to DataFusion style Signed-off-by: luofucong <luofc@foxmail.com> * resolve PR comments Signed-off-by: luofucong <luofc@foxmail.com> --------- Signed-off-by: luofucong <luofc@foxmail.com>
This commit is contained in:
@@ -15,11 +15,15 @@
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error::Result;
|
||||
use common_error::ext::{BoxedError, PlainError};
|
||||
use common_error::status_code::StatusCode;
|
||||
use common_query::error::{ExecuteSnafu, Result};
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion_expr::Signature;
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature};
|
||||
use datatypes::vectors::VectorRef;
|
||||
use session::context::{QueryContextBuilder, QueryContextRef};
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::state::FunctionState;
|
||||
|
||||
@@ -68,8 +72,26 @@ pub trait Function: fmt::Display + Sync + Send {
|
||||
/// The signature of function.
|
||||
fn signature(&self) -> Signature;
|
||||
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
// TODO(LFC): Remove default implementation once all UDFs have implemented this function.
|
||||
let _ = args;
|
||||
Err(datafusion_common::DataFusionError::NotImplemented(
|
||||
"invoke_with_args".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
/// Evaluate the function, e.g. run/execute the function.
|
||||
fn eval(&self, ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef>;
|
||||
/// TODO(LFC): Remove `eval` when all UDFs are rewritten to `invoke_with_args`
|
||||
fn eval(&self, _: &FunctionContext, _: &[VectorRef]) -> Result<VectorRef> {
|
||||
Err(BoxedError::new(PlainError::new(
|
||||
"unsupported".to_string(),
|
||||
StatusCode::Unsupported,
|
||||
)))
|
||||
.context(ExecuteSnafu)
|
||||
}
|
||||
|
||||
fn aliases(&self) -> &[String] {
|
||||
&[]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -16,9 +16,7 @@ use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error::{
|
||||
GeneralDataFusionSnafu, IntoVectorSnafu, InvalidFuncArgsSnafu, InvalidInputTypeSnafu, Result,
|
||||
};
|
||||
use common_query::error::{IntoVectorSnafu, InvalidFuncArgsSnafu, InvalidInputTypeSnafu, Result};
|
||||
use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion};
|
||||
use datafusion::common::{DFSchema, Result as DfResult};
|
||||
use datafusion::execution::SessionStateBuilder;
|
||||
@@ -106,21 +104,16 @@ impl MatchesFunction {
|
||||
let input_schema = Self::input_schema();
|
||||
let session_state = SessionStateBuilder::new().with_default_features().build();
|
||||
let planner = DefaultPhysicalPlanner::default();
|
||||
let physical_expr = planner
|
||||
.create_physical_expr(&like_expr, &input_schema, &session_state)
|
||||
.context(GeneralDataFusionSnafu)?;
|
||||
let physical_expr =
|
||||
planner.create_physical_expr(&like_expr, &input_schema, &session_state)?;
|
||||
|
||||
let data_array = data.to_arrow_array();
|
||||
let arrow_schema = Arc::new(input_schema.as_arrow().clone());
|
||||
let input_record_batch = RecordBatch::try_new(arrow_schema, vec![data_array]).unwrap();
|
||||
|
||||
let num_rows = input_record_batch.num_rows();
|
||||
let result = physical_expr
|
||||
.evaluate(&input_record_batch)
|
||||
.context(GeneralDataFusionSnafu)?;
|
||||
let result_array = result
|
||||
.into_array(num_rows)
|
||||
.context(GeneralDataFusionSnafu)?;
|
||||
let result = physical_expr.evaluate(&input_record_batch)?;
|
||||
let result_array = result.into_array(num_rows)?;
|
||||
let result_vector =
|
||||
BooleanVector::try_from_arrow_array(result_array).context(IntoVectorSnafu {
|
||||
data_type: DataType::Boolean,
|
||||
@@ -210,14 +203,12 @@ impl PatternAst {
|
||||
/// Transform this AST with preset rules to make it correct.
|
||||
fn transform_ast(self) -> Result<Self> {
|
||||
self.transform_up(Self::collapse_binary_branch_fn)
|
||||
.context(GeneralDataFusionSnafu)
|
||||
.map(|data| data.data)?
|
||||
.transform_up(Self::eliminate_optional_fn)
|
||||
.context(GeneralDataFusionSnafu)
|
||||
.map(|data| data.data)?
|
||||
.transform_down(Self::eliminate_single_child_fn)
|
||||
.context(GeneralDataFusionSnafu)
|
||||
.map(|data| data.data)
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Collapse binary branch with the same operator. I.e., this transformer
|
||||
|
||||
@@ -19,13 +19,12 @@ mod rate;
|
||||
use std::fmt;
|
||||
|
||||
pub use clamp::{ClampFunction, ClampMaxFunction, ClampMinFunction};
|
||||
use common_query::error::{GeneralDataFusionSnafu, Result};
|
||||
use common_query::error::Result;
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion::error::DataFusionError;
|
||||
use datafusion_expr::{Signature, Volatility};
|
||||
use datatypes::vectors::VectorRef;
|
||||
pub use rate::RateFunction;
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function_registry::FunctionRegistry;
|
||||
@@ -68,7 +67,7 @@ impl Function for RangeFunction {
|
||||
.ok_or(DataFusionError::Internal(
|
||||
"No expr found in range_fn".into(),
|
||||
))
|
||||
.context(GeneralDataFusionSnafu)
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// `range_fn` will never been used. As long as a legal signature is returned, the specific content of the signature does not matter.
|
||||
@@ -80,7 +79,7 @@ impl Function for RangeFunction {
|
||||
fn eval(&self, _func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
Err(DataFusionError::Internal(
|
||||
"range_fn just a empty function used in range select, It should not be eval!".into(),
|
||||
))
|
||||
.context(GeneralDataFusionSnafu)
|
||||
)
|
||||
.into())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,6 +65,14 @@ impl ScalarUDFImpl for ScalarUdf {
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<datafusion_expr::ColumnarValue> {
|
||||
let result = self.function.invoke_with_args(args.clone());
|
||||
if !matches!(
|
||||
result,
|
||||
Err(datafusion_common::DataFusionError::NotImplemented(_))
|
||||
) {
|
||||
return result;
|
||||
}
|
||||
|
||||
let columns = args
|
||||
.args
|
||||
.iter()
|
||||
|
||||
@@ -16,11 +16,12 @@ use std::fmt;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error::Result;
|
||||
use datafusion::arrow::array::StringViewArray;
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion_expr::{Signature, Volatility};
|
||||
use datatypes::vectors::{StringVector, VectorRef};
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature, Volatility};
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::Function;
|
||||
|
||||
/// Generates build information
|
||||
#[derive(Clone, Debug, Default)]
|
||||
@@ -38,17 +39,18 @@ impl Function for BuildFunction {
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
Ok(DataType::Utf8)
|
||||
Ok(DataType::Utf8View)
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
Signature::nullary(Volatility::Immutable)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
fn invoke_with_args(&self, _: ScalarFunctionArgs) -> datafusion_common::Result<ColumnarValue> {
|
||||
let build_info = common_version::build_info().to_string();
|
||||
let v = Arc::new(StringVector::from(vec![build_info]));
|
||||
Ok(v)
|
||||
Ok(ColumnarValue::Array(Arc::new(StringViewArray::from(vec![
|
||||
build_info,
|
||||
]))))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,16 +58,29 @@ impl Function for BuildFunction {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_schema::Field;
|
||||
use datafusion::arrow::array::ArrayRef;
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
|
||||
use super::*;
|
||||
#[test]
|
||||
fn test_build_function() {
|
||||
let build = BuildFunction;
|
||||
assert_eq!("build", build.name());
|
||||
assert_eq!(DataType::Utf8, build.return_type(&[]).unwrap());
|
||||
assert_eq!(DataType::Utf8View, build.return_type(&[]).unwrap());
|
||||
assert_eq!(build.signature(), Signature::nullary(Volatility::Immutable));
|
||||
let build_info = common_version::build_info().to_string();
|
||||
let vector = build.eval(&FunctionContext::default(), &[]).unwrap();
|
||||
let expect: VectorRef = Arc::new(StringVector::from(vec![build_info]));
|
||||
assert_eq!(expect, vector);
|
||||
let actual = build
|
||||
.invoke_with_args(ScalarFunctionArgs {
|
||||
args: vec![],
|
||||
arg_fields: vec![],
|
||||
number_rows: 0,
|
||||
return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
})
|
||||
.unwrap();
|
||||
let actual = ColumnarValue::values_to_arrays(&[actual]).unwrap();
|
||||
let expect = vec![Arc::new(StringViewArray::from(vec![build_info])) as ArrayRef];
|
||||
assert_eq!(actual, expect);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user