From 2712c5cd7ac58ff9850855778c45a3a2873de774 Mon Sep 17 00:00:00 2001 From: LFC <990479+MichaelScofield@users.noreply.github.com> Date: Tue, 23 Sep 2025 15:53:51 +0800 Subject: [PATCH] refactor: rewrite some UDFs to DataFusion style (part 3) (#6990) * refactor: rewrite some UDFs to DataFusion style (part 3) Signed-off-by: luofucong * resolve PR comments Signed-off-by: luofucong * resolve PR comments Signed-off-by: luofucong * resolve PR comments Signed-off-by: luofucong --------- Signed-off-by: luofucong --- .../function/src/aggrs/aggr_wrapper/tests.rs | 40 +++-- src/common/function/src/function.rs | 64 ++++++++ .../function/src/scalars/date/date_add.rs | 6 +- .../function/src/scalars/date/date_sub.rs | 6 +- .../src/scalars/expression/is_null.rs | 55 +++---- .../function/src/scalars/geo/geohash.rs | 10 +- src/common/function/src/scalars/geo/h3.rs | 58 +++---- .../function/src/scalars/geo/helpers.rs | 61 +------- .../function/src/scalars/geo/measure.rs | 87 ++++++----- .../function/src/scalars/geo/relation.rs | 146 ++++++------------ src/common/function/src/scalars/geo/s2.rs | 134 +++++++++------- src/common/function/src/scalars/geo/wkt.rs | 38 +++-- src/common/function/src/scalars/hll_count.rs | 108 ++++++++----- src/common/function/src/scalars/matches.rs | 7 +- .../function/src/scalars/uddsketch_calc.rs | 134 +++++++++------- .../src/scalars/vector/vector_subvector.rs | 7 +- src/common/function/src/system/database.rs | 95 ++++++++---- src/query/src/datafusion.rs | 11 ++ src/query/src/dist_plan/analyzer.rs | 42 ++++- 19 files changed, 622 insertions(+), 487 deletions(-) diff --git a/src/common/function/src/aggrs/aggr_wrapper/tests.rs b/src/common/function/src/aggrs/aggr_wrapper/tests.rs index 3f82a8fa9b..a262703b28 100644 --- a/src/common/function/src/aggrs/aggr_wrapper/tests.rs +++ b/src/common/function/src/aggrs/aggr_wrapper/tests.rs @@ -17,7 +17,6 @@ use std::pin::Pin; use std::sync::{Arc, Mutex}; use std::task::{Context, Poll}; -use arrow::array::{ArrayRef, Float64Array, Int64Array, UInt64Array}; use arrow::record_batch::RecordBatch; use arrow_schema::SchemaRef; use datafusion::catalog::{Session, TableProvider}; @@ -32,10 +31,14 @@ use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties}; use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}; use datafusion::prelude::SessionContext; +use datafusion_common::arrow::array::{ArrayRef, AsArray, Float64Array, Int64Array, UInt64Array}; +use datafusion_common::arrow::datatypes::{Float64Type, UInt64Type}; use datafusion_common::{Column, TableReference}; use datafusion_expr::expr::AggregateFunction; use datafusion_expr::sqlparser::ast::NullTreatment; -use datafusion_expr::{Aggregate, Expr, LogicalPlan, SortExpr, TableScan, lit}; +use datafusion_expr::{ + Aggregate, ColumnarValue, Expr, LogicalPlan, ScalarFunctionArgs, SortExpr, TableScan, lit, +}; use datafusion_physical_expr::aggregate::AggregateExprBuilder; use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; use datatypes::arrow_array::StringArray; @@ -649,14 +652,20 @@ async fn test_udaf_correct_eval_result() { expected_output: None, expected_fn: Some(|arr| { let percent = ScalarValue::Float64(Some(0.5)).to_array().unwrap(); - let percent = datatypes::vectors::Helper::try_into_vector(percent).unwrap(); - let state = datatypes::vectors::Helper::try_into_vector(arr).unwrap(); + let percent = ColumnarValue::Array(percent); + let state = ColumnarValue::Array(arr); let udd_calc = UddSketchCalcFunction; let res = udd_calc - .eval(&Default::default(), &[percent, state]) + .invoke_with_args(ScalarFunctionArgs { + args: vec![percent, state], + arg_fields: vec![], + number_rows: 1, + return_field: Arc::new(Field::new("x", DataType::Float64, false)), + config_options: Arc::new(Default::default()), + }) .unwrap(); - let binding = res.to_arrow_array(); - let res_arr = binding.as_any().downcast_ref::().unwrap(); + let binding = res.to_array(1).unwrap(); + let res_arr = binding.as_primitive::(); assert!(res_arr.len() == 1); assert!((res_arr.value(0) - 2.856578984907706f64).abs() <= f64::EPSILON); true @@ -683,11 +692,20 @@ async fn test_udaf_correct_eval_result() { ]))], expected_output: None, expected_fn: Some(|arr| { - let state = datatypes::vectors::Helper::try_into_vector(arr).unwrap(); + let number_rows = arr.len(); + let state = ColumnarValue::Array(arr); let hll_calc = HllCalcFunction; - let res = hll_calc.eval(&Default::default(), &[state]).unwrap(); - let binding = res.to_arrow_array(); - let res_arr = binding.as_any().downcast_ref::().unwrap(); + let res = hll_calc + .invoke_with_args(ScalarFunctionArgs { + args: vec![state], + arg_fields: vec![], + number_rows, + return_field: Arc::new(Field::new("x", DataType::UInt64, false)), + config_options: Arc::new(Default::default()), + }) + .unwrap(); + let binding = res.to_array(1).unwrap(); + let res_arr = binding.as_primitive::(); assert!(res_arr.len() == 1); assert_eq!(res_arr.value(0), 3); true diff --git a/src/common/function/src/function.rs b/src/common/function/src/function.rs index 18ad60b622..4ee271192d 100644 --- a/src/common/function/src/function.rs +++ b/src/common/function/src/function.rs @@ -12,7 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::any::Any; use std::fmt; +use std::fmt::{Debug, Formatter}; use std::sync::Arc; use common_error::ext::{BoxedError, PlainError}; @@ -20,6 +22,9 @@ use common_error::status_code::StatusCode; use common_query::error::{ExecuteSnafu, Result}; use datafusion::arrow::datatypes::DataType; use datafusion::logical_expr::ColumnarValue; +use datafusion_common::DataFusionError; +use datafusion_common::arrow::array::ArrayRef; +use datafusion_common::config::{ConfigEntry, ConfigExtension, ExtensionOptions}; use datafusion_expr::{ScalarFunctionArgs, Signature}; use datatypes::vectors::VectorRef; use session::context::{QueryContextBuilder, QueryContextRef}; @@ -60,6 +65,42 @@ impl Default for FunctionContext { } } +impl ExtensionOptions for FunctionContext { + fn as_any(&self) -> &dyn Any { + self + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + fn cloned(&self) -> Box { + Box::new(self.clone()) + } + + fn set(&mut self, _: &str, _: &str) -> datafusion_common::Result<()> { + Err(DataFusionError::NotImplemented( + "set options for `FunctionContext`".to_string(), + )) + } + + fn entries(&self) -> Vec { + vec![] + } +} + +impl Debug for FunctionContext { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("FunctionContext") + .field("query_ctx", &self.query_ctx) + .finish() + } +} + +impl ConfigExtension for FunctionContext { + const PREFIX: &'static str = "FunctionContext"; +} + /// Scalar function trait, modified from databend to adapt datafusion /// TODO(dennis): optimize function by it's features such as monotonicity etc. pub trait Function: fmt::Display + Sync + Send { @@ -99,3 +140,26 @@ pub trait Function: fmt::Display + Sync + Send { } pub type FunctionRef = Arc; + +/// Find the [FunctionContext] in the [ScalarFunctionArgs]. The [FunctionContext] was set +/// previously in the DataFusion session context creation, and is passed all the way down to the +/// args by DataFusion. +pub(crate) fn find_function_context( + args: &ScalarFunctionArgs, +) -> datafusion_common::Result<&FunctionContext> { + let Some(x) = args.config_options.extensions.get::() else { + return Err(DataFusionError::Execution( + "function context is not set".to_string(), + )); + }; + Ok(x) +} + +/// Extract UDF arguments (as Arrow's [ArrayRef]) from [ScalarFunctionArgs] directly. +pub(crate) fn extract_args( + name: &str, + args: &ScalarFunctionArgs, +) -> datafusion_common::Result<[ArrayRef; N]> { + ColumnarValue::values_to_arrays(&args.args) + .and_then(|x| datafusion_common::utils::take_function_args(name, x)) +} diff --git a/src/common/function/src/scalars/date/date_add.rs b/src/common/function/src/scalars/date/date_add.rs index 973535fc7b..ced1e02cfb 100644 --- a/src/common/function/src/scalars/date/date_add.rs +++ b/src/common/function/src/scalars/date/date_add.rs @@ -16,13 +16,12 @@ use std::fmt; use common_query::error::{ArrowComputeSnafu, Result}; use datafusion::logical_expr::ColumnarValue; -use datafusion_common::utils; use datafusion_expr::{ScalarFunctionArgs, Signature}; use datatypes::arrow::compute::kernels::numeric; use datatypes::arrow::datatypes::{DataType, IntervalUnit, TimeUnit}; use snafu::ResultExt; -use crate::function::Function; +use crate::function::{Function, extract_args}; use crate::helper; /// A function adds an interval value to Timestamp, Date, and return the result. @@ -63,8 +62,7 @@ impl Function for DateAddFunction { &self, args: ScalarFunctionArgs, ) -> datafusion_common::Result { - let args = ColumnarValue::values_to_arrays(&args.args)?; - let [left, right] = utils::take_function_args(self.name(), args)?; + let [left, right] = extract_args(self.name(), &args)?; let result = numeric::add(&left, &right).context(ArrowComputeSnafu)?; Ok(ColumnarValue::Array(result)) diff --git a/src/common/function/src/scalars/date/date_sub.rs b/src/common/function/src/scalars/date/date_sub.rs index 6ed5b84c90..71ef5701ef 100644 --- a/src/common/function/src/scalars/date/date_sub.rs +++ b/src/common/function/src/scalars/date/date_sub.rs @@ -16,13 +16,12 @@ use std::fmt; use common_query::error::{ArrowComputeSnafu, Result}; use datafusion::logical_expr::ColumnarValue; -use datafusion_common::utils; use datafusion_expr::{ScalarFunctionArgs, Signature}; use datatypes::arrow::compute::kernels::numeric; use datatypes::arrow::datatypes::{DataType, IntervalUnit, TimeUnit}; use snafu::ResultExt; -use crate::function::Function; +use crate::function::{Function, extract_args}; use crate::helper; /// A function subtracts an interval value to Timestamp, Date, and return the result. @@ -63,8 +62,7 @@ impl Function for DateSubFunction { &self, args: ScalarFunctionArgs, ) -> datafusion_common::Result { - let args = ColumnarValue::values_to_arrays(&args.args)?; - let [left, right] = utils::take_function_args(self.name(), args)?; + let [left, right] = extract_args(self.name(), &args)?; let result = numeric::sub(&left, &right).context(ArrowComputeSnafu)?; Ok(ColumnarValue::Array(result)) diff --git a/src/common/function/src/scalars/expression/is_null.rs b/src/common/function/src/scalars/expression/is_null.rs index fd28f8682a..f536b8a21e 100644 --- a/src/common/function/src/scalars/expression/is_null.rs +++ b/src/common/function/src/scalars/expression/is_null.rs @@ -16,17 +16,12 @@ use std::fmt; use std::fmt::Display; use std::sync::Arc; -use common_query::error; -use common_query::error::{ArrowComputeSnafu, InvalidFuncArgsSnafu, Result}; -use datafusion::arrow::array::ArrayRef; +use common_query::error::Result; use datafusion::arrow::compute::is_null; use datafusion::arrow::datatypes::DataType; -use datafusion_expr::{Signature, Volatility}; -use datatypes::prelude::VectorRef; -use datatypes::vectors::Helper; -use snafu::{ResultExt, ensure}; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility}; -use crate::function::{Function, FunctionContext}; +use crate::function::{Function, extract_args}; const NAME: &str = "isnull"; @@ -53,25 +48,14 @@ impl Function for IsNullFunction { Signature::any(1, Volatility::Immutable) } - fn eval( + fn invoke_with_args( &self, - _func_ctx: &FunctionContext, - columns: &[VectorRef], - ) -> common_query::error::Result { - ensure!( - columns.len() == 1, - InvalidFuncArgsSnafu { - err_msg: format!( - "The length of the args is not correct, expect exactly one, have: {}", - columns.len() - ), - } - ); - let values = &columns[0]; - let arrow_array = &values.to_arrow_array(); - let result = is_null(arrow_array).context(ArrowComputeSnafu)?; + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let [arg0] = extract_args(self.name(), &args)?; + let result = is_null(&arg0)?; - Helper::try_into_vector(Arc::new(result) as ArrayRef).context(error::FromArrowArraySnafu) + Ok(ColumnarValue::Array(Arc::new(result))) } } @@ -79,9 +63,9 @@ impl Function for IsNullFunction { mod tests { use std::sync::Arc; + use arrow_schema::Field; + use datafusion_common::arrow::array::{AsArray, BooleanArray, Float32Array}; use datafusion_expr::TypeSignature; - use datatypes::scalars::ScalarVector; - use datatypes::vectors::{BooleanVector, Float32Vector}; use super::*; #[test] @@ -98,9 +82,20 @@ mod tests { ); let values = vec![None, Some(3.0), None]; - let args: Vec = vec![Arc::new(Float32Vector::from(values))]; - let vector = is_null.eval(&FunctionContext::default(), &args).unwrap(); - let expect: VectorRef = Arc::new(BooleanVector::from_vec(vec![true, false, true])); + let result = is_null + .invoke_with_args(ScalarFunctionArgs { + args: vec![ColumnarValue::Array(Arc::new(Float32Array::from(values)))], + arg_fields: vec![], + number_rows: 3, + return_field: Arc::new(Field::new("", DataType::Boolean, false)), + config_options: Arc::new(Default::default()), + }) + .unwrap(); + let ColumnarValue::Array(result) = result else { + unreachable!() + }; + let vector = result.as_boolean(); + let expect = &BooleanArray::from(vec![true, false, true]); assert_eq!(expect, vector); } } diff --git a/src/common/function/src/scalars/geo/geohash.rs b/src/common/function/src/scalars/geo/geohash.rs index 43085b6b7e..f939208ac7 100644 --- a/src/common/function/src/scalars/geo/geohash.rs +++ b/src/common/function/src/scalars/geo/geohash.rs @@ -21,13 +21,13 @@ use common_query::error::{self, Result}; use datafusion::arrow::array::{Array, AsArray, ListBuilder, StringViewBuilder}; use datafusion::arrow::datatypes::{DataType, Field, Float64Type, UInt8Type}; use datafusion::logical_expr::ColumnarValue; -use datafusion_common::{DataFusionError, utils}; +use datafusion_common::DataFusionError; use datafusion_expr::type_coercion::aggregates::INTEGERS; use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility}; use geohash::Coord; use snafu::ResultExt; -use crate::function::Function; +use crate::function::{Function, extract_args}; use crate::scalars::geo::helpers; fn ensure_resolution_usize(v: u8) -> datafusion_common::Result { @@ -77,8 +77,7 @@ impl Function for GeohashFunction { &self, args: ScalarFunctionArgs, ) -> datafusion_common::Result { - let args = ColumnarValue::values_to_arrays(&args.args)?; - let [lat_vec, lon_vec, resolutions] = utils::take_function_args(self.name(), args)?; + let [lat_vec, lon_vec, resolutions] = extract_args(self.name(), &args)?; let lat_vec = helpers::cast::(&lat_vec)?; let lat_vec = lat_vec.as_primitive::(); @@ -169,8 +168,7 @@ impl Function for GeohashNeighboursFunction { &self, args: ScalarFunctionArgs, ) -> datafusion_common::Result { - let args = ColumnarValue::values_to_arrays(&args.args)?; - let [lat_vec, lon_vec, resolutions] = utils::take_function_args(self.name(), args)?; + let [lat_vec, lon_vec, resolutions] = extract_args(self.name(), &args)?; let lat_vec = helpers::cast::(&lat_vec)?; let lat_vec = lat_vec.as_primitive::(); diff --git a/src/common/function/src/scalars/geo/h3.rs b/src/common/function/src/scalars/geo/h3.rs index 0014e3a3cb..d32c30c81a 100644 --- a/src/common/function/src/scalars/geo/h3.rs +++ b/src/common/function/src/scalars/geo/h3.rs @@ -25,7 +25,7 @@ use datafusion::arrow::array::{ use datafusion::arrow::compute; use datafusion::arrow::datatypes::{Float64Type, Int64Type, UInt8Type, UInt64Type}; use datafusion::logical_expr::ColumnarValue; -use datafusion_common::{DataFusionError, ScalarValue, utils}; +use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_expr::type_coercion::aggregates::INTEGERS; use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility}; use datatypes::arrow::datatypes::{DataType, Field}; @@ -33,7 +33,7 @@ use derive_more::Display; use h3o::{CellIndex, LatLng, Resolution}; use snafu::prelude::*; -use crate::function::Function; +use crate::function::{Function, extract_args}; use crate::scalars::geo::helpers; static CELL_TYPES: LazyLock> = @@ -85,8 +85,7 @@ impl Function for H3LatLngToCell { &self, args: ScalarFunctionArgs, ) -> datafusion_common::Result { - let args = ColumnarValue::values_to_arrays(&args.args)?; - let [lat_vec, lon_vec, resolution_vec] = utils::take_function_args(self.name(), args)?; + let [lat_vec, lon_vec, resolution_vec] = extract_args(self.name(), &args)?; let lat_vec = helpers::cast::(&lat_vec)?; let lat_vec = lat_vec.as_primitive::(); @@ -167,8 +166,7 @@ impl Function for H3LatLngToCellString { &self, args: ScalarFunctionArgs, ) -> datafusion_common::Result { - let args = ColumnarValue::values_to_arrays(&args.args)?; - let [lat_vec, lon_vec, resolution_vec] = utils::take_function_args(self.name(), args)?; + let [lat_vec, lon_vec, resolution_vec] = extract_args(self.name(), &args)?; let lat_vec = helpers::cast::(&lat_vec)?; let lat_vec = lat_vec.as_primitive::(); @@ -233,8 +231,7 @@ impl Function for H3CellToString { &self, args: ScalarFunctionArgs, ) -> datafusion_common::Result { - let args = ColumnarValue::values_to_arrays(&args.args)?; - let [cell_vec] = utils::take_function_args(self.name(), args)?; + let [cell_vec] = extract_args(self.name(), &args)?; let size = cell_vec.len(); let mut builder = StringViewBuilder::with_capacity(size); @@ -272,8 +269,7 @@ impl Function for H3StringToCell { &self, args: ScalarFunctionArgs, ) -> datafusion_common::Result { - let args = ColumnarValue::values_to_arrays(&args.args)?; - let [string_vec] = utils::take_function_args(self.name(), args)?; + let [string_vec] = extract_args(self.name(), &args)?; let string_vec = compute::cast(string_vec.as_ref(), &DataType::Utf8View)?; let string_vec = datafusion_common::downcast_value!(string_vec, StringViewArray); @@ -323,8 +319,7 @@ impl Function for H3CellCenterLatLng { &self, args: ScalarFunctionArgs, ) -> datafusion_common::Result { - let args = ColumnarValue::values_to_arrays(&args.args)?; - let [cell_vec] = utils::take_function_args(self.name(), args)?; + let [cell_vec] = extract_args(self.name(), &args)?; let size = cell_vec.len(); let mut builder = ListBuilder::new(Float64Builder::new()); @@ -368,8 +363,7 @@ impl Function for H3CellResolution { &self, args: ScalarFunctionArgs, ) -> datafusion_common::Result { - let args = ColumnarValue::values_to_arrays(&args.args)?; - let [cell_vec] = utils::take_function_args(self.name(), args)?; + let [cell_vec] = extract_args(self.name(), &args)?; let size = cell_vec.len(); let mut builder = UInt8Builder::with_capacity(cell_vec.len()); @@ -406,8 +400,7 @@ impl Function for H3CellBase { &self, args: ScalarFunctionArgs, ) -> datafusion_common::Result { - let args = ColumnarValue::values_to_arrays(&args.args)?; - let [cell_vec] = utils::take_function_args(self.name(), args)?; + let [cell_vec] = extract_args(self.name(), &args)?; let size = cell_vec.len(); let mut builder = UInt8Builder::with_capacity(size); @@ -445,8 +438,7 @@ impl Function for H3CellIsPentagon { &self, args: ScalarFunctionArgs, ) -> datafusion_common::Result { - let args = ColumnarValue::values_to_arrays(&args.args)?; - let [cell_vec] = utils::take_function_args(self.name(), args)?; + let [cell_vec] = extract_args(self.name(), &args)?; let size = cell_vec.len(); let mut builder = BooleanBuilder::with_capacity(size); @@ -544,8 +536,7 @@ impl Function for H3CellToChildren { &self, args: ScalarFunctionArgs, ) -> datafusion_common::Result { - let args = ColumnarValue::values_to_arrays(&args.args)?; - let [cell_vec, res_vec] = utils::take_function_args(self.name(), args)?; + let [cell_vec, res_vec] = extract_args(self.name(), &args)?; let resolutions = helpers::cast::(&res_vec)?; let resolutions = resolutions.as_primitive::(); @@ -638,8 +629,7 @@ fn calculate_cell_child_property( where F: Fn(CellIndex, Resolution) -> Option, { - let args = ColumnarValue::values_to_arrays(&args.args)?; - let [cells, resolutions] = utils::take_function_args(name, args)?; + let [cells, resolutions] = extract_args(name, &args)?; let resolutions = helpers::cast::(&resolutions)?; let resolutions = resolutions.as_primitive::(); @@ -695,8 +685,7 @@ impl Function for H3ChildPosToCell { &self, args: ScalarFunctionArgs, ) -> datafusion_common::Result { - let args = ColumnarValue::values_to_arrays(&args.args)?; - let [pos_vec, cell_vec, res_vec] = utils::take_function_args(self.name(), args)?; + let [pos_vec, cell_vec, res_vec] = extract_args(self.name(), &args)?; let resolutions = helpers::cast::(&res_vec)?; let resolutions = resolutions.as_primitive::(); @@ -747,8 +736,7 @@ impl Function for H3GridDisk { &self, args: ScalarFunctionArgs, ) -> datafusion_common::Result { - let args = ColumnarValue::values_to_arrays(&args.args)?; - let [cell_vec, k_vec] = utils::take_function_args(self.name(), args)?; + let [cell_vec, k_vec] = extract_args(self.name(), &args)?; let size = cell_vec.len(); let mut builder = ListBuilder::new(UInt64Builder::new()); @@ -797,8 +785,7 @@ impl Function for H3GridDiskDistances { &self, args: ScalarFunctionArgs, ) -> datafusion_common::Result { - let args = ColumnarValue::values_to_arrays(&args.args)?; - let [cell_vec, k_vec] = utils::take_function_args(self.name(), args)?; + let [cell_vec, k_vec] = extract_args(self.name(), &args)?; let size = cell_vec.len(); let mut builder = ListBuilder::new(UInt64Builder::new()); @@ -842,8 +829,7 @@ impl Function for H3GridDistance { &self, args: ScalarFunctionArgs, ) -> datafusion_common::Result { - let args = ColumnarValue::values_to_arrays(&args.args)?; - let [cell_this_vec, cell_that_vec] = utils::take_function_args(self.name(), args)?; + let [cell_this_vec, cell_that_vec] = extract_args(self.name(), &args)?; let size = cell_this_vec.len(); let mut builder = Int32Builder::with_capacity(size); @@ -902,8 +888,7 @@ impl Function for H3GridPathCells { &self, args: ScalarFunctionArgs, ) -> datafusion_common::Result { - let args = ColumnarValue::values_to_arrays(&args.args)?; - let [cell_this_vec, cell_that_vec] = utils::take_function_args(self.name(), args)?; + let [cell_this_vec, cell_that_vec] = extract_args(self.name(), &args)?; let size = cell_this_vec.len(); let mut builder = ListBuilder::new(UInt64Builder::new()); @@ -972,8 +957,7 @@ impl Function for H3CellContains { &self, args: ScalarFunctionArgs, ) -> datafusion_common::Result { - let args = ColumnarValue::values_to_arrays(&args.args)?; - let [cells_vec, cell_this_vec] = utils::take_function_args(self.name(), args)?; + let [cells_vec, cell_this_vec] = extract_args(self.name(), &args)?; let size = cell_this_vec.len(); let mut builder = BooleanBuilder::with_capacity(size); @@ -1027,8 +1011,7 @@ impl Function for H3CellDistanceSphereKm { &self, args: ScalarFunctionArgs, ) -> datafusion_common::Result { - let args = ColumnarValue::values_to_arrays(&args.args)?; - let [cell_this_vec, cell_that_vec] = utils::take_function_args(self.name(), args)?; + let [cell_this_vec, cell_that_vec] = extract_args(self.name(), &args)?; let size = cell_this_vec.len(); let mut builder = Float64Builder::with_capacity(size); @@ -1084,8 +1067,7 @@ impl Function for H3CellDistanceEuclideanDegree { &self, args: ScalarFunctionArgs, ) -> datafusion_common::Result { - let args = ColumnarValue::values_to_arrays(&args.args)?; - let [cell_this_vec, cell_that_vec] = utils::take_function_args(self.name(), args)?; + let [cell_this_vec, cell_that_vec] = extract_args(self.name(), &args)?; let size = cell_this_vec.len(); let mut builder = Float64Builder::with_capacity(size); diff --git a/src/common/function/src/scalars/geo/helpers.rs b/src/common/function/src/scalars/geo/helpers.rs index c76e188990..5bdef16ef8 100644 --- a/src/common/function/src/scalars/geo/helpers.rs +++ b/src/common/function/src/scalars/geo/helpers.rs @@ -15,63 +15,14 @@ use datafusion::arrow::array::{ArrayRef, ArrowPrimitiveType}; use datafusion::arrow::compute; -macro_rules! ensure_columns_len { - ($columns:ident) => { - snafu::ensure!( - $columns.windows(2).all(|c| c[0].len() == c[1].len()), - common_query::error::InvalidFuncArgsSnafu { - err_msg: "The length of input columns are in different size" - } - ) - }; - ($column_a:ident, $column_b:ident, $($column_n:ident),*) => { - snafu::ensure!( - { - let mut result = $column_a.len() == $column_b.len(); - $( - result = result && ($column_a.len() == $column_n.len()); - )* - result - } - common_query::error::InvalidFuncArgsSnafu { - err_msg: "The length of input columns are in different size" - } - ) - }; -} - -pub(crate) use ensure_columns_len; - -macro_rules! ensure_columns_n { - ($columns:ident, $n:literal) => { - snafu::ensure!( - $columns.len() == $n, - common_query::error::InvalidFuncArgsSnafu { - err_msg: format!( - "The length of arguments is not correct, expect {}, provided : {}", - stringify!($n), - $columns.len() - ), - } - ); - - if $n > 1 { - ensure_columns_len!($columns); - } - }; -} - -pub(crate) use ensure_columns_n; - macro_rules! ensure_and_coerce { ($compare:expr, $coerce:expr) => {{ - snafu::ensure!( - $compare, - common_query::error::InvalidFuncArgsSnafu { - err_msg: "Argument was outside of acceptable range " - } - ); - Ok($coerce) + if !$compare { + return Err(datafusion_common::DataFusionError::Execution( + "argument out of valid range".to_string(), + )); + } + Ok(Some($coerce)) }}; } diff --git a/src/common/function/src/scalars/geo/measure.rs b/src/common/function/src/scalars/geo/measure.rs index cab77fae6e..047fa591d1 100644 --- a/src/common/function/src/scalars/geo/measure.rs +++ b/src/common/function/src/scalars/geo/measure.rs @@ -12,21 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::sync::Arc; + use common_error::ext::{BoxedError, PlainError}; use common_error::status_code::StatusCode; use common_query::error::{self, Result}; -use datafusion::arrow::datatypes::DataType; -use datafusion_expr::{Signature, Volatility}; -use datatypes::scalars::ScalarVectorBuilder; -use datatypes::vectors::{Float64VectorBuilder, MutableVector, VectorRef}; +use datafusion_common::arrow::array::{Array, AsArray, Float64Builder}; +use datafusion_common::arrow::compute; +use datafusion_common::arrow::datatypes::DataType; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility}; use derive_more::Display; use geo::algorithm::line_measures::metric_spaces::Euclidean; use geo::{Area, Distance, Haversine}; use geo_types::Geometry; use snafu::ResultExt; -use crate::function::{Function, FunctionContext}; -use crate::scalars::geo::helpers::{ensure_columns_len, ensure_columns_n}; +use crate::function::{Function, extract_args}; use crate::scalars::geo::wkt::parse_wkt; /// Return WGS84(SRID: 4326) euclidean distance between two geometry object, in degree @@ -47,33 +48,38 @@ impl Function for STDistance { Signature::string(2, Volatility::Stable) } - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - ensure_columns_n!(columns, 2); + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let [arg0, arg1] = extract_args(self.name(), &args)?; - let wkt_this_vec = &columns[0]; - let wkt_that_vec = &columns[1]; + let arg0 = compute::cast(&arg0, &DataType::Utf8View)?; + let wkt_this_vec = arg0.as_string_view(); + let arg1 = compute::cast(&arg1, &DataType::Utf8View)?; + let wkt_that_vec = arg1.as_string_view(); let size = wkt_this_vec.len(); - let mut results = Float64VectorBuilder::with_capacity(size); + let mut builder = Float64Builder::with_capacity(size); for i in 0..size { - let wkt_this = wkt_this_vec.get(i).as_string(); - let wkt_that = wkt_that_vec.get(i).as_string(); + let wkt_this = wkt_this_vec.is_valid(i).then(|| wkt_this_vec.value(i)); + let wkt_that = wkt_that_vec.is_valid(i).then(|| wkt_that_vec.value(i)); let result = match (wkt_this, wkt_that) { (Some(wkt_this), Some(wkt_that)) => { - let geom_this = parse_wkt(&wkt_this)?; - let geom_that = parse_wkt(&wkt_that)?; + let geom_this = parse_wkt(wkt_this)?; + let geom_that = parse_wkt(wkt_that)?; Some(Euclidean::distance(&geom_this, &geom_that)) } _ => None, }; - results.push(result); + builder.append_option(result); } - Ok(results.to_vector()) + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) } } @@ -95,23 +101,28 @@ impl Function for STDistanceSphere { Signature::string(2, Volatility::Stable) } - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - ensure_columns_n!(columns, 2); + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let [arg0, arg1] = extract_args(self.name(), &args)?; - let wkt_this_vec = &columns[0]; - let wkt_that_vec = &columns[1]; + let arg0 = compute::cast(&arg0, &DataType::Utf8View)?; + let wkt_this_vec = arg0.as_string_view(); + let arg1 = compute::cast(&arg1, &DataType::Utf8View)?; + let wkt_that_vec = arg1.as_string_view(); let size = wkt_this_vec.len(); - let mut results = Float64VectorBuilder::with_capacity(size); + let mut builder = Float64Builder::with_capacity(size); for i in 0..size { - let wkt_this = wkt_this_vec.get(i).as_string(); - let wkt_that = wkt_that_vec.get(i).as_string(); + let wkt_this = wkt_this_vec.is_valid(i).then(|| wkt_this_vec.value(i)); + let wkt_that = wkt_that_vec.is_valid(i).then(|| wkt_that_vec.value(i)); let result = match (wkt_this, wkt_that) { (Some(wkt_this), Some(wkt_that)) => { - let geom_this = parse_wkt(&wkt_this)?; - let geom_that = parse_wkt(&wkt_that)?; + let geom_this = parse_wkt(wkt_this)?; + let geom_that = parse_wkt(wkt_that)?; match (geom_this, geom_that) { (Geometry::Point(this), Geometry::Point(that)) => { @@ -128,10 +139,10 @@ impl Function for STDistanceSphere { _ => None, }; - results.push(result); + builder.append_option(result); } - Ok(results.to_vector()) + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) } } @@ -153,27 +164,31 @@ impl Function for STArea { Signature::string(1, Volatility::Stable) } - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - ensure_columns_n!(columns, 1); + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let [arg0] = extract_args(self.name(), &args)?; - let wkt_vec = &columns[0]; + let arg0 = compute::cast(&arg0, &DataType::Utf8View)?; + let wkt_vec = arg0.as_string_view(); let size = wkt_vec.len(); - let mut results = Float64VectorBuilder::with_capacity(size); + let mut builder = Float64Builder::with_capacity(size); for i in 0..size { - let wkt = wkt_vec.get(i).as_string(); + let wkt = wkt_vec.is_valid(i).then(|| wkt_vec.value(i)); let result = if let Some(wkt) = wkt { - let geom = parse_wkt(&wkt)?; + let geom = parse_wkt(wkt)?; Some(geom.unsigned_area()) } else { None }; - results.push(result); + builder.append_option(result); } - Ok(results.to_vector()) + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) } } diff --git a/src/common/function/src/scalars/geo/relation.rs b/src/common/function/src/scalars/geo/relation.rs index 6dc116c4b3..f0843700c2 100644 --- a/src/common/function/src/scalars/geo/relation.rs +++ b/src/common/function/src/scalars/geo/relation.rs @@ -12,18 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::sync::Arc; + use common_query::error::Result; -use datafusion::arrow::datatypes::DataType; -use datafusion_expr::{Signature, Volatility}; -use datatypes::scalars::ScalarVectorBuilder; -use datatypes::vectors::{BooleanVectorBuilder, MutableVector, VectorRef}; +use datafusion_common::arrow::array::{Array, AsArray, BooleanBuilder}; +use datafusion_common::arrow::compute; +use datafusion_common::arrow::datatypes::DataType; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility}; use derive_more::Display; use geo::algorithm::contains::Contains; use geo::algorithm::intersects::Intersects; use geo::algorithm::within::Within; +use geo_types::Geometry; -use crate::function::{Function, FunctionContext}; -use crate::scalars::geo::helpers::{ensure_columns_len, ensure_columns_n}; +use crate::function::{Function, extract_args}; use crate::scalars::geo::wkt::parse_wkt; /// Test if spatial relationship: contains @@ -31,46 +33,11 @@ use crate::scalars::geo::wkt::parse_wkt; #[display("{}", self.name())] pub struct STContains; -impl Function for STContains { - fn name(&self) -> &str { - "st_contains" - } +impl StFunction for STContains { + const NAME: &'static str = "st_contains"; - fn return_type(&self, _: &[DataType]) -> Result { - Ok(DataType::Boolean) - } - - fn signature(&self) -> Signature { - Signature::string(2, Volatility::Stable) - } - - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - ensure_columns_n!(columns, 2); - - let wkt_this_vec = &columns[0]; - let wkt_that_vec = &columns[1]; - - let size = wkt_this_vec.len(); - let mut results = BooleanVectorBuilder::with_capacity(size); - - for i in 0..size { - let wkt_this = wkt_this_vec.get(i).as_string(); - let wkt_that = wkt_that_vec.get(i).as_string(); - - let result = match (wkt_this, wkt_that) { - (Some(wkt_this), Some(wkt_that)) => { - let geom_this = parse_wkt(&wkt_this)?; - let geom_that = parse_wkt(&wkt_that)?; - - Some(geom_this.contains(&geom_that)) - } - _ => None, - }; - - results.push(result); - } - - Ok(results.to_vector()) + fn invoke(g1: Geometry, g2: Geometry) -> bool { + g1.contains(&g2) } } @@ -79,46 +46,11 @@ impl Function for STContains { #[display("{}", self.name())] pub struct STWithin; -impl Function for STWithin { - fn name(&self) -> &str { - "st_within" - } +impl StFunction for STWithin { + const NAME: &'static str = "st_within"; - fn return_type(&self, _: &[DataType]) -> Result { - Ok(DataType::Boolean) - } - - fn signature(&self) -> Signature { - Signature::string(2, Volatility::Stable) - } - - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - ensure_columns_n!(columns, 2); - - let wkt_this_vec = &columns[0]; - let wkt_that_vec = &columns[1]; - - let size = wkt_this_vec.len(); - let mut results = BooleanVectorBuilder::with_capacity(size); - - for i in 0..size { - let wkt_this = wkt_this_vec.get(i).as_string(); - let wkt_that = wkt_that_vec.get(i).as_string(); - - let result = match (wkt_this, wkt_that) { - (Some(wkt_this), Some(wkt_that)) => { - let geom_this = parse_wkt(&wkt_this)?; - let geom_that = parse_wkt(&wkt_that)?; - - Some(geom_this.is_within(&geom_that)) - } - _ => None, - }; - - results.push(result); - } - - Ok(results.to_vector()) + fn invoke(g1: Geometry, g2: Geometry) -> bool { + g1.is_within(&g2) } } @@ -127,9 +59,23 @@ impl Function for STWithin { #[display("{}", self.name())] pub struct STIntersects; -impl Function for STIntersects { +impl StFunction for STIntersects { + const NAME: &'static str = "st_intersects"; + + fn invoke(g1: Geometry, g2: Geometry) -> bool { + g1.intersects(&g2) + } +} + +trait StFunction { + const NAME: &'static str; + + fn invoke(g1: Geometry, g2: Geometry) -> bool; +} + +impl Function for T { fn name(&self) -> &str { - "st_intersects" + T::NAME } fn return_type(&self, _: &[DataType]) -> Result { @@ -140,32 +86,34 @@ impl Function for STIntersects { Signature::string(2, Volatility::Stable) } - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - ensure_columns_n!(columns, 2); + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let [arg0, arg1] = extract_args(self.name(), &args)?; - let wkt_this_vec = &columns[0]; - let wkt_that_vec = &columns[1]; + let arg0 = compute::cast(&arg0, &DataType::Utf8View)?; + let wkt_this_vec = arg0.as_string_view(); + let arg1 = compute::cast(&arg1, &DataType::Utf8View)?; + let wkt_that_vec = arg1.as_string_view(); let size = wkt_this_vec.len(); - let mut results = BooleanVectorBuilder::with_capacity(size); + let mut builder = BooleanBuilder::with_capacity(size); for i in 0..size { - let wkt_this = wkt_this_vec.get(i).as_string(); - let wkt_that = wkt_that_vec.get(i).as_string(); + let wkt_this = wkt_this_vec.is_valid(i).then(|| wkt_this_vec.value(i)); + let wkt_that = wkt_that_vec.is_valid(i).then(|| wkt_that_vec.value(i)); let result = match (wkt_this, wkt_that) { (Some(wkt_this), Some(wkt_that)) => { - let geom_this = parse_wkt(&wkt_this)?; - let geom_that = parse_wkt(&wkt_that)?; - - Some(geom_this.intersects(&geom_that)) + Some(T::invoke(parse_wkt(wkt_this)?, parse_wkt(wkt_that)?)) } _ => None, }; - results.push(result); + builder.append_option(result); } - Ok(results.to_vector()) + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) } } diff --git a/src/common/function/src/scalars/geo/s2.rs b/src/common/function/src/scalars/geo/s2.rs index 4e356520c2..1b3126d33c 100644 --- a/src/common/function/src/scalars/geo/s2.rs +++ b/src/common/function/src/scalars/geo/s2.rs @@ -12,21 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::LazyLock; +use std::sync::{Arc, LazyLock}; use common_query::error::{InvalidFuncArgsSnafu, Result}; -use datafusion_expr::{Signature, TypeSignature, Volatility}; -use datatypes::arrow::datatypes::DataType; -use datatypes::scalars::ScalarVectorBuilder; -use datatypes::value::Value; -use datatypes::vectors::{MutableVector, StringVectorBuilder, UInt64VectorBuilder, VectorRef}; +use datafusion_common::ScalarValue; +use datafusion_common::arrow::array::{Array, AsArray, StringViewBuilder, UInt64Builder}; +use datafusion_common::arrow::datatypes::{DataType, Float64Type}; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, Volatility}; use derive_more::Display; use s2::cellid::{CellID, MAX_LEVEL}; use s2::latlng::LatLng; use snafu::ensure; -use crate::function::{Function, FunctionContext}; -use crate::scalars::geo::helpers::{ensure_and_coerce, ensure_columns_len, ensure_columns_n}; +use crate::function::{Function, extract_args}; +use crate::scalars::geo::helpers; +use crate::scalars::geo::helpers::ensure_and_coerce; static CELL_TYPES: LazyLock> = LazyLock::new(|| vec![DataType::Int64, DataType::UInt64]); @@ -65,18 +65,23 @@ impl Function for S2LatLngToCell { Signature::one_of(signatures, Volatility::Stable) } - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - ensure_columns_n!(columns, 2); + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let [arg0, arg1] = extract_args(self.name(), &args)?; - let lat_vec = &columns[0]; - let lon_vec = &columns[1]; + let arg0 = helpers::cast::(&arg0)?; + let lat_vec = arg0.as_primitive::(); + let arg1 = helpers::cast::(&arg1)?; + let lon_vec = arg1.as_primitive::(); let size = lat_vec.len(); - let mut results = UInt64VectorBuilder::with_capacity(size); + let mut builder = UInt64Builder::with_capacity(size); for i in 0..size { - let lat = lat_vec.get(i).as_f64_lossy(); - let lon = lon_vec.get(i).as_f64_lossy(); + let lat = lat_vec.is_valid(i).then(|| lat_vec.value(i)); + let lon = lon_vec.is_valid(i).then(|| lon_vec.value(i)); let result = match (lat, lon) { (Some(lat), Some(lon)) => { @@ -94,10 +99,10 @@ impl Function for S2LatLngToCell { _ => None, }; - results.push(result); + builder.append_option(result); } - Ok(results.to_vector()) + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) } } @@ -119,21 +124,23 @@ impl Function for S2CellLevel { signature_of_cell() } - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - ensure_columns_n!(columns, 1); + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let [cell_vec] = extract_args(self.name(), &args)?; - let cell_vec = &columns[0]; let size = cell_vec.len(); - let mut results = UInt64VectorBuilder::with_capacity(size); + let mut builder = UInt64Builder::with_capacity(size); for i in 0..size { - let cell = cell_from_value(cell_vec.get(i)); - let res = cell.map(|cell| cell.level()); + let v = ScalarValue::try_from_array(&cell_vec, i)?; + let v = cell_from_value(v).map(|x| x.level()); - results.push(res); + builder.append_option(v); } - Ok(results.to_vector()) + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) } } @@ -148,28 +155,30 @@ impl Function for S2CellToToken { } fn return_type(&self, _: &[DataType]) -> Result { - Ok(DataType::Utf8) + Ok(DataType::Utf8View) } fn signature(&self) -> Signature { signature_of_cell() } - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - ensure_columns_n!(columns, 1); + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let [cell_vec] = extract_args(self.name(), &args)?; - let cell_vec = &columns[0]; let size = cell_vec.len(); - let mut results = StringVectorBuilder::with_capacity(size); + let mut builder = StringViewBuilder::with_capacity(size); for i in 0..size { - let cell = cell_from_value(cell_vec.get(i)); - let res = cell.map(|cell| cell.to_token()); + let v = ScalarValue::try_from_array(&cell_vec, i)?; + let v = cell_from_value(v).map(|x| x.to_token()); - results.push(res.as_deref()); + builder.append_option(v.as_deref()); } - Ok(results.to_vector()) + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) } } @@ -191,23 +200,28 @@ impl Function for S2CellParent { signature_of_cell_and_level() } - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - ensure_columns_n!(columns, 2); + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let [cell_vec, levels] = extract_args(self.name(), &args)?; - let cell_vec = &columns[0]; - let level_vec = &columns[1]; let size = cell_vec.len(); - let mut results = UInt64VectorBuilder::with_capacity(size); + let mut builder = UInt64Builder::with_capacity(size); for i in 0..size { - let cell = cell_from_value(cell_vec.get(i)); - let level = value_to_level(level_vec.get(i))?; - let result = cell.map(|cell| cell.parent(level).0); + let cell = ScalarValue::try_from_array(&cell_vec, i).map(cell_from_value)?; + let level = ScalarValue::try_from_array(&levels, i).and_then(value_to_level)?; + let result = if let (Some(cell), Some(level)) = (cell, level) { + Some(cell.parent(level).0) + } else { + None + }; - results.push(result); + builder.append_option(result); } - Ok(results.to_vector()) + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) } } @@ -233,24 +247,30 @@ fn signature_of_cell_and_level() -> Signature { Signature::one_of(signatures, Volatility::Stable) } -fn cell_from_value(v: Value) -> Option { +fn cell_from_value(v: ScalarValue) -> Option { match v { - Value::Int64(v) => Some(CellID(v as u64)), - Value::UInt64(v) => Some(CellID(v)), + ScalarValue::Int64(v) => v.map(|x| CellID(x as u64)), + ScalarValue::UInt64(v) => v.map(CellID), _ => None, } } -fn value_to_level(v: Value) -> Result { +fn value_to_level(v: ScalarValue) -> datafusion_common::Result> { match v { - Value::Int8(v) => ensure_and_coerce!(v >= 0 && v <= MAX_LEVEL as i8, v as u64), - Value::Int16(v) => ensure_and_coerce!(v >= 0 && v <= MAX_LEVEL as i16, v as u64), - Value::Int32(v) => ensure_and_coerce!(v >= 0 && v <= MAX_LEVEL as i32, v as u64), - Value::Int64(v) => ensure_and_coerce!(v >= 0 && v <= MAX_LEVEL as i64, v as u64), - Value::UInt8(v) => ensure_and_coerce!(v <= MAX_LEVEL as u8, v as u64), - Value::UInt16(v) => ensure_and_coerce!(v <= MAX_LEVEL as u16, v as u64), - Value::UInt32(v) => ensure_and_coerce!(v <= MAX_LEVEL as u32, v as u64), - Value::UInt64(v) => ensure_and_coerce!(v <= MAX_LEVEL, v), - _ => unreachable!(), + ScalarValue::Int8(Some(v)) => ensure_and_coerce!(v >= 0 && v <= MAX_LEVEL as i8, v as u64), + ScalarValue::Int16(Some(v)) => { + ensure_and_coerce!(v >= 0 && v <= MAX_LEVEL as i16, v as u64) + } + ScalarValue::Int32(Some(v)) => { + ensure_and_coerce!(v >= 0 && v <= MAX_LEVEL as i32, v as u64) + } + ScalarValue::Int64(Some(v)) => { + ensure_and_coerce!(v >= 0 && v <= MAX_LEVEL as i64, v as u64) + } + ScalarValue::UInt8(Some(v)) => ensure_and_coerce!(v <= MAX_LEVEL as u8, v as u64), + ScalarValue::UInt16(Some(v)) => ensure_and_coerce!(v <= MAX_LEVEL as u16, v as u64), + ScalarValue::UInt32(Some(v)) => ensure_and_coerce!(v <= MAX_LEVEL as u32, v as u64), + ScalarValue::UInt64(Some(v)) => ensure_and_coerce!(v <= MAX_LEVEL, v), + _ => Ok(None), } } diff --git a/src/common/function/src/scalars/geo/wkt.rs b/src/common/function/src/scalars/geo/wkt.rs index b6938bf7ee..7d36e2b41f 100644 --- a/src/common/function/src/scalars/geo/wkt.rs +++ b/src/common/function/src/scalars/geo/wkt.rs @@ -12,22 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::LazyLock; +use std::sync::{Arc, LazyLock}; use common_error::ext::{BoxedError, PlainError}; use common_error::status_code::StatusCode; use common_query::error::{self, Result}; -use datafusion_expr::{Signature, TypeSignature, Volatility}; -use datatypes::arrow::datatypes::DataType; -use datatypes::scalars::ScalarVectorBuilder; -use datatypes::vectors::{MutableVector, StringVectorBuilder, VectorRef}; +use datafusion_common::arrow::array::{Array, AsArray, StringViewBuilder}; +use datafusion_common::arrow::datatypes::{DataType, Float64Type}; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, Volatility}; use derive_more::Display; use geo_types::{Geometry, Point}; use snafu::ResultExt; use wkt::{ToWkt, TryFromWkt}; -use crate::function::{Function, FunctionContext}; -use crate::scalars::geo::helpers::{ensure_columns_len, ensure_columns_n}; +use crate::function::{Function, extract_args}; +use crate::scalars::geo::helpers; static COORDINATE_TYPES: LazyLock> = LazyLock::new(|| vec![DataType::Float32, DataType::Float64]); @@ -43,7 +42,7 @@ impl Function for LatLngToPointWkt { } fn return_type(&self, _: &[DataType]) -> Result { - Ok(DataType::Utf8) + Ok(DataType::Utf8View) } fn signature(&self) -> Signature { @@ -59,28 +58,33 @@ impl Function for LatLngToPointWkt { Signature::one_of(signatures, Volatility::Stable) } - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - ensure_columns_n!(columns, 2); + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let [arg0, arg1] = extract_args(self.name(), &args)?; - let lat_vec = &columns[0]; - let lng_vec = &columns[1]; + let arg0 = helpers::cast::(&arg0)?; + let lat_vec = arg0.as_primitive::(); + let arg1 = helpers::cast::(&arg1)?; + let lng_vec = arg1.as_primitive::(); let size = lat_vec.len(); - let mut results = StringVectorBuilder::with_capacity(size); + let mut builder = StringViewBuilder::with_capacity(size); for i in 0..size { - let lat = lat_vec.get(i).as_f64_lossy(); - let lng = lng_vec.get(i).as_f64_lossy(); + let lat = lat_vec.is_valid(i).then(|| lat_vec.value(i)); + let lng = lng_vec.is_valid(i).then(|| lng_vec.value(i)); let result = match (lat, lng) { (Some(lat), Some(lng)) => Some(Point::new(lng, lat).wkt_string()), _ => None, }; - results.push(result.as_deref()); + builder.append_option(result.as_deref()); } - Ok(results.to_vector()) + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) } } diff --git a/src/common/function/src/scalars/hll_count.rs b/src/common/function/src/scalars/hll_count.rs index fa6e787db0..70cd25798b 100644 --- a/src/common/function/src/scalars/hll_count.rs +++ b/src/common/function/src/scalars/hll_count.rs @@ -16,18 +16,17 @@ use std::fmt; use std::fmt::Display; +use std::sync::Arc; -use common_query::error::{DowncastVectorSnafu, InvalidFuncArgsSnafu, Result}; -use datafusion_expr::{Signature, Volatility}; +use common_query::error::Result; +use datafusion_common::DataFusionError; +use datafusion_common::arrow::array::{Array, AsArray, UInt64Builder}; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility}; use datatypes::arrow::datatypes::DataType; -use datatypes::prelude::Vector; -use datatypes::scalars::{ScalarVector, ScalarVectorBuilder}; -use datatypes::vectors::{BinaryVector, MutableVector, UInt64VectorBuilder, VectorRef}; use hyperloglogplus::HyperLogLog; -use snafu::OptionExt; use crate::aggrs::approximate::hll::HllStateType; -use crate::function::{Function, FunctionContext}; +use crate::function::{Function, extract_args}; use crate::function_registry::FunctionRegistry; const NAME: &str = "hll_count"; @@ -67,28 +66,27 @@ impl Function for HllCalcFunction { Signature::exact(vec![DataType::Binary], Volatility::Immutable) } - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - if columns.len() != 1 { - return InvalidFuncArgsSnafu { - err_msg: format!("hll_count expects 1 argument, got {}", columns.len()), - } - .fail(); - } + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let [arg0] = extract_args(self.name(), &args)?; - let hll_vec = columns[0] - .as_any() - .downcast_ref::() - .with_context(|| DowncastVectorSnafu { - err_msg: format!("expect BinaryVector, got {}", columns[0].vector_type_name()), - })?; + let Some(hll_vec) = arg0.as_binary_opt::() else { + return Err(DataFusionError::Execution(format!( + "'{}' expects argument to be Binary datatype, got {}", + self.name(), + arg0.data_type() + ))); + }; let len = hll_vec.len(); - let mut builder = UInt64VectorBuilder::with_capacity(len); + let mut builder = UInt64Builder::with_capacity(len); for i in 0..len { - let hll_opt = hll_vec.get_data(i); + let hll_opt = hll_vec.is_valid(i).then(|| hll_vec.value(i)); if hll_opt.is_none() { - builder.push_null(); + builder.append_null(); continue; } @@ -99,15 +97,15 @@ impl Function for HllCalcFunction { Ok(h) => h, Err(e) => { common_telemetry::trace!("Failed to deserialize HyperLogLogPlus: {}", e); - builder.push_null(); + builder.append_null(); continue; } }; - builder.push(Some(hll.count().round() as u64)); + builder.append_value(hll.count().round() as u64); } - Ok(builder.to_vector()) + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) } } @@ -115,7 +113,9 @@ impl Function for HllCalcFunction { mod tests { use std::sync::Arc; - use datatypes::vectors::BinaryVector; + use arrow_schema::Field; + use datafusion_common::arrow::array::BinaryArray; + use datafusion_common::arrow::datatypes::UInt64Type; use super::*; use crate::utils::FixedRandomState; @@ -136,17 +136,27 @@ mod tests { } let serialized_bytes = bincode::serialize(&hll).unwrap(); - let args: Vec = vec![Arc::new(BinaryVector::from(vec![Some(serialized_bytes)]))]; + let args = vec![ColumnarValue::Array(Arc::new(BinaryArray::from_iter( + vec![Some(serialized_bytes)], + )))]; - let result = function.eval(&FunctionContext::default(), &args).unwrap(); + let result = function + .invoke_with_args(ScalarFunctionArgs { + args, + arg_fields: vec![], + number_rows: 1, + return_field: Arc::new(Field::new("x", DataType::UInt64, false)), + config_options: Arc::new(Default::default()), + }) + .unwrap(); + let ColumnarValue::Array(result) = result else { + unreachable!() + }; + let result = result.as_primitive::(); assert_eq!(result.len(), 1); // Test cardinality estimate - if let datatypes::value::Value::UInt64(v) = result.get(0) { - assert_eq!(v, 10); - } else { - panic!("Expected uint64 value"); - } + assert_eq!(result.value(0), 10); } #[test] @@ -154,20 +164,38 @@ mod tests { let function = HllCalcFunction; // Test with invalid number of arguments - let args: Vec = vec![]; - let result = function.eval(&FunctionContext::default(), &args); + let result = function.invoke_with_args(ScalarFunctionArgs { + args: vec![], + arg_fields: vec![], + number_rows: 0, + return_field: Arc::new(Field::new("x", DataType::UInt64, false)), + config_options: Arc::new(Default::default()), + }); assert!(result.is_err()); assert!( result .unwrap_err() .to_string() - .contains("hll_count expects 1 argument") + .contains("Execution error: hll_count function requires 1 argument, got 0") ); // Test with invalid binary data - let args: Vec = vec![Arc::new(BinaryVector::from(vec![Some(vec![1, 2, 3])]))]; // Invalid binary data - let result = function.eval(&FunctionContext::default(), &args).unwrap(); + let result = function + .invoke_with_args(ScalarFunctionArgs { + args: vec![ColumnarValue::Array(Arc::new(BinaryArray::from_iter( + vec![Some(vec![1, 2, 3])], + )))], + arg_fields: vec![], + number_rows: 0, + return_field: Arc::new(Field::new("x", DataType::UInt64, false)), + config_options: Arc::new(Default::default()), + }) + .unwrap(); + let ColumnarValue::Array(result) = result else { + unreachable!() + }; + let result = result.as_primitive::(); assert_eq!(result.len(), 1); - assert!(matches!(result.get(0), datatypes::value::Value::Null)); + assert!(result.is_null(0)); } } diff --git a/src/common/function/src/scalars/matches.rs b/src/common/function/src/scalars/matches.rs index e8b87943aa..1f0ed5f9b5 100644 --- a/src/common/function/src/scalars/matches.rs +++ b/src/common/function/src/scalars/matches.rs @@ -23,13 +23,13 @@ use datafusion::common::{DFSchema, Result as DfResult}; use datafusion::execution::SessionStateBuilder; use datafusion::logical_expr::{self, ColumnarValue, Expr, Volatility}; use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}; -use datafusion_common::{DataFusionError, utils}; +use datafusion_common::DataFusionError; use datafusion_expr::{ScalarFunctionArgs, Signature}; use datatypes::arrow::array::RecordBatch; use datatypes::arrow::datatypes::{DataType, Field}; use snafu::{OptionExt, ensure}; -use crate::function::Function; +use crate::function::{Function, extract_args}; use crate::function_registry::FunctionRegistry; /// `matches` for full text search. @@ -65,8 +65,7 @@ impl Function for MatchesFunction { // TODO: read case-sensitive config fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DfResult { - let args = ColumnarValue::values_to_arrays(&args.args)?; - let [data_column, patterns] = utils::take_function_args(self.name(), args)?; + let [data_column, patterns] = extract_args(self.name(), &args)?; if data_column.is_empty() { return Ok(ColumnarValue::Array(Arc::new(BooleanArray::from( diff --git a/src/common/function/src/scalars/uddsketch_calc.rs b/src/common/function/src/scalars/uddsketch_calc.rs index 2713d8ceaa..f802f25543 100644 --- a/src/common/function/src/scalars/uddsketch_calc.rs +++ b/src/common/function/src/scalars/uddsketch_calc.rs @@ -16,17 +16,16 @@ use std::fmt; use std::fmt::Display; +use std::sync::Arc; -use common_query::error::{DowncastVectorSnafu, InvalidFuncArgsSnafu, Result}; -use datafusion_expr::{Signature, Volatility}; -use datatypes::arrow::datatypes::DataType; -use datatypes::prelude::Vector; -use datatypes::scalars::{ScalarVector, ScalarVectorBuilder}; -use datatypes::vectors::{BinaryVector, Float64VectorBuilder, MutableVector, VectorRef}; -use snafu::OptionExt; +use common_query::error::Result; +use datafusion_common::DataFusionError; +use datafusion_common::arrow::array::{Array, AsArray, Float64Builder}; +use datafusion_common::arrow::datatypes::{DataType, Float64Type}; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility}; use uddsketch::UDDSketch; -use crate::function::{Function, FunctionContext}; +use crate::function::{Function, extract_args}; use crate::function_registry::FunctionRegistry; const NAME: &str = "uddsketch_calc"; @@ -71,30 +70,35 @@ impl Function for UddSketchCalcFunction { ) } - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - if columns.len() != 2 { - return InvalidFuncArgsSnafu { - err_msg: format!("uddsketch_calc expects 2 arguments, got {}", columns.len()), - } - .fail(); - } + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let [arg0, arg1] = extract_args(self.name(), &args)?; - let perc_vec = &columns[0]; - let sketch_vec = columns[1] - .as_any() - .downcast_ref::() - .with_context(|| DowncastVectorSnafu { - err_msg: format!("expect BinaryVector, got {}", columns[1].vector_type_name()), - })?; + let Some(percentages) = arg0.as_primitive_opt::() else { + return Err(DataFusionError::Execution(format!( + "'{}' expects 1st argument to be Float64 datatype, got {}", + self.name(), + arg0.data_type() + ))); + }; + let Some(sketch_vec) = arg1.as_binary_opt::() else { + return Err(DataFusionError::Execution(format!( + "'{}' expects 2nd argument to be Binary datatype, got {}", + self.name(), + arg1.data_type() + ))); + }; let len = sketch_vec.len(); - let mut builder = Float64VectorBuilder::with_capacity(len); + let mut builder = Float64Builder::with_capacity(len); for i in 0..len { - let perc_opt = perc_vec.get(i).as_f64_lossy(); - let sketch_opt = sketch_vec.get_data(i); + let perc_opt = percentages.is_valid(i).then(|| percentages.value(i)); + let sketch_opt = sketch_vec.is_valid(i).then(|| sketch_vec.value(i)); if sketch_opt.is_none() || perc_opt.is_none() { - builder.push_null(); + builder.append_null(); continue; } @@ -106,7 +110,7 @@ impl Function for UddSketchCalcFunction { Ok(s) => s, Err(e) => { common_telemetry::trace!("Failed to deserialize UDDSketch: {}", e); - builder.push_null(); + builder.append_null(); continue; } }; @@ -115,15 +119,15 @@ impl Function for UddSketchCalcFunction { // This is important to avoid panics when calling estimate_quantile on an empty sketch // In practice, this will happen if input is all null if sketch.bucket_iter().count() == 0 { - builder.push_null(); + builder.append_null(); continue; } // Compute the estimated quantile from the sketch let result = sketch.estimate_quantile(perc); - builder.push(Some(result)); + builder.append_value(result); } - Ok(builder.to_vector()) + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) } } @@ -131,7 +135,8 @@ impl Function for UddSketchCalcFunction { mod tests { use std::sync::Arc; - use datatypes::vectors::{BinaryVector, Float64Vector}; + use arrow_schema::Field; + use datafusion_common::arrow::array::{BinaryArray, Float64Array}; use super::*; @@ -165,26 +170,32 @@ mod tests { let serialized = bincode::serialize(&sketch).unwrap(); let percentiles = vec![0.5, 0.9, 0.95]; - let args: Vec = vec![ - Arc::new(Float64Vector::from_vec(percentiles.clone())), - Arc::new(BinaryVector::from(vec![Some(serialized.clone()); 3])), + let args = vec![ + ColumnarValue::Array(Arc::new(Float64Array::from(percentiles.clone()))), + ColumnarValue::Array(Arc::new(BinaryArray::from_iter_values(vec![serialized; 3]))), ]; - let result = function.eval(&FunctionContext::default(), &args).unwrap(); + let result = function + .invoke_with_args(ScalarFunctionArgs { + args, + arg_fields: vec![], + number_rows: 3, + return_field: Arc::new(Field::new("x", DataType::Float64, false)), + config_options: Arc::new(Default::default()), + }) + .unwrap(); + let ColumnarValue::Array(result) = result else { + unreachable!() + }; + let result = result.as_primitive::(); assert_eq!(result.len(), 3); // Test median (p50) - assert!( - matches!(result.get(0), datatypes::value::Value::Float64(v) if (v - expected_p50).abs() < 1e-10) - ); + assert!((result.value(0) - expected_p50).abs() < 1e-10); // Test p90 - assert!( - matches!(result.get(1), datatypes::value::Value::Float64(v) if (v - expected_p90).abs() < 1e-10) - ); + assert!((result.value(1) - expected_p90).abs() < 1e-10); // Test p95 - assert!( - matches!(result.get(2), datatypes::value::Value::Float64(v) if (v - expected_p95).abs() < 1e-10) - ); + assert!((result.value(2) - expected_p95).abs() < 1e-10); } #[test] @@ -192,23 +203,42 @@ mod tests { let function = UddSketchCalcFunction; // Test with invalid number of arguments - let args: Vec = vec![Arc::new(Float64Vector::from_vec(vec![0.95]))]; - let result = function.eval(&FunctionContext::default(), &args); + let result = function.invoke_with_args(ScalarFunctionArgs { + args: vec![ColumnarValue::Array(Arc::new(Float64Array::from(vec![ + 0.95, + ])))], + arg_fields: vec![], + number_rows: 0, + return_field: Arc::new(Field::new("x", DataType::Float64, false)), + config_options: Arc::new(Default::default()), + }); assert!(result.is_err()); assert!( result .unwrap_err() .to_string() - .contains("uddsketch_calc expects 2 arguments") + .contains("Execution error: uddsketch_calc function requires 2 arguments, got 1") ); // Test with invalid binary data - let args: Vec = vec![ - Arc::new(Float64Vector::from_vec(vec![0.95])), - Arc::new(BinaryVector::from(vec![Some(vec![1, 2, 3])])), // Invalid binary data + let args = vec![ + ColumnarValue::Array(Arc::new(Float64Array::from(vec![0.95]))), + ColumnarValue::Array(Arc::new(BinaryArray::from_iter(vec![Some(vec![1, 2, 3])]))), ]; - let result = function.eval(&FunctionContext::default(), &args).unwrap(); + let result = function + .invoke_with_args(ScalarFunctionArgs { + args, + arg_fields: vec![], + number_rows: 0, + return_field: Arc::new(Field::new("x", DataType::Float64, false)), + config_options: Arc::new(Default::default()), + }) + .unwrap(); + let ColumnarValue::Array(result) = result else { + unreachable!() + }; + let result = result.as_primitive::(); assert_eq!(result.len(), 1); - assert!(matches!(result.get(0), datatypes::value::Value::Null)); + assert!(result.is_null(0)); } } diff --git a/src/common/function/src/scalars/vector/vector_subvector.rs b/src/common/function/src/scalars/vector/vector_subvector.rs index 239edaaa93..ef02151a07 100644 --- a/src/common/function/src/scalars/vector/vector_subvector.rs +++ b/src/common/function/src/scalars/vector/vector_subvector.rs @@ -19,12 +19,12 @@ use common_query::error::{InvalidFuncArgsSnafu, Result}; use datafusion::arrow::array::{Array, AsArray, BinaryViewBuilder}; use datafusion::arrow::datatypes::Int64Type; use datafusion::logical_expr::ColumnarValue; -use datafusion_common::{ScalarValue, utils}; +use datafusion_common::ScalarValue; use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility}; use datatypes::arrow::datatypes::DataType; use snafu::ensure; -use crate::function::Function; +use crate::function::{Function, extract_args}; use crate::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit}; const NAME: &str = "vec_subvector"; @@ -71,8 +71,7 @@ impl Function for VectorSubvectorFunction { &self, args: ScalarFunctionArgs, ) -> datafusion_common::Result { - let args = ColumnarValue::values_to_arrays(&args.args)?; - let [arg0, arg1, arg2] = utils::take_function_args(self.name(), args)?; + let [arg0, arg1, arg2] = extract_args(self.name(), &args)?; let arg1 = arg1.as_primitive::(); let arg2 = arg2.as_primitive::(); diff --git a/src/common/function/src/system/database.rs b/src/common/function/src/system/database.rs index d28975afb4..97e4ee43ab 100644 --- a/src/common/function/src/system/database.rs +++ b/src/common/function/src/system/database.rs @@ -13,16 +13,14 @@ // limitations under the License. use std::fmt::{self}; -use std::sync::Arc; use common_query::error::Result; use datafusion::arrow::datatypes::DataType; -use datafusion_expr::{Signature, Volatility}; -use datatypes::prelude::ScalarVector; -use datatypes::vectors::{StringVector, UInt32Vector, VectorRef}; +use datafusion_common::ScalarValue; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility}; use derive_more::Display; -use crate::function::{Function, FunctionContext}; +use crate::function::{Function, find_function_context}; /// A function to return current schema name. #[derive(Clone, Debug, Default)] @@ -55,17 +53,21 @@ impl Function for DatabaseFunction { } fn return_type(&self, _: &[DataType]) -> Result { - Ok(DataType::Utf8) + Ok(DataType::Utf8View) } fn signature(&self) -> Signature { Signature::nullary(Volatility::Immutable) } - fn eval(&self, func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result { + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let func_ctx = find_function_context(&args)?; let db = func_ctx.query_ctx.current_schema(); - Ok(Arc::new(StringVector::from_slice(&[&db])) as _) + Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(db)))) } } @@ -77,17 +79,21 @@ impl Function for CurrentSchemaFunction { } fn return_type(&self, _: &[DataType]) -> Result { - Ok(DataType::Utf8) + Ok(DataType::Utf8View) } fn signature(&self) -> Signature { Signature::nullary(Volatility::Immutable) } - fn eval(&self, func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result { + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let func_ctx = find_function_context(&args)?; let db = func_ctx.query_ctx.current_schema(); - Ok(Arc::new(StringVector::from_slice(&[&db])) as _) + Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(db)))) } } @@ -97,17 +103,23 @@ impl Function for SessionUserFunction { } fn return_type(&self, _: &[DataType]) -> Result { - Ok(DataType::Utf8) + Ok(DataType::Utf8View) } fn signature(&self) -> Signature { Signature::nullary(Volatility::Immutable) } - fn eval(&self, func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result { + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let func_ctx = find_function_context(&args)?; let user = func_ctx.query_ctx.current_user(); - Ok(Arc::new(StringVector::from_slice(&[user.username()])) as _) + Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + user.username().to_string(), + )))) } } @@ -117,17 +129,23 @@ impl Function for ReadPreferenceFunction { } fn return_type(&self, _: &[DataType]) -> Result { - Ok(DataType::Utf8) + Ok(DataType::Utf8View) } fn signature(&self) -> Signature { Signature::nullary(Volatility::Immutable) } - fn eval(&self, func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result { + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let func_ctx = find_function_context(&args)?; let read_preference = func_ctx.query_ctx.read_preference(); - Ok(Arc::new(StringVector::from_slice(&[read_preference.as_ref()])) as _) + Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + read_preference.to_string(), + )))) } } @@ -144,10 +162,14 @@ impl Function for PgBackendPidFunction { Signature::nullary(Volatility::Immutable) } - fn eval(&self, func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result { + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let func_ctx = find_function_context(&args)?; let pid = func_ctx.query_ctx.process_id(); - Ok(Arc::new(UInt32Vector::from_slice([pid])) as _) + Ok(ColumnarValue::Scalar(ScalarValue::UInt64(Some(pid as u64)))) } } @@ -157,17 +179,21 @@ impl Function for ConnectionIdFunction { } fn return_type(&self, _: &[DataType]) -> Result { - Ok(DataType::UInt64) + Ok(DataType::UInt32) } fn signature(&self) -> Signature { Signature::nullary(Volatility::Immutable) } - fn eval(&self, func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result { + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let func_ctx = find_function_context(&args)?; let pid = func_ctx.query_ctx.process_id(); - Ok(Arc::new(UInt32Vector::from_slice([pid])) as _) + Ok(ColumnarValue::Scalar(ScalarValue::UInt32(Some(pid)))) } } @@ -199,14 +225,17 @@ impl fmt::Display for ReadPreferenceFunction { mod tests { use std::sync::Arc; + use arrow_schema::Field; + use datafusion_common::config::ConfigOptions; use session::context::QueryContextBuilder; use super::*; + use crate::function::FunctionContext; #[test] fn test_build_function() { let build = DatabaseFunction; assert_eq!("database", build.name()); - assert_eq!(DataType::Utf8, build.return_type(&[]).unwrap()); + assert_eq!(DataType::Utf8View, build.return_type(&[]).unwrap()); assert_eq!(build.signature(), Signature::nullary(Volatility::Immutable)); let query_ctx = QueryContextBuilder::default() @@ -214,12 +243,24 @@ mod tests { .build() .into(); - let func_ctx = FunctionContext { + let mut config_options = ConfigOptions::default(); + config_options.extensions.insert(FunctionContext { query_ctx, ..Default::default() + }); + let config_options = Arc::new(config_options); + + let args = ScalarFunctionArgs { + args: vec![], + arg_fields: vec![], + number_rows: 0, + return_field: Arc::new(Field::new("x", DataType::UInt64, false)), + config_options, }; - let vector = build.eval(&func_ctx, &[]).unwrap(); - let expect: VectorRef = Arc::new(StringVector::from(vec!["test_db"])); - assert_eq!(expect, vector); + let result = build.invoke_with_args(args).unwrap(); + let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s))) = result else { + unreachable!() + }; + assert_eq!(s, "test_db"); } } diff --git a/src/query/src/datafusion.rs b/src/query/src/datafusion.rs index 96e1f60763..61430d6786 100644 --- a/src/query/src/datafusion.rs +++ b/src/query/src/datafusion.rs @@ -25,6 +25,7 @@ use async_trait::async_trait; use common_base::Plugins; use common_catalog::consts::is_readonly_schema; use common_error::ext::BoxedError; +use common_function::function::FunctionContext; use common_function::function_factory::ScalarFunctionFactory; use common_query::{Output, OutputData, OutputMeta}; use common_recordbatch::adapter::RecordBatchStreamAdapter; @@ -568,6 +569,16 @@ impl QueryEngine for DatafusionQueryEngine { }); } } + + state + .config_mut() + .options_mut() + .extensions + .insert(FunctionContext { + query_ctx: query_ctx.clone(), + state: self.engine_state().function_state(), + }); + QueryEngineContext::new(state, query_ctx) } diff --git a/src/query/src/dist_plan/analyzer.rs b/src/query/src/dist_plan/analyzer.rs index 424500f3a1..50dfd6c56e 100644 --- a/src/query/src/dist_plan/analyzer.rs +++ b/src/query/src/dist_plan/analyzer.rs @@ -15,11 +15,13 @@ use std::collections::{BTreeMap, BTreeSet, HashSet}; use std::sync::Arc; +use chrono::{DateTime, Utc}; use common_telemetry::debug; use datafusion::config::{ConfigExtension, ExtensionOptions}; use datafusion::datasource::DefaultTableSource; use datafusion::error::Result as DfResult; use datafusion_common::Column; +use datafusion_common::alias::AliasGenerator; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; use datafusion_expr::expr::{Exists, InSubquery}; @@ -27,7 +29,7 @@ use datafusion_expr::utils::expr_to_columns; use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder, Subquery, col as col_fn}; use datafusion_optimizer::analyzer::AnalyzerRule; use datafusion_optimizer::simplify_expressions::SimplifyExpressions; -use datafusion_optimizer::{OptimizerContext, OptimizerRule}; +use datafusion_optimizer::{OptimizerConfig, OptimizerRule}; use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan}; use table::metadata::TableType; use table::table::adapter::DfTableProviderAdapter; @@ -100,8 +102,42 @@ impl AnalyzerRule for DistPlannerAnalyzer { plan: LogicalPlan, config: &ConfigOptions, ) -> datafusion_common::Result { - // preprocess the input plan - let optimizer_context = OptimizerContext::new(); + let mut config = config.clone(); + // Aligned with the behavior in `datafusion_optimizer::OptimizerContext::new()`. + config.optimizer.filter_null_join_keys = true; + let config = Arc::new(config); + + // The `ConstEvaluator` in `SimplifyExpressions` might evaluate some UDFs early in the + // planning stage, by executing them directly. For example, the `database()` function. + // So the `ConfigOptions` here (which is set from the session context) should be present + // in the UDF's `ScalarFunctionArgs`. However, the default implementation in DataFusion + // seems to lost track on it: the `ConfigOptions` is recreated with its default values again. + // So we create a custom `OptimizerConfig` with the desired `ConfigOptions` + // to walk around the issue. + struct OptimizerContext { + inner: datafusion_optimizer::OptimizerContext, + config: Arc, + } + + impl OptimizerConfig for OptimizerContext { + fn query_execution_start_time(&self) -> DateTime { + self.inner.query_execution_start_time() + } + + fn alias_generator(&self) -> &Arc { + self.inner.alias_generator() + } + + fn options(&self) -> Arc { + self.config.clone() + } + } + + let optimizer_context = OptimizerContext { + inner: datafusion_optimizer::OptimizerContext::new(), + config: config.clone(), + }; + let plan = SimplifyExpressions::new() .rewrite(plan, &optimizer_context)? .data;