diff --git a/src/common/function/src/scalars/date/date_add.rs b/src/common/function/src/scalars/date/date_add.rs index 9785b92a91..973535fc7b 100644 --- a/src/common/function/src/scalars/date/date_add.rs +++ b/src/common/function/src/scalars/date/date_add.rs @@ -14,14 +14,15 @@ use std::fmt; -use common_query::error::{ArrowComputeSnafu, IntoVectorSnafu, InvalidFuncArgsSnafu, Result}; -use datafusion_expr::Signature; +use common_query::error::{ArrowComputeSnafu, Result}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::utils; +use datafusion_expr::{ScalarFunctionArgs, Signature}; use datatypes::arrow::compute::kernels::numeric; use datatypes::arrow::datatypes::{DataType, IntervalUnit, TimeUnit}; -use datatypes::vectors::{Helper, VectorRef}; -use snafu::{ResultExt, ensure}; +use snafu::ResultExt; -use crate::function::{Function, FunctionContext}; +use crate::function::Function; use crate::helper; /// A function adds an interval value to Timestamp, Date, and return the result. @@ -58,25 +59,15 @@ impl Function for DateAddFunction { ) } - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - ensure!( - columns.len() == 2, - InvalidFuncArgsSnafu { - err_msg: format!( - "The length of the args is not correct, expect 2, have: {}", - columns.len() - ), - } - ); - - let left = columns[0].to_arrow_array(); - let right = columns[1].to_arrow_array(); + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let args = ColumnarValue::values_to_arrays(&args.args)?; + let [left, right] = utils::take_function_args(self.name(), args)?; let result = numeric::add(&left, &right).context(ArrowComputeSnafu)?; - let arrow_type = result.data_type().clone(); - Helper::try_into_vector(result).context(IntoVectorSnafu { - data_type: arrow_type, - }) + Ok(ColumnarValue::Array(result)) } } @@ -90,12 +81,14 @@ impl fmt::Display for DateAddFunction { mod tests { use std::sync::Arc; - use datafusion_expr::{TypeSignature, Volatility}; - use datatypes::arrow::datatypes::IntervalDayTime; - use datatypes::value::Value; - use datatypes::vectors::{ - DateVector, IntervalDayTimeVector, IntervalYearMonthVector, TimestampSecondVector, + use arrow_schema::Field; + use datafusion::arrow::array::{ + Array, AsArray, Date32Array, IntervalDayTimeArray, IntervalYearMonthArray, + TimestampSecondArray, }; + use datafusion::arrow::datatypes::{Date32Type, IntervalDayTime, TimestampSecondType}; + use datafusion_common::config::ConfigOptions; + use datafusion_expr::{TypeSignature, Volatility}; use super::{DateAddFunction, *}; @@ -142,25 +135,37 @@ mod tests { ]; let results = [Some(124), None, Some(45), None]; - let time_vector = TimestampSecondVector::from(times.clone()); - let interval_vector = IntervalDayTimeVector::from_vec(intervals); - let args: Vec = vec![Arc::new(time_vector), Arc::new(interval_vector)]; - let vector = f.eval(&FunctionContext::default(), &args).unwrap(); + let args = vec![ + ColumnarValue::Array(Arc::new(TimestampSecondArray::from(times.clone()))), + ColumnarValue::Array(Arc::new(IntervalDayTimeArray::from(intervals))), + ]; + + let vector = f + .invoke_with_args(ScalarFunctionArgs { + args, + arg_fields: vec![], + number_rows: 4, + return_field: Arc::new(Field::new( + "x", + DataType::Timestamp(TimeUnit::Second, None), + true, + )), + config_options: Arc::new(ConfigOptions::new()), + }) + .and_then(|v| ColumnarValue::values_to_arrays(&[v])) + .map(|mut a| a.remove(0)) + .unwrap(); + let vector = vector.as_primitive::(); assert_eq!(4, vector.len()); for (i, _t) in times.iter().enumerate() { - let v = vector.get(i); let result = results.get(i).unwrap(); - if result.is_none() { - assert_eq!(Value::Null, v); - continue; - } - match v { - Value::Timestamp(ts) => { - assert_eq!(ts.value(), result.unwrap()); - } - _ => unreachable!(), + if let Some(x) = result { + assert!(vector.is_valid(i)); + assert_eq!(vector.value(i), *x); + } else { + assert!(vector.is_null(i)); } } } @@ -174,25 +179,37 @@ mod tests { let intervals = vec![1, 2, 3, 1]; let results = [Some(154), None, Some(131), None]; - let date_vector = DateVector::from(dates.clone()); - let interval_vector = IntervalYearMonthVector::from_vec(intervals); - let args: Vec = vec![Arc::new(date_vector), Arc::new(interval_vector)]; - let vector = f.eval(&FunctionContext::default(), &args).unwrap(); + let args = vec![ + ColumnarValue::Array(Arc::new(Date32Array::from(dates.clone()))), + ColumnarValue::Array(Arc::new(IntervalYearMonthArray::from(intervals))), + ]; + + let vector = f + .invoke_with_args(ScalarFunctionArgs { + args, + arg_fields: vec![], + number_rows: 4, + return_field: Arc::new(Field::new( + "x", + DataType::Timestamp(TimeUnit::Second, None), + true, + )), + config_options: Arc::new(ConfigOptions::new()), + }) + .and_then(|v| ColumnarValue::values_to_arrays(&[v])) + .map(|mut a| a.remove(0)) + .unwrap(); + let vector = vector.as_primitive::(); assert_eq!(4, vector.len()); for (i, _t) in dates.iter().enumerate() { - let v = vector.get(i); let result = results.get(i).unwrap(); - if result.is_none() { - assert_eq!(Value::Null, v); - continue; - } - match v { - Value::Date(date) => { - assert_eq!(date.val(), result.unwrap()); - } - _ => unreachable!(), + if let Some(x) = result { + assert!(vector.is_valid(i)); + assert_eq!(vector.value(i), *x); + } else { + assert!(vector.is_null(i)); } } } diff --git a/src/common/function/src/scalars/date/date_sub.rs b/src/common/function/src/scalars/date/date_sub.rs index e451ee3c6e..6ed5b84c90 100644 --- a/src/common/function/src/scalars/date/date_sub.rs +++ b/src/common/function/src/scalars/date/date_sub.rs @@ -14,14 +14,15 @@ use std::fmt; -use common_query::error::{ArrowComputeSnafu, IntoVectorSnafu, InvalidFuncArgsSnafu, Result}; -use datafusion_expr::Signature; +use common_query::error::{ArrowComputeSnafu, Result}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::utils; +use datafusion_expr::{ScalarFunctionArgs, Signature}; use datatypes::arrow::compute::kernels::numeric; use datatypes::arrow::datatypes::{DataType, IntervalUnit, TimeUnit}; -use datatypes::vectors::{Helper, VectorRef}; -use snafu::{ResultExt, ensure}; +use snafu::ResultExt; -use crate::function::{Function, FunctionContext}; +use crate::function::Function; use crate::helper; /// A function subtracts an interval value to Timestamp, Date, and return the result. @@ -58,25 +59,15 @@ impl Function for DateSubFunction { ) } - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - ensure!( - columns.len() == 2, - InvalidFuncArgsSnafu { - err_msg: format!( - "The length of the args is not correct, expect 2, have: {}", - columns.len() - ), - } - ); - - let left = columns[0].to_arrow_array(); - let right = columns[1].to_arrow_array(); + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let args = ColumnarValue::values_to_arrays(&args.args)?; + let [left, right] = utils::take_function_args(self.name(), args)?; let result = numeric::sub(&left, &right).context(ArrowComputeSnafu)?; - let arrow_type = result.data_type().clone(); - Helper::try_into_vector(result).context(IntoVectorSnafu { - data_type: arrow_type, - }) + Ok(ColumnarValue::Array(result)) } } @@ -90,12 +81,14 @@ impl fmt::Display for DateSubFunction { mod tests { use std::sync::Arc; - use datafusion_expr::{TypeSignature, Volatility}; - use datatypes::arrow::datatypes::IntervalDayTime; - use datatypes::value::Value; - use datatypes::vectors::{ - DateVector, IntervalDayTimeVector, IntervalYearMonthVector, TimestampSecondVector, + use arrow_schema::Field; + use datafusion::arrow::array::{ + Array, AsArray, Date32Array, IntervalDayTimeArray, IntervalYearMonthArray, + TimestampSecondArray, }; + use datafusion::arrow::datatypes::{Date32Type, IntervalDayTime, TimestampSecondType}; + use datafusion_common::config::ConfigOptions; + use datafusion_expr::{TypeSignature, Volatility}; use super::{DateSubFunction, *}; @@ -142,25 +135,37 @@ mod tests { ]; let results = [Some(122), None, Some(39), None]; - let time_vector = TimestampSecondVector::from(times.clone()); - let interval_vector = IntervalDayTimeVector::from_vec(intervals); - let args: Vec = vec![Arc::new(time_vector), Arc::new(interval_vector)]; - let vector = f.eval(&FunctionContext::default(), &args).unwrap(); + let args = vec![ + ColumnarValue::Array(Arc::new(TimestampSecondArray::from(times.clone()))), + ColumnarValue::Array(Arc::new(IntervalDayTimeArray::from(intervals))), + ]; + + let vector = f + .invoke_with_args(ScalarFunctionArgs { + args, + arg_fields: vec![], + number_rows: 4, + return_field: Arc::new(Field::new( + "x", + DataType::Timestamp(TimeUnit::Second, None), + true, + )), + config_options: Arc::new(ConfigOptions::new()), + }) + .and_then(|v| ColumnarValue::values_to_arrays(&[v])) + .map(|mut a| a.remove(0)) + .unwrap(); + let vector = vector.as_primitive::(); assert_eq!(4, vector.len()); for (i, _t) in times.iter().enumerate() { - let v = vector.get(i); let result = results.get(i).unwrap(); - if result.is_none() { - assert_eq!(Value::Null, v); - continue; - } - match v { - Value::Timestamp(ts) => { - assert_eq!(ts.value(), result.unwrap()); - } - _ => unreachable!(), + if let Some(x) = result { + assert!(vector.is_valid(i)); + assert_eq!(vector.value(i), *x); + } else { + assert!(vector.is_null(i)); } } } @@ -180,25 +185,37 @@ mod tests { let intervals = vec![1, 2, 3, 1]; let results = [Some(3659), None, Some(1168), None]; - let date_vector = DateVector::from(dates.clone()); - let interval_vector = IntervalYearMonthVector::from_vec(intervals); - let args: Vec = vec![Arc::new(date_vector), Arc::new(interval_vector)]; - let vector = f.eval(&FunctionContext::default(), &args).unwrap(); + let args = vec![ + ColumnarValue::Array(Arc::new(Date32Array::from(dates.clone()))), + ColumnarValue::Array(Arc::new(IntervalYearMonthArray::from(intervals))), + ]; + + let vector = f + .invoke_with_args(ScalarFunctionArgs { + args, + arg_fields: vec![], + number_rows: 4, + return_field: Arc::new(Field::new( + "x", + DataType::Timestamp(TimeUnit::Second, None), + true, + )), + config_options: Arc::new(ConfigOptions::new()), + }) + .and_then(|v| ColumnarValue::values_to_arrays(&[v])) + .map(|mut a| a.remove(0)) + .unwrap(); + let vector = vector.as_primitive::(); assert_eq!(4, vector.len()); for (i, _t) in dates.iter().enumerate() { - let v = vector.get(i); let result = results.get(i).unwrap(); - if result.is_none() { - assert_eq!(Value::Null, v); - continue; - } - match v { - Value::Date(date) => { - assert_eq!(date.val(), result.unwrap()); - } - _ => unreachable!(), + if let Some(x) = result { + assert!(vector.is_valid(i)); + assert_eq!(vector.value(i), *x); + } else { + assert!(vector.is_null(i)); } } } diff --git a/src/common/function/src/scalars/geo/geohash.rs b/src/common/function/src/scalars/geo/geohash.rs index 59e3444ffb..43085b6b7e 100644 --- a/src/common/function/src/scalars/geo/geohash.rs +++ b/src/common/function/src/scalars/geo/geohash.rs @@ -17,62 +17,26 @@ use std::sync::Arc; use common_error::ext::{BoxedError, PlainError}; use common_error::status_code::StatusCode; -use common_query::error::{self, InvalidFuncArgsSnafu, Result}; -use datafusion::arrow::datatypes::Field; +use common_query::error::{self, Result}; +use datafusion::arrow::array::{Array, AsArray, ListBuilder, StringViewBuilder}; +use datafusion::arrow::datatypes::{DataType, Field, Float64Type, UInt8Type}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{DataFusionError, utils}; use datafusion_expr::type_coercion::aggregates::INTEGERS; -use datafusion_expr::{Signature, TypeSignature, Volatility}; -use datatypes::arrow::datatypes::DataType; -use datatypes::prelude::ConcreteDataType; -use datatypes::scalars::{Scalar, ScalarVectorBuilder}; -use datatypes::value::{ListValue, Value}; -use datatypes::vectors::{ListVectorBuilder, MutableVector, StringVectorBuilder, VectorRef}; +use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility}; use geohash::Coord; -use snafu::{ResultExt, ensure}; +use snafu::ResultExt; -use crate::function::{Function, FunctionContext}; +use crate::function::Function; +use crate::scalars::geo::helpers; -macro_rules! ensure_resolution_usize { - ($v: ident) => { - if !($v > 0 && $v <= 12) { - Err(BoxedError::new(PlainError::new( - format!("Invalid geohash resolution {}, expect value: [1, 12]", $v), - StatusCode::EngineExecuteQuery, - ))) - .context(error::ExecuteSnafu) - } else { - Ok($v as usize) - } - }; -} - -fn try_into_resolution(v: Value) -> Result { - match v { - Value::Int8(v) => { - ensure_resolution_usize!(v) - } - Value::Int16(v) => { - ensure_resolution_usize!(v) - } - Value::Int32(v) => { - ensure_resolution_usize!(v) - } - Value::Int64(v) => { - ensure_resolution_usize!(v) - } - Value::UInt8(v) => { - ensure_resolution_usize!(v) - } - Value::UInt16(v) => { - ensure_resolution_usize!(v) - } - Value::UInt32(v) => { - ensure_resolution_usize!(v) - } - Value::UInt64(v) => { - ensure_resolution_usize!(v) - } - _ => unreachable!(), +fn ensure_resolution_usize(v: u8) -> datafusion_common::Result { + if v == 0 || v > 12 { + return Err(DataFusionError::Execution(format!( + "Invalid geohash resolution {v}, valid value range: [1, 12]" + ))); } + Ok(v as usize) } /// Function that return geohash string for a given geospatial coordinate. @@ -109,31 +73,33 @@ impl Function for GeohashFunction { Signature::one_of(signatures, Volatility::Stable) } - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - ensure!( - columns.len() == 3, - InvalidFuncArgsSnafu { - err_msg: format!( - "The length of the args is not correct, expect 3, provided : {}", - columns.len() - ), - } - ); + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let args = ColumnarValue::values_to_arrays(&args.args)?; + let [lat_vec, lon_vec, resolutions] = utils::take_function_args(self.name(), args)?; - let lat_vec = &columns[0]; - let lon_vec = &columns[1]; - let resolution_vec = &columns[2]; + let lat_vec = helpers::cast::(&lat_vec)?; + let lat_vec = lat_vec.as_primitive::(); + let lon_vec = helpers::cast::(&lon_vec)?; + let lon_vec = lon_vec.as_primitive::(); + let resolutions = helpers::cast::(&resolutions)?; + let resolutions = resolutions.as_primitive::(); let size = lat_vec.len(); - let mut results = StringVectorBuilder::with_capacity(size); + let mut builder = StringViewBuilder::with_capacity(size); for i in 0..size { - let lat = lat_vec.get(i).as_f64_lossy(); - let lon = lon_vec.get(i).as_f64_lossy(); - let r = try_into_resolution(resolution_vec.get(i))?; + let lat = lat_vec.is_valid(i).then(|| lat_vec.value(i)); + let lon = lon_vec.is_valid(i).then(|| lon_vec.value(i)); + let r = resolutions + .is_valid(i) + .then(|| ensure_resolution_usize(resolutions.value(i))) + .transpose()?; - let result = match (lat, lon) { - (Some(lat), Some(lon)) => { + let result = match (lat, lon, r) { + (Some(lat), Some(lon), Some(r)) => { let coord = Coord { x: lon, y: lat }; let encoded = geohash::encode(coord, r) .map_err(|e| { @@ -148,10 +114,10 @@ impl Function for GeohashFunction { _ => None, }; - results.push(result.as_deref()); + builder.append_option(result); } - Ok(results.to_vector()) + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) } } @@ -176,8 +142,8 @@ impl Function for GeohashNeighboursFunction { fn return_type(&self, _: &[DataType]) -> Result { Ok(DataType::List(Arc::new(Field::new( - "x", - DataType::Utf8, + "item", + DataType::Utf8View, false, )))) } @@ -199,32 +165,33 @@ impl Function for GeohashNeighboursFunction { Signature::one_of(signatures, Volatility::Stable) } - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - ensure!( - columns.len() == 3, - InvalidFuncArgsSnafu { - err_msg: format!( - "The length of the args is not correct, expect 3, provided : {}", - columns.len() - ), - } - ); + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let args = ColumnarValue::values_to_arrays(&args.args)?; + let [lat_vec, lon_vec, resolutions] = utils::take_function_args(self.name(), args)?; - let lat_vec = &columns[0]; - let lon_vec = &columns[1]; - let resolution_vec = &columns[2]; + let lat_vec = helpers::cast::(&lat_vec)?; + let lat_vec = lat_vec.as_primitive::(); + let lon_vec = helpers::cast::(&lon_vec)?; + let lon_vec = lon_vec.as_primitive::(); + let resolutions = helpers::cast::(&resolutions)?; + let resolutions = resolutions.as_primitive::(); let size = lat_vec.len(); - let mut results = - ListVectorBuilder::with_type_capacity(ConcreteDataType::string_datatype(), size); + let mut builder = ListBuilder::new(StringViewBuilder::new()); for i in 0..size { - let lat = lat_vec.get(i).as_f64_lossy(); - let lon = lon_vec.get(i).as_f64_lossy(); - let r = try_into_resolution(resolution_vec.get(i))?; + let lat = lat_vec.is_valid(i).then(|| lat_vec.value(i)); + let lon = lon_vec.is_valid(i).then(|| lon_vec.value(i)); + let r = resolutions + .is_valid(i) + .then(|| ensure_resolution_usize(resolutions.value(i))) + .transpose()?; - let result = match (lat, lon) { - (Some(lat), Some(lon)) => { + match (lat, lon, r) { + (Some(lat), Some(lon), Some(r)) => { let coord = Coord { x: lon, y: lat }; let encoded = geohash::encode(coord, r) .map_err(|e| { @@ -242,8 +209,8 @@ impl Function for GeohashNeighboursFunction { )) }) .context(error::ExecuteSnafu)?; - Some(ListValue::new( - vec![ + builder.append_value( + [ neighbours.n, neighbours.nw, neighbours.w, @@ -254,22 +221,14 @@ impl Function for GeohashNeighboursFunction { neighbours.ne, ] .into_iter() - .map(Value::from) - .collect(), - ConcreteDataType::string_datatype(), - )) + .map(Some), + ); } - _ => None, + _ => builder.append_null(), }; - - if let Some(list_value) = result { - results.push(Some(list_value.as_scalar_ref())); - } else { - results.push(None); - } } - Ok(results.to_vector()) + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) } } diff --git a/src/common/function/src/scalars/geo/h3.rs b/src/common/function/src/scalars/geo/h3.rs index 66d92863ca..0014e3a3cb 100644 --- a/src/common/function/src/scalars/geo/h3.rs +++ b/src/common/function/src/scalars/geo/h3.rs @@ -19,13 +19,11 @@ use common_error::ext::{BoxedError, PlainError}; use common_error::status_code::StatusCode; use common_query::error::{self, Result}; use datafusion::arrow::array::{ - Array, ArrayRef, AsArray, BooleanBuilder, Float64Builder, Int32Builder, ListBuilder, - StringViewArray, StringViewBuilder, UInt8Builder, UInt64Builder, + Array, AsArray, BooleanBuilder, Float64Builder, Int32Builder, ListBuilder, StringViewArray, + StringViewBuilder, UInt8Builder, UInt64Builder, }; use datafusion::arrow::compute; -use datafusion::arrow::datatypes::{ - ArrowPrimitiveType, Float64Type, Int64Type, UInt8Type, UInt64Type, -}; +use datafusion::arrow::datatypes::{Float64Type, Int64Type, UInt8Type, UInt64Type}; use datafusion::logical_expr::ColumnarValue; use datafusion_common::{DataFusionError, ScalarValue, utils}; use datafusion_expr::type_coercion::aggregates::INTEGERS; @@ -36,6 +34,7 @@ use h3o::{CellIndex, LatLng, Resolution}; use snafu::prelude::*; use crate::function::Function; +use crate::scalars::geo::helpers; static CELL_TYPES: LazyLock> = LazyLock::new(|| vec![DataType::Int64, DataType::UInt64, DataType::Utf8]); @@ -89,11 +88,11 @@ impl Function for H3LatLngToCell { let args = ColumnarValue::values_to_arrays(&args.args)?; let [lat_vec, lon_vec, resolution_vec] = utils::take_function_args(self.name(), args)?; - let lat_vec = cast::(&lat_vec)?; + let lat_vec = helpers::cast::(&lat_vec)?; let lat_vec = lat_vec.as_primitive::(); - let lon_vec = cast::(&lon_vec)?; + let lon_vec = helpers::cast::(&lon_vec)?; let lon_vec = lon_vec.as_primitive::(); - let resolutions = cast::(&resolution_vec)?; + let resolutions = helpers::cast::(&resolution_vec)?; let resolution_vec = resolutions.as_primitive::(); let size = lat_vec.len(); @@ -171,11 +170,11 @@ impl Function for H3LatLngToCellString { let args = ColumnarValue::values_to_arrays(&args.args)?; let [lat_vec, lon_vec, resolution_vec] = utils::take_function_args(self.name(), args)?; - let lat_vec = cast::(&lat_vec)?; + let lat_vec = helpers::cast::(&lat_vec)?; let lat_vec = lat_vec.as_primitive::(); - let lon_vec = cast::(&lon_vec)?; + let lon_vec = helpers::cast::(&lon_vec)?; let lon_vec = lon_vec.as_primitive::(); - let resolutions = cast::(&resolution_vec)?; + let resolutions = helpers::cast::(&resolution_vec)?; let resolution_vec = resolutions.as_primitive::(); let size = lat_vec.len(); @@ -547,7 +546,7 @@ impl Function for H3CellToChildren { ) -> datafusion_common::Result { let args = ColumnarValue::values_to_arrays(&args.args)?; let [cell_vec, res_vec] = utils::take_function_args(self.name(), args)?; - let resolutions = cast::(&res_vec)?; + let resolutions = helpers::cast::(&res_vec)?; let resolutions = resolutions.as_primitive::(); let size = cell_vec.len(); @@ -641,7 +640,7 @@ where { let args = ColumnarValue::values_to_arrays(&args.args)?; let [cells, resolutions] = utils::take_function_args(name, args)?; - let resolutions = cast::(&resolutions)?; + let resolutions = helpers::cast::(&resolutions)?; let resolutions = resolutions.as_primitive::(); let mut builder = UInt64Builder::with_capacity(cells.len()); @@ -698,7 +697,7 @@ impl Function for H3ChildPosToCell { ) -> datafusion_common::Result { let args = ColumnarValue::values_to_arrays(&args.args)?; let [pos_vec, cell_vec, res_vec] = utils::take_function_args(self.name(), args)?; - let resolutions = cast::(&res_vec)?; + let resolutions = helpers::cast::(&res_vec)?; let resolutions = resolutions.as_primitive::(); let size = cell_vec.len(); @@ -722,18 +721,6 @@ impl Function for H3ChildPosToCell { } } -fn cast(array: &ArrayRef) -> datafusion_common::Result { - let x = compute::cast_with_options( - array.as_ref(), - &T::DATA_TYPE, - &compute::CastOptions { - safe: false, - ..Default::default() - }, - )?; - Ok(x) -} - /// Function that returns cells with k distances of given cell #[derive(Clone, Debug, Default, Display)] #[display("{}", self.name())] diff --git a/src/common/function/src/scalars/geo/helpers.rs b/src/common/function/src/scalars/geo/helpers.rs index aba3c80543..c76e188990 100644 --- a/src/common/function/src/scalars/geo/helpers.rs +++ b/src/common/function/src/scalars/geo/helpers.rs @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +use datafusion::arrow::array::{ArrayRef, ArrowPrimitiveType}; +use datafusion::arrow::compute; + macro_rules! ensure_columns_len { ($columns:ident) => { snafu::ensure!( @@ -73,3 +76,15 @@ macro_rules! ensure_and_coerce { } pub(crate) use ensure_and_coerce; + +pub(crate) fn cast(array: &ArrayRef) -> datafusion_common::Result { + let x = compute::cast_with_options( + array.as_ref(), + &T::DATA_TYPE, + &compute::CastOptions { + safe: false, + ..Default::default() + }, + )?; + Ok(x) +} diff --git a/src/common/function/src/scalars/matches.rs b/src/common/function/src/scalars/matches.rs index e44c981ee8..e8b87943aa 100644 --- a/src/common/function/src/scalars/matches.rs +++ b/src/common/function/src/scalars/matches.rs @@ -16,21 +16,20 @@ use std::collections::HashMap; use std::fmt; use std::sync::Arc; -use common_query::error::{IntoVectorSnafu, InvalidFuncArgsSnafu, InvalidInputTypeSnafu, Result}; +use common_query::error::{InvalidFuncArgsSnafu, Result}; +use datafusion::arrow::array::{Array, ArrayRef, AsArray, BooleanArray}; use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion}; use datafusion::common::{DFSchema, Result as DfResult}; use datafusion::execution::SessionStateBuilder; -use datafusion::logical_expr::{self, Expr, Volatility}; +use datafusion::logical_expr::{self, ColumnarValue, Expr, Volatility}; use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}; -use datafusion_expr::Signature; +use datafusion_common::{DataFusionError, utils}; +use datafusion_expr::{ScalarFunctionArgs, Signature}; use datatypes::arrow::array::RecordBatch; use datatypes::arrow::datatypes::{DataType, Field}; -use datatypes::prelude::VectorRef; -use datatypes::vectors::BooleanVector; -use snafu::{OptionExt, ResultExt, ensure}; -use store_api::storage::ConcreteDataType; +use snafu::{OptionExt, ensure}; -use crate::function::{Function, FunctionContext}; +use crate::function::Function; use crate::function_registry::FunctionRegistry; /// `matches` for full text search. @@ -65,38 +64,36 @@ impl Function for MatchesFunction { } // TODO: read case-sensitive config - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - ensure!( - columns.len() == 2, - InvalidFuncArgsSnafu { - err_msg: format!( - "The length of the args is not correct, expect exactly 2, have: {}", - columns.len() - ), - } - ); + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DfResult { + let args = ColumnarValue::values_to_arrays(&args.args)?; + let [data_column, patterns] = utils::take_function_args(self.name(), args)?; - let data_column = &columns[0]; if data_column.is_empty() { - return Ok(Arc::new(BooleanVector::from(Vec::::with_capacity(0)))); + return Ok(ColumnarValue::Array(Arc::new(BooleanArray::from( + Vec::::with_capacity(0), + )))); } - let pattern_vector = &columns[1] - .cast(&ConcreteDataType::string_datatype()) - .context(InvalidInputTypeSnafu { - err_msg: "cannot cast `pattern` to string", - })?; // Safety: both length and type are checked before - let pattern = pattern_vector.get(0).as_string().unwrap(); + let pattern = match patterns.data_type() { + DataType::Utf8View => patterns.as_string_view().value(0), + DataType::Utf8 => patterns.as_string::().value(0), + DataType::LargeUtf8 => patterns.as_string::().value(0), + t => { + return Err(DataFusionError::Execution(format!( + "unsupported datatype {t}" + ))); + } + }; self.eval(data_column, pattern) } } impl MatchesFunction { - fn eval(&self, data: &VectorRef, pattern: String) -> Result { + fn eval(&self, data_array: ArrayRef, pattern: &str) -> DfResult { let col_name = "data"; let parser_context = ParserContext::default(); - let raw_ast = parser_context.parse_pattern(&pattern)?; + let raw_ast = parser_context.parse_pattern(pattern)?; let ast = raw_ast.transform_ast()?; let like_expr = ast.into_like_expr(col_name); @@ -107,19 +104,14 @@ impl MatchesFunction { let physical_expr = planner.create_physical_expr(&like_expr, &input_schema, &session_state)?; - let data_array = data.to_arrow_array(); let arrow_schema = Arc::new(input_schema.as_arrow().clone()); let input_record_batch = RecordBatch::try_new(arrow_schema, vec![data_array]).unwrap(); let num_rows = input_record_batch.num_rows(); let result = physical_expr.evaluate(&input_record_batch)?; let result_array = result.into_array(num_rows)?; - let result_vector = - BooleanVector::try_from_arrow_array(result_array).context(IntoVectorSnafu { - data_type: DataType::Boolean, - })?; - Ok(Arc::new(result_vector)) + Ok(ColumnarValue::Array(Arc::new(result_array))) } fn input_schema() -> DFSchema { @@ -833,7 +825,9 @@ impl Tokenizer { #[cfg(test)] mod test { - use datatypes::vectors::StringVector; + use datafusion::arrow::array::StringArray; + use datafusion_common::ScalarValue; + use datafusion_common::config::ConfigOptions; use super::*; @@ -1300,7 +1294,7 @@ mod test { "The quick brown fox jumps over dog", "The quick brown fox jumps over the dog", ]; - let input_vector: VectorRef = Arc::new(StringVector::from(input_data)); + let col: ArrayRef = Arc::new(StringArray::from(input_data)); let cases = [ // basic cases ("quick", vec![true, false, true, true, true, true, true]), @@ -1391,9 +1385,22 @@ mod test { let f = MatchesFunction; for (pattern, expected) in cases { - let actual: VectorRef = f.eval(&input_vector, pattern.to_string()).unwrap(); - let expected: VectorRef = Arc::new(BooleanVector::from(expected)) as _; - assert_eq!(expected, actual, "{pattern}"); + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(col.clone()), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(pattern.to_string()))), + ], + arg_fields: vec![], + number_rows: col.len(), + return_field: Arc::new(Field::new("x", col.data_type().clone(), true)), + config_options: Arc::new(ConfigOptions::new()), + }; + let actual = f + .invoke_with_args(args) + .and_then(|x| x.to_array(col.len())) + .unwrap(); + let expected: ArrayRef = Arc::new(BooleanArray::from(expected)); + assert_eq!(expected.as_ref(), actual.as_ref(), "{pattern}"); } } } diff --git a/src/common/function/src/scalars/math.rs b/src/common/function/src/scalars/math.rs index a3baf54add..7376779d7b 100644 --- a/src/common/function/src/scalars/math.rs +++ b/src/common/function/src/scalars/math.rs @@ -23,10 +23,9 @@ use common_query::error::Result; use datafusion::arrow::datatypes::DataType; use datafusion::error::DataFusionError; use datafusion_expr::{Signature, Volatility}; -use datatypes::vectors::VectorRef; pub use rate::RateFunction; -use crate::function::{Function, FunctionContext}; +use crate::function::Function; use crate::function_registry::FunctionRegistry; use crate::scalars::math::modulo::ModuloFunction; @@ -75,11 +74,4 @@ impl Function for RangeFunction { fn signature(&self) -> Signature { Signature::variadic_any(Volatility::Immutable) } - - fn eval(&self, _func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result { - Err(DataFusionError::Internal( - "range_fn just a empty function used in range select, It should not be eval!".into(), - ) - .into()) - } } diff --git a/src/common/function/src/scalars/math/clamp.rs b/src/common/function/src/scalars/math/clamp.rs index 37b17d231f..2f55d0208b 100644 --- a/src/common/function/src/scalars/math/clamp.rs +++ b/src/common/function/src/scalars/math/clamp.rs @@ -15,54 +15,21 @@ use std::fmt::{self, Display}; use std::sync::Arc; -use common_query::error::{InvalidFuncArgsSnafu, Result}; -use datafusion::arrow::array::{ArrayIter, PrimitiveArray}; +use common_query::error::Result; +use datafusion::arrow::array::{Array, ArrayRef, AsArray, PrimitiveArray}; use datafusion::arrow::datatypes::DataType as ArrowDataType; -use datafusion::logical_expr::Volatility; -use datafusion_expr::Signature; +use datafusion::logical_expr::{ColumnarValue, Volatility}; +use datafusion_common::{DataFusionError, ScalarValue, utils}; use datafusion_expr::type_coercion::aggregates::NUMERICS; -use datatypes::data_type::DataType; -use datatypes::prelude::VectorRef; -use datatypes::types::LogicalPrimitiveType; -use datatypes::value::TryAsPrimitive; -use datatypes::vectors::PrimitiveVector; -use datatypes::with_match_primitive_type_id; -use snafu::{OptionExt, ensure}; +use datafusion_expr::{ScalarFunctionArgs, Signature}; -use crate::function::{Function, FunctionContext}; +use crate::function::Function; #[derive(Clone, Debug, Default)] pub struct ClampFunction; const CLAMP_NAME: &str = "clamp"; -/// Ensure the vector is constant and not empty (i.e., all values are identical) -fn ensure_constant_vector(vector: &VectorRef) -> Result<()> { - ensure!( - !vector.is_empty(), - InvalidFuncArgsSnafu { - err_msg: "Expect at least one value", - } - ); - - if vector.is_const() { - return Ok(()); - } - - let first = vector.get_ref(0); - for i in 1..vector.len() { - let v = vector.get_ref(i); - if first != v { - return InvalidFuncArgsSnafu { - err_msg: "All values in min/max argument must be identical", - } - .fail(); - } - } - - Ok(()) -} - impl Function for ClampFunction { fn name(&self) -> &str { CLAMP_NAME @@ -78,76 +45,12 @@ impl Function for ClampFunction { Signature::uniform(3, NUMERICS.to_vec(), Volatility::Immutable) } - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - ensure!( - columns.len() == 3, - InvalidFuncArgsSnafu { - err_msg: format!( - "The length of the args is not correct, expect exactly 3, have: {}", - columns.len() - ), - } - ); - ensure!( - columns[0].data_type().is_numeric(), - InvalidFuncArgsSnafu { - err_msg: format!( - "The first arg's type is not numeric, have: {}", - columns[0].data_type() - ), - } - ); - ensure!( - columns[0].data_type() == columns[1].data_type() - && columns[1].data_type() == columns[2].data_type(), - InvalidFuncArgsSnafu { - err_msg: format!( - "Arguments don't have identical types: {}, {}, {}", - columns[0].data_type(), - columns[1].data_type(), - columns[2].data_type() - ), - } - ); - - ensure_constant_vector(&columns[1])?; - ensure_constant_vector(&columns[2])?; - - with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| { - let input_array = columns[0].to_arrow_array(); - let input = input_array - .as_any() - .downcast_ref::::ArrowPrimitive>>() - .unwrap(); - - let min = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0)) - .with_context(|| { - InvalidFuncArgsSnafu { - err_msg: "The second arg should not be none", - } - })?; - let max = TryAsPrimitive::<$S>::try_as_primitive(&columns[2].get(0)) - .with_context(|| { - InvalidFuncArgsSnafu { - err_msg: "The third arg should not be none", - } - })?; - - // ensure min <= max - ensure!( - min <= max, - InvalidFuncArgsSnafu { - err_msg: format!( - "The second arg should be less than or equal to the third arg, have: {:?}, {:?}", - columns[1], columns[2] - ), - } - ); - - clamp_impl::<$S, true, true>(input, min, max) - },{ - unreachable!() - }) + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let [col, min, max] = utils::take_function_args(self.name(), args.args)?; + clamp_impl(col, min, max) } } @@ -157,25 +60,155 @@ impl Display for ClampFunction { } } -fn clamp_impl( - input: &PrimitiveArray, - min: T::Native, - max: T::Native, -) -> Result { - let iter = ArrayIter::new(input); - let result = iter.map(|x| { - x.map(|x| { - if CLAMP_MIN && x < min { - min - } else if CLAMP_MAX && x > max { - max - } else { - x +fn clamp_impl( + col: ColumnarValue, + min: ColumnarValue, + max: ColumnarValue, +) -> datafusion_common::Result { + if col.data_type() != min.data_type() || min.data_type() != max.data_type() { + return Err(DataFusionError::Execution(format!( + "argument data types mismatch: {}, {}, {}", + col.data_type(), + min.data_type(), + max.data_type(), + ))); + } + + macro_rules! with_match_numerics_types { + ($data_type:expr, | $_:tt $T:ident | $body:tt) => {{ + macro_rules! __with_ty__ { + ( $_ $T:ident ) => { + $body + }; } - }) - }); - let result = PrimitiveArray::::from_iter(result); - Ok(Arc::new(PrimitiveVector::::from(result))) + + use datafusion::arrow::datatypes::{ + Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, Int64Type, UInt8Type, + UInt16Type, UInt32Type, UInt64Type, + }; + + match $data_type { + ArrowDataType::Int8 => Ok(__with_ty__! { Int8Type }), + ArrowDataType::Int16 => Ok(__with_ty__! { Int16Type }), + ArrowDataType::Int32 => Ok(__with_ty__! { Int32Type }), + ArrowDataType::Int64 => Ok(__with_ty__! { Int64Type }), + ArrowDataType::UInt8 => Ok(__with_ty__! { UInt8Type }), + ArrowDataType::UInt16 => Ok(__with_ty__! { UInt16Type }), + ArrowDataType::UInt32 => Ok(__with_ty__! { UInt32Type }), + ArrowDataType::UInt64 => Ok(__with_ty__! { UInt64Type }), + ArrowDataType::Float32 => Ok(__with_ty__! { Float32Type }), + ArrowDataType::Float64 => Ok(__with_ty__! { Float64Type }), + _ => Err(DataFusionError::Execution(format!( + "unsupported numeric data type: '{}'", + $data_type + ))), + } + }}; + } + + macro_rules! clamp { + ($v: ident, $min: ident, $max: ident) => { + if $v < $min { + $min + } else if $v > $max { + $max + } else { + $v + } + }; + } + + match (col, min, max) { + (ColumnarValue::Scalar(col), ColumnarValue::Scalar(min), ColumnarValue::Scalar(max)) => { + if min > max { + return Err(DataFusionError::Execution(format!( + "min '{}' > max '{}'", + min, max + ))); + } + Ok(ColumnarValue::Scalar(clamp!(col, min, max))) + } + + (ColumnarValue::Array(col), ColumnarValue::Array(min), ColumnarValue::Array(max)) => { + if col.len() != min.len() || col.len() != max.len() { + return Err(DataFusionError::Internal( + "arguments not of same length".to_string(), + )); + } + let result = with_match_numerics_types!( + col.data_type(), + |$S| { + let col = col.as_primitive::<$S>(); + let min = min.as_primitive::<$S>(); + let max = max.as_primitive::<$S>(); + Arc::new(PrimitiveArray::<$S>::from( + (0..col.len()) + .map(|i| { + let v = col.is_valid(i).then(|| col.value(i)); + // Index safety: checked above, all have same length. + let min = min.is_valid(i).then(|| min.value(i)); + let max = max.is_valid(i).then(|| max.value(i)); + Ok(match (v, min, max) { + (Some(v), Some(min), Some(max)) => { + if min > max { + return Err(DataFusionError::Execution(format!( + "min '{}' > max '{}'", + min, max + ))); + } + Some(clamp!(v, min, max)) + }, + _ => None, + }) + }) + .collect::>>()?, + ) + ) as ArrayRef + } + )?; + Ok(ColumnarValue::Array(result)) + } + + (ColumnarValue::Array(col), ColumnarValue::Scalar(min), ColumnarValue::Scalar(max)) => { + if min.is_null() || max.is_null() { + return Err(DataFusionError::Execution( + "argument 'min' or 'max' is null".to_string(), + )); + } + let min = min.to_array()?; + let max = max.to_array()?; + let result = with_match_numerics_types!( + col.data_type(), + |$S| { + let col = col.as_primitive::<$S>(); + // Index safety: checked above, both are not nulls. + let min = min.as_primitive::<$S>().value(0); + let max = max.as_primitive::<$S>().value(0); + if min > max { + return Err(DataFusionError::Execution(format!( + "min '{}' > max '{}'", + min, max + ))); + } + Arc::new(PrimitiveArray::<$S>::from( + (0..col.len()) + .map(|x| { + col.is_valid(x).then(|| { + let v = col.value(x); + clamp!(v, min, max) + }) + }) + .collect::>(), + ) + ) as ArrayRef + } + )?; + Ok(ColumnarValue::Array(result)) + } + _ => Err(DataFusionError::Internal( + "argument column types mismatch".to_string(), + )), + } } #[derive(Clone, Debug, Default)] @@ -197,59 +230,19 @@ impl Function for ClampMinFunction { Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable) } - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - ensure!( - columns.len() == 2, - InvalidFuncArgsSnafu { - err_msg: format!( - "The length of the args is not correct, expect exactly 2, have: {}", - columns.len() - ), - } - ); - ensure!( - columns[0].data_type().is_numeric(), - InvalidFuncArgsSnafu { - err_msg: format!( - "The first arg's type is not numeric, have: {}", - columns[0].data_type() - ), - } - ); - ensure!( - columns[0].data_type() == columns[1].data_type(), - InvalidFuncArgsSnafu { - err_msg: format!( - "Arguments don't have identical types: {}, {}", - columns[0].data_type(), - columns[1].data_type() - ), - } - ); + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let [col, min] = utils::take_function_args(self.name(), args.args)?; - ensure_constant_vector(&columns[1])?; - - with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| { - let input_array = columns[0].to_arrow_array(); - let input = input_array - .as_any() - .downcast_ref::::ArrowPrimitive>>() - .unwrap(); - - let min = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0)) - .with_context(|| { - InvalidFuncArgsSnafu { - err_msg: "The second arg (min) should not be none", - } - })?; - // For clamp_min, max is effectively infinity, so we don't use it in the clamp_impl logic. - // We pass a default/dummy value for max. - let max_dummy = <$S as LogicalPrimitiveType>::Native::default(); - - clamp_impl::<$S, true, false>(input, min, max_dummy) - },{ - unreachable!() - }) + let Some(max) = ScalarValue::max(&min.data_type()) else { + return Err(DataFusionError::Internal(format!( + "cannot find a max value for numeric data type {}", + min.data_type() + ))); + }; + clamp_impl(col, min, ColumnarValue::Scalar(max)) } } @@ -278,59 +271,19 @@ impl Function for ClampMaxFunction { Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable) } - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - ensure!( - columns.len() == 2, - InvalidFuncArgsSnafu { - err_msg: format!( - "The length of the args is not correct, expect exactly 2, have: {}", - columns.len() - ), - } - ); - ensure!( - columns[0].data_type().is_numeric(), - InvalidFuncArgsSnafu { - err_msg: format!( - "The first arg's type is not numeric, have: {}", - columns[0].data_type() - ), - } - ); - ensure!( - columns[0].data_type() == columns[1].data_type(), - InvalidFuncArgsSnafu { - err_msg: format!( - "Arguments don't have identical types: {}, {}", - columns[0].data_type(), - columns[1].data_type() - ), - } - ); + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let [col, max] = utils::take_function_args(self.name(), args.args)?; - ensure_constant_vector(&columns[1])?; - - with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| { - let input_array = columns[0].to_arrow_array(); - let input = input_array - .as_any() - .downcast_ref::::ArrowPrimitive>>() - .unwrap(); - - let max = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0)) - .with_context(|| { - InvalidFuncArgsSnafu { - err_msg: "The second arg (max) should not be none", - } - })?; - // For clamp_max, min is effectively -infinity, so we don't use it in the clamp_impl logic. - // We pass a default/dummy value for min. - let min_dummy = <$S as LogicalPrimitiveType>::Native::default(); - - clamp_impl::<$S, false, true>(input, min_dummy, max) - },{ - unreachable!() - }) + let Some(min) = ScalarValue::min(&max.data_type()) else { + return Err(DataFusionError::Internal(format!( + "cannot find a min value for numeric data type {}", + max.data_type() + ))); + }; + clamp_impl(col, ColumnarValue::Scalar(min), max) } } @@ -345,55 +298,80 @@ mod test { use std::sync::Arc; - use datatypes::prelude::ScalarVector; - use datatypes::vectors::{ - ConstantVector, Float64Vector, Int64Vector, StringVector, UInt64Vector, - }; + use arrow_schema::Field; + use datafusion_common::config::ConfigOptions; + use datatypes::arrow::array::{ArrayRef, Float64Array, Int64Array, UInt64Array}; + use datatypes::arrow_array::StringArray; use super::*; - use crate::function::FunctionContext; + + macro_rules! impl_test_eval { + ($func: ty) => { + impl $func { + fn test_eval( + &self, + args: Vec, + number_rows: usize, + ) -> datafusion_common::Result { + let input_type = args[0].data_type(); + self.invoke_with_args(ScalarFunctionArgs { + args, + arg_fields: vec![], + number_rows, + return_field: Arc::new(Field::new("x", input_type, false)), + config_options: Arc::new(ConfigOptions::new()), + }) + .and_then(|v| ColumnarValue::values_to_arrays(&[v]).map_err(Into::into)) + .map(|mut a| a.remove(0)) + } + } + }; + } + + impl_test_eval!(ClampFunction); + impl_test_eval!(ClampMinFunction); + impl_test_eval!(ClampMaxFunction); #[test] fn clamp_i64() { let inputs = [ ( vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)], - -1, - 10, + -1i64, + 10i64, vec![Some(-1), Some(-1), Some(-1), Some(0), Some(1), Some(2)], ), ( vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)], - 0, - 0, + 0i64, + 0i64, vec![Some(0), Some(0), Some(0), Some(0), Some(0), Some(0)], ), ( vec![Some(-3), None, Some(-1), None, None, Some(2)], - -2, - 1, + -2i64, + 1i64, vec![Some(-2), None, Some(-1), None, None, Some(1)], ), ( vec![None, None, None, None, None], - 0, - 1, + 0i64, + 1i64, vec![None, None, None, None, None], ), ]; let func = ClampFunction; for (in_data, min, max, expected) in inputs { - let args = [ - Arc::new(Int64Vector::from(in_data)) as _, - Arc::new(Int64Vector::from_vec(vec![min])) as _, - Arc::new(Int64Vector::from_vec(vec![max])) as _, + let number_rows = in_data.len(); + let args = vec![ + ColumnarValue::Array(Arc::new(Int64Array::from(in_data))), + ColumnarValue::Scalar(min.into()), + ColumnarValue::Scalar(max.into()), ]; - let result = func - .eval(&FunctionContext::default(), args.as_slice()) - .unwrap(); - let expected: VectorRef = Arc::new(Int64Vector::from(expected)); - assert_eq!(expected, result); + let result = func.test_eval(args, number_rows).unwrap(); + let expected: ArrayRef = Arc::new(Int64Array::from(expected)); + assert_eq!(expected.as_ref(), result.as_ref()); } } @@ -402,42 +380,41 @@ mod test { let inputs = [ ( vec![Some(0), Some(1), Some(2), Some(3), Some(4), Some(5)], - 1, - 3, + 1u64, + 3u64, vec![Some(1), Some(1), Some(2), Some(3), Some(3), Some(3)], ), ( vec![Some(0), Some(1), Some(2), Some(3), Some(4), Some(5)], - 0, - 0, + 0u64, + 0u64, vec![Some(0), Some(0), Some(0), Some(0), Some(0), Some(0)], ), ( vec![Some(0), None, Some(2), None, None, Some(5)], - 1, - 3, + 1u64, + 3u64, vec![Some(1), None, Some(2), None, None, Some(3)], ), ( vec![None, None, None, None, None], - 0, - 1, + 0u64, + 1u64, vec![None, None, None, None, None], ), ]; let func = ClampFunction; for (in_data, min, max, expected) in inputs { - let args = [ - Arc::new(UInt64Vector::from(in_data)) as _, - Arc::new(UInt64Vector::from_vec(vec![min])) as _, - Arc::new(UInt64Vector::from_vec(vec![max])) as _, + let number_rows = in_data.len(); + let args = vec![ + ColumnarValue::Array(Arc::new(UInt64Array::from(in_data))), + ColumnarValue::Scalar(min.into()), + ColumnarValue::Scalar(max.into()), ]; - let result = func - .eval(&FunctionContext::default(), args.as_slice()) - .unwrap(); - let expected: VectorRef = Arc::new(UInt64Vector::from(expected)); - assert_eq!(expected, result); + let result = func.test_eval(args, number_rows).unwrap(); + let expected: ArrayRef = Arc::new(UInt64Array::from(expected)); + assert_eq!(expected.as_ref(), result.as_ref()); } } @@ -472,38 +449,18 @@ mod test { let func = ClampFunction; for (in_data, min, max, expected) in inputs { - let args = [ - Arc::new(Float64Vector::from(in_data)) as _, - Arc::new(Float64Vector::from_vec(vec![min])) as _, - Arc::new(Float64Vector::from_vec(vec![max])) as _, + let number_rows = in_data.len(); + let args = vec![ + ColumnarValue::Array(Arc::new(Float64Array::from(in_data))), + ColumnarValue::Scalar(min.into()), + ColumnarValue::Scalar(max.into()), ]; - let result = func - .eval(&FunctionContext::default(), args.as_slice()) - .unwrap(); - let expected: VectorRef = Arc::new(Float64Vector::from(expected)); - assert_eq!(expected, result); + let result = func.test_eval(args, number_rows).unwrap(); + let expected: ArrayRef = Arc::new(Float64Array::from(expected)); + assert_eq!(expected.as_ref(), result.as_ref()); } } - #[test] - fn clamp_const_i32() { - let input = vec![Some(5)]; - let min = 2; - let max = 4; - - let func = ClampFunction; - let args = [ - Arc::new(ConstantVector::new(Arc::new(Int64Vector::from(input)), 1)) as _, - Arc::new(Int64Vector::from_vec(vec![min])) as _, - Arc::new(Int64Vector::from_vec(vec![max])) as _, - ]; - let result = func - .eval(&FunctionContext::default(), args.as_slice()) - .unwrap(); - let expected: VectorRef = Arc::new(Int64Vector::from(vec![Some(4)])); - assert_eq!(expected, result); - } - #[test] fn clamp_invalid_min_max() { let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)]; @@ -511,28 +468,30 @@ mod test { let max = -1.0; let func = ClampFunction; - let args = [ - Arc::new(Float64Vector::from(input)) as _, - Arc::new(Float64Vector::from_vec(vec![min])) as _, - Arc::new(Float64Vector::from_vec(vec![max])) as _, + let number_rows = input.len(); + let args = vec![ + ColumnarValue::Array(Arc::new(Float64Array::from(input))), + ColumnarValue::Scalar(min.into()), + ColumnarValue::Scalar(max.into()), ]; - let result = func.eval(&FunctionContext::default(), args.as_slice()); + let result = func.test_eval(args, number_rows); assert!(result.is_err()); } #[test] fn clamp_type_not_match() { let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)]; - let min = -1; - let max = 10; + let min = -1i64; + let max = 10u64; let func = ClampFunction; - let args = [ - Arc::new(Float64Vector::from(input)) as _, - Arc::new(Int64Vector::from_vec(vec![min])) as _, - Arc::new(UInt64Vector::from_vec(vec![max])) as _, + let number_rows = input.len(); + let args = vec![ + ColumnarValue::Array(Arc::new(Float64Array::from(input))), + ColumnarValue::Scalar(min.into()), + ColumnarValue::Scalar(max.into()), ]; - let result = func.eval(&FunctionContext::default(), args.as_slice()); + let result = func.test_eval(args, number_rows); assert!(result.is_err()); } @@ -543,12 +502,13 @@ mod test { let max = 1.0; let func = ClampFunction; - let args = [ - Arc::new(Float64Vector::from(input)) as _, - Arc::new(Float64Vector::from_vec(vec![min, max])) as _, - Arc::new(Float64Vector::from_vec(vec![max, min])) as _, + let number_rows = input.len(); + let args = vec![ + ColumnarValue::Array(Arc::new(Float64Array::from(input))), + ColumnarValue::Array(Arc::new(Float64Array::from(vec![min, max]))), + ColumnarValue::Array(Arc::new(Float64Array::from(vec![max, min]))), ]; - let result = func.eval(&FunctionContext::default(), args.as_slice()); + let result = func.test_eval(args, number_rows); assert!(result.is_err()); } @@ -558,11 +518,12 @@ mod test { let min = -10.0; let func = ClampFunction; - let args = [ - Arc::new(Float64Vector::from(input)) as _, - Arc::new(Float64Vector::from_vec(vec![min])) as _, + let number_rows = input.len(); + let args = vec![ + ColumnarValue::Array(Arc::new(Float64Array::from(input))), + ColumnarValue::Scalar(min.into()), ]; - let result = func.eval(&FunctionContext::default(), args.as_slice()); + let result = func.test_eval(args, number_rows); assert!(result.is_err()); } @@ -571,12 +532,13 @@ mod test { let input = vec![Some("foo"), Some("foo"), Some("foo"), Some("foo")]; let func = ClampFunction; - let args = [ - Arc::new(StringVector::from(input)) as _, - Arc::new(StringVector::from_vec(vec!["bar"])) as _, - Arc::new(StringVector::from_vec(vec!["baz"])) as _, + let number_rows = input.len(); + let args = vec![ + ColumnarValue::Array(Arc::new(StringArray::from(input))), + ColumnarValue::Scalar("bar".into()), + ColumnarValue::Scalar("baz".into()), ]; - let result = func.eval(&FunctionContext::default(), args.as_slice()); + let result = func.test_eval(args, number_rows); assert!(result.is_err()); } @@ -585,27 +547,26 @@ mod test { let inputs = [ ( vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)], - -1, + -1i64, vec![Some(-1), Some(-1), Some(-1), Some(0), Some(1), Some(2)], ), ( vec![Some(-3), None, Some(-1), None, None, Some(2)], - -2, + -2i64, vec![Some(-2), None, Some(-1), None, None, Some(2)], ), ]; let func = ClampMinFunction; for (in_data, min, expected) in inputs { - let args = [ - Arc::new(Int64Vector::from(in_data)) as _, - Arc::new(Int64Vector::from_vec(vec![min])) as _, + let number_rows = in_data.len(); + let args = vec![ + ColumnarValue::Array(Arc::new(Int64Array::from(in_data))), + ColumnarValue::Scalar(min.into()), ]; - let result = func - .eval(&FunctionContext::default(), args.as_slice()) - .unwrap(); - let expected: VectorRef = Arc::new(Int64Vector::from(expected)); - assert_eq!(expected, result); + let result = func.test_eval(args, number_rows).unwrap(); + let expected: ArrayRef = Arc::new(Int64Array::from(expected)); + assert_eq!(expected.as_ref(), result.as_ref()); } } @@ -614,27 +575,26 @@ mod test { let inputs = [ ( vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)], - 1, + 1i64, vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(1)], ), ( vec![Some(-3), None, Some(-1), None, None, Some(2)], - 0, + 0i64, vec![Some(-3), None, Some(-1), None, None, Some(0)], ), ]; let func = ClampMaxFunction; for (in_data, max, expected) in inputs { - let args = [ - Arc::new(Int64Vector::from(in_data)) as _, - Arc::new(Int64Vector::from_vec(vec![max])) as _, + let number_rows = in_data.len(); + let args = vec![ + ColumnarValue::Array(Arc::new(Int64Array::from(in_data))), + ColumnarValue::Scalar(max.into()), ]; - let result = func - .eval(&FunctionContext::default(), args.as_slice()) - .unwrap(); - let expected: VectorRef = Arc::new(Int64Vector::from(expected)); - assert_eq!(expected, result); + let result = func.test_eval(args, number_rows).unwrap(); + let expected: ArrayRef = Arc::new(Int64Array::from(expected)); + assert_eq!(expected.as_ref(), result.as_ref()); } } @@ -648,15 +608,14 @@ mod test { let func = ClampMinFunction; for (in_data, min, expected) in inputs { - let args = [ - Arc::new(Float64Vector::from(in_data)) as _, - Arc::new(Float64Vector::from_vec(vec![min])) as _, + let number_rows = in_data.len(); + let args = vec![ + ColumnarValue::Array(Arc::new(Float64Array::from(in_data))), + ColumnarValue::Scalar(min.into()), ]; - let result = func - .eval(&FunctionContext::default(), args.as_slice()) - .unwrap(); - let expected: VectorRef = Arc::new(Float64Vector::from(expected)); - assert_eq!(expected, result); + let result = func.test_eval(args, number_rows).unwrap(); + let expected: ArrayRef = Arc::new(Float64Array::from(expected)); + assert_eq!(expected.as_ref(), result.as_ref()); } } @@ -670,43 +629,44 @@ mod test { let func = ClampMaxFunction; for (in_data, max, expected) in inputs { - let args = [ - Arc::new(Float64Vector::from(in_data)) as _, - Arc::new(Float64Vector::from_vec(vec![max])) as _, + let number_rows = in_data.len(); + let args = vec![ + ColumnarValue::Array(Arc::new(Float64Array::from(in_data))), + ColumnarValue::Scalar(max.into()), ]; - let result = func - .eval(&FunctionContext::default(), args.as_slice()) - .unwrap(); - let expected: VectorRef = Arc::new(Float64Vector::from(expected)); - assert_eq!(expected, result); + let result = func.test_eval(args, number_rows).unwrap(); + let expected: ArrayRef = Arc::new(Float64Array::from(expected)); + assert_eq!(expected.as_ref(), result.as_ref()); } } #[test] fn clamp_min_type_not_match() { let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)]; - let min = -1; + let min = -1i64; let func = ClampMinFunction; - let args = [ - Arc::new(Float64Vector::from(input)) as _, - Arc::new(Int64Vector::from_vec(vec![min])) as _, + let number_rows = input.len(); + let args = vec![ + ColumnarValue::Array(Arc::new(Float64Array::from(input))), + ColumnarValue::Scalar(min.into()), ]; - let result = func.eval(&FunctionContext::default(), args.as_slice()); + let result = func.test_eval(args, number_rows); assert!(result.is_err()); } #[test] fn clamp_max_type_not_match() { let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)]; - let max = 1; + let max = 1i64; let func = ClampMaxFunction; - let args = [ - Arc::new(Float64Vector::from(input)) as _, - Arc::new(Int64Vector::from_vec(vec![max])) as _, + let number_rows = input.len(); + let args = vec![ + ColumnarValue::Array(Arc::new(Float64Array::from(input))), + ColumnarValue::Scalar(max.into()), ]; - let result = func.eval(&FunctionContext::default(), args.as_slice()); + let result = func.test_eval(args, number_rows); assert!(result.is_err()); } } diff --git a/src/common/function/src/scalars/vector.rs b/src/common/function/src/scalars/vector.rs index 46abf1e163..ba98823c54 100644 --- a/src/common/function/src/scalars/vector.rs +++ b/src/common/function/src/scalars/vector.rs @@ -28,7 +28,14 @@ mod vector_norm; mod vector_sub; mod vector_subvector; +use std::borrow::Cow; + +use datafusion_common::{DataFusionError, Result, ScalarValue, utils}; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; + use crate::function_registry::FunctionRegistry; +use crate::scalars::vector::impl_conv::as_veclit; + pub(crate) struct VectorFunction; impl VectorFunction { @@ -59,3 +66,155 @@ impl VectorFunction { registry.register_scalar(elem_product::ElemProductFunction); } } + +// Use macro instead of function to "return" the reference to `ScalarValue` in the +// `ColumnarValue::Array` match arm. +macro_rules! try_get_scalar_value { + ($col: ident, $i: ident) => { + match $col { + datafusion::logical_expr::ColumnarValue::Array(a) => { + &datafusion_common::ScalarValue::try_from_array(a.as_ref(), $i)? + } + datafusion::logical_expr::ColumnarValue::Scalar(v) => v, + } + }; +} + +pub(crate) fn ensure_same_length(values: &[&ColumnarValue]) -> Result { + if values.is_empty() { + return Ok(0); + } + + let mut array_len = None; + for v in values { + array_len = match (v, array_len) { + (ColumnarValue::Array(a), None) => Some(a.len()), + (ColumnarValue::Array(a), Some(array_len)) => { + if array_len == a.len() { + Some(array_len) + } else { + return Err(DataFusionError::Internal(format!( + "Arguments has mixed length. Expected length: {array_len}, found length: {}", + a.len() + ))); + } + } + (ColumnarValue::Scalar(_), array_len) => array_len, + } + } + + // If array_len is none, it means there are only scalars, treat them each as 1 element array. + let array_len = array_len.unwrap_or(1); + Ok(array_len) +} + +struct VectorCalculator<'a, F> { + name: &'a str, + func: F, +} + +impl VectorCalculator<'_, F> +where + F: Fn(&ScalarValue, &ScalarValue) -> Result, +{ + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [arg0, arg1] = utils::take_function_args(self.name, &args.args)?; + + if let (ColumnarValue::Scalar(v0), ColumnarValue::Scalar(v1)) = (arg0, arg1) { + let result = (self.func)(v0, v1)?; + return Ok(ColumnarValue::Scalar(result)); + } + + let len = ensure_same_length(&[arg0, arg1])?; + let mut results = Vec::with_capacity(len); + for i in 0..len { + let v0 = try_get_scalar_value!(arg0, i); + let v1 = try_get_scalar_value!(arg1, i); + results.push((self.func)(v0, v1)?); + } + + let results = ScalarValue::iter_to_array(results.into_iter())?; + Ok(ColumnarValue::Array(results)) + } +} + +impl VectorCalculator<'_, F> +where + F: Fn(&Option>, &Option>) -> Result, +{ + fn invoke_with_vectors(&self, args: ScalarFunctionArgs) -> Result { + let [arg0, arg1] = utils::take_function_args(self.name, &args.args)?; + + if let (ColumnarValue::Scalar(v0), ColumnarValue::Scalar(v1)) = (arg0, arg1) { + let v0 = as_veclit(v0)?; + let v1 = as_veclit(v1)?; + let result = (self.func)(&v0, &v1)?; + return Ok(ColumnarValue::Scalar(result)); + } + + let len = ensure_same_length(&[arg0, arg1])?; + let mut results = Vec::with_capacity(len); + + match (arg0, arg1) { + (ColumnarValue::Scalar(v0), ColumnarValue::Array(a1)) => { + let v0 = as_veclit(v0)?; + for i in 0..len { + let v1 = ScalarValue::try_from_array(a1, i)?; + let v1 = as_veclit(&v1)?; + results.push((self.func)(&v0, &v1)?); + } + } + (ColumnarValue::Array(a0), ColumnarValue::Scalar(v1)) => { + let v1 = as_veclit(v1)?; + for i in 0..len { + let v0 = ScalarValue::try_from_array(a0, i)?; + let v0 = as_veclit(&v0)?; + results.push((self.func)(&v0, &v1)?); + } + } + (ColumnarValue::Array(a0), ColumnarValue::Array(a1)) => { + for i in 0..len { + let v0 = ScalarValue::try_from_array(a0, i)?; + let v0 = as_veclit(&v0)?; + let v1 = ScalarValue::try_from_array(a1, i)?; + let v1 = as_veclit(&v1)?; + results.push((self.func)(&v0, &v1)?); + } + } + (ColumnarValue::Scalar(_), ColumnarValue::Scalar(_)) => { + // unreachable because this arm has been separately dealt with above + unreachable!() + } + } + + let results = ScalarValue::iter_to_array(results.into_iter())?; + Ok(ColumnarValue::Array(results)) + } +} + +impl VectorCalculator<'_, F> +where + F: Fn(&ScalarValue) -> Result, +{ + fn invoke_with_single_argument(&self, args: ScalarFunctionArgs) -> Result { + let [arg0] = utils::take_function_args(self.name, &args.args)?; + + let arg0 = match arg0 { + ColumnarValue::Scalar(v) => { + let result = (self.func)(v)?; + return Ok(ColumnarValue::Scalar(result)); + } + ColumnarValue::Array(a) => a, + }; + + let len = arg0.len(); + let mut results = Vec::with_capacity(len); + for i in 0..len { + let v = ScalarValue::try_from_array(arg0, i)?; + results.push((self.func)(&v)?); + } + + let results = ScalarValue::iter_to_array(results.into_iter())?; + Ok(ColumnarValue::Array(results)) + } +} diff --git a/src/common/function/src/scalars/vector/convert/vector_to_string.rs b/src/common/function/src/scalars/vector/convert/vector_to_string.rs index 58fb22b61a..3dc5f06ad6 100644 --- a/src/common/function/src/scalars/vector/convert/vector_to_string.rs +++ b/src/common/function/src/scalars/vector/convert/vector_to_string.rs @@ -17,7 +17,7 @@ use std::fmt::Display; use common_query::error::{InvalidFuncArgsSnafu, Result}; use datafusion::arrow::datatypes::DataType; use datafusion_expr::type_coercion::aggregates::BINARYS; -use datafusion_expr::{Signature, Volatility}; +use datafusion_expr::{Signature, TypeSignature, Volatility}; use datatypes::scalars::ScalarVectorBuilder; use datatypes::types::vector_type_value_to_string; use datatypes::value::Value; @@ -41,7 +41,13 @@ impl Function for VectorToStringFunction { } fn signature(&self) -> Signature { - Signature::uniform(1, BINARYS.to_vec(), Volatility::Immutable) + Signature::one_of( + vec![ + TypeSignature::Uniform(1, vec![DataType::BinaryView]), + TypeSignature::Uniform(1, BINARYS.to_vec()), + ], + Volatility::Immutable, + ) } fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { diff --git a/src/common/function/src/scalars/vector/distance.rs b/src/common/function/src/scalars/vector/distance.rs index 864e2405b0..ab79f64143 100644 --- a/src/common/function/src/scalars/vector/distance.rs +++ b/src/common/function/src/scalars/vector/distance.rs @@ -19,20 +19,17 @@ mod l2sq; use std::borrow::Cow; use std::fmt::Display; -use common_query::error::{InvalidFuncArgsSnafu, Result}; -use datafusion_expr::Signature; +use common_query::error::Result; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::ScalarValue; +use datafusion_expr::{ScalarFunctionArgs, Signature}; use datatypes::arrow::datatypes::DataType; -use datatypes::scalars::ScalarVectorBuilder; -use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef}; -use snafu::ensure; -use crate::function::{Function, FunctionContext}; +use crate::function::Function; use crate::helper; -use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const}; macro_rules! define_distance_function { ($StructName:ident, $display_name:expr, $similarity_method:path) => { - /// A function calculates the distance between two vectors. #[derive(Debug, Clone, Default)] @@ -54,59 +51,34 @@ macro_rules! define_distance_function { ) } - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - ensure!( - columns.len() == 2, - InvalidFuncArgsSnafu { - err_msg: format!( - "The length of the args is not correct, expect exactly two, have: {}", - columns.len() - ), - } - ); - let arg0 = &columns[0]; - let arg1 = &columns[1]; + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let body = |v0: &Option>, + v1: &Option>| + -> datafusion_common::Result { + let result = if let (Some(v0), Some(v1)) = (v0, v1) { + if v0.len() != v1.len() { + return Err(datafusion_common::DataFusionError::Execution(format!( + "vectors length not match: {}", + self.name() + ))); + } - let size = arg0.len(); - let mut result = Float32VectorBuilder::with_capacity(size); - if size == 0 { - return Ok(result.to_vector()); - } - - let arg0_const = as_veclit_if_const(arg0)?; - let arg1_const = as_veclit_if_const(arg1)?; - - for i in 0..size { - let vec0 = match arg0_const.as_ref() { - Some(a) => Some(Cow::Borrowed(a.as_ref())), - None => as_veclit(arg0.get_ref(i))?, - }; - let vec1 = match arg1_const.as_ref() { - Some(b) => Some(Cow::Borrowed(b.as_ref())), - None => as_veclit(arg1.get_ref(i))?, - }; - - if let (Some(vec0), Some(vec1)) = (vec0, vec1) { - ensure!( - vec0.len() == vec1.len(), - InvalidFuncArgsSnafu { - err_msg: format!( - "The length of the vectors must match to calculate distance, have: {} vs {}", - vec0.len(), - vec1.len() - ), - } - ); - - // Checked if the length of the vectors match - let d = $similarity_method(vec0.as_ref(), vec1.as_ref()); - result.push(Some(d)); + let d = $similarity_method(v0, v1); + Some(d) } else { - result.push_null(); - } - } + None + }; + Ok(ScalarValue::Float32(result)) + }; - return Ok(result.to_vector()); + let calculator = $crate::scalars::vector::VectorCalculator { + name: self.name(), + func: body, + }; + calculator.invoke_with_vectors(args) } } @@ -115,7 +87,7 @@ macro_rules! define_distance_function { write!(f, "{}", $display_name.to_ascii_uppercase()) } } - } + }; } define_distance_function!(CosDistanceFunction, "vec_cos_distance", cos::cos); @@ -126,10 +98,29 @@ define_distance_function!(DotProductFunction, "vec_dot_product", dot::dot); mod tests { use std::sync::Arc; - use datatypes::vectors::{BinaryVector, ConstantVector, StringVector}; + use arrow_schema::Field; + use datafusion::arrow::array::{Array, ArrayRef, AsArray, BinaryArray, StringViewArray}; + use datafusion::arrow::datatypes::Float32Type; + use datafusion_common::config::ConfigOptions; use super::*; + fn test_invoke(func: &dyn Function, args: &[ArrayRef]) -> datafusion_common::Result { + let number_rows = args[0].len(); + let args = ScalarFunctionArgs { + args: args + .iter() + .map(|x| ColumnarValue::Array(x.clone())) + .collect::>(), + arg_fields: vec![], + number_rows, + return_field: Arc::new(Field::new("x", DataType::Float32, false)), + config_options: Arc::new(ConfigOptions::new()), + }; + func.invoke_with_args(args) + .and_then(|x| x.to_array(number_rows)) + } + #[test] fn test_distance_string_string() { let funcs = [ @@ -139,36 +130,34 @@ mod tests { ]; for func in funcs { - let vec1 = Arc::new(StringVector::from(vec![ + let vec1: ArrayRef = Arc::new(StringViewArray::from(vec![ Some("[0.0, 1.0]"), Some("[1.0, 0.0]"), None, Some("[1.0, 0.0]"), - ])) as VectorRef; - let vec2 = Arc::new(StringVector::from(vec![ + ])); + let vec2: ArrayRef = Arc::new(StringViewArray::from(vec![ Some("[0.0, 1.0]"), Some("[0.0, 1.0]"), Some("[0.0, 1.0]"), None, - ])) as VectorRef; + ])); - let result = func - .eval(&FunctionContext::default(), &[vec1.clone(), vec2.clone()]) - .unwrap(); + let result = test_invoke(func.as_ref(), &[vec1.clone(), vec2.clone()]).unwrap(); + let result = result.as_primitive::(); - assert!(!result.get(0).is_null()); - assert!(!result.get(1).is_null()); - assert!(result.get(2).is_null()); - assert!(result.get(3).is_null()); + assert!(!result.is_null(0)); + assert!(!result.is_null(1)); + assert!(result.is_null(2)); + assert!(result.is_null(3)); - let result = func - .eval(&FunctionContext::default(), &[vec2, vec1]) - .unwrap(); + let result = test_invoke(func.as_ref(), &[vec2, vec1]).unwrap(); + let result = result.as_primitive::(); - assert!(!result.get(0).is_null()); - assert!(!result.get(1).is_null()); - assert!(result.get(2).is_null()); - assert!(result.get(3).is_null()); + assert!(!result.is_null(0)); + assert!(!result.is_null(1)); + assert!(result.is_null(2)); + assert!(result.is_null(3)); } } @@ -181,37 +170,35 @@ mod tests { ]; for func in funcs { - let vec1 = Arc::new(BinaryVector::from(vec![ + let vec1: ArrayRef = Arc::new(BinaryArray::from_iter(vec![ Some(vec![0, 0, 0, 0, 0, 0, 128, 63]), Some(vec![0, 0, 128, 63, 0, 0, 0, 0]), None, Some(vec![0, 0, 128, 63, 0, 0, 0, 0]), - ])) as VectorRef; - let vec2 = Arc::new(BinaryVector::from(vec![ + ])); + let vec2: ArrayRef = Arc::new(BinaryArray::from_iter(vec![ // [0.0, 1.0] Some(vec![0, 0, 0, 0, 0, 0, 128, 63]), Some(vec![0, 0, 0, 0, 0, 0, 128, 63]), Some(vec![0, 0, 0, 0, 0, 0, 128, 63]), None, - ])) as VectorRef; + ])); - let result = func - .eval(&FunctionContext::default(), &[vec1.clone(), vec2.clone()]) - .unwrap(); + let result = test_invoke(func.as_ref(), &[vec1.clone(), vec2.clone()]).unwrap(); + let result = result.as_primitive::(); - assert!(!result.get(0).is_null()); - assert!(!result.get(1).is_null()); - assert!(result.get(2).is_null()); - assert!(result.get(3).is_null()); + assert!(!result.is_null(0)); + assert!(!result.is_null(1)); + assert!(result.is_null(2)); + assert!(result.is_null(3)); - let result = func - .eval(&FunctionContext::default(), &[vec2, vec1]) - .unwrap(); + let result = test_invoke(func.as_ref(), &[vec2, vec1]).unwrap(); + let result = result.as_primitive::(); - assert!(!result.get(0).is_null()); - assert!(!result.get(1).is_null()); - assert!(result.get(2).is_null()); - assert!(result.get(3).is_null()); + assert!(!result.is_null(0)); + assert!(!result.is_null(1)); + assert!(result.is_null(2)); + assert!(result.is_null(3)); } } @@ -224,115 +211,35 @@ mod tests { ]; for func in funcs { - let vec1 = Arc::new(StringVector::from(vec![ + let vec1: ArrayRef = Arc::new(StringViewArray::from(vec![ Some("[0.0, 1.0]"), Some("[1.0, 0.0]"), None, Some("[1.0, 0.0]"), - ])) as VectorRef; - let vec2 = Arc::new(BinaryVector::from(vec![ + ])); + let vec2: ArrayRef = Arc::new(BinaryArray::from_iter(vec![ // [0.0, 1.0] Some(vec![0, 0, 0, 0, 0, 0, 128, 63]), Some(vec![0, 0, 0, 0, 0, 0, 128, 63]), Some(vec![0, 0, 0, 0, 0, 0, 128, 63]), None, - ])) as VectorRef; + ])); - let result = func - .eval(&FunctionContext::default(), &[vec1.clone(), vec2.clone()]) - .unwrap(); + let result = test_invoke(func.as_ref(), &[vec1.clone(), vec2.clone()]).unwrap(); + let result = result.as_primitive::(); - assert!(!result.get(0).is_null()); - assert!(!result.get(1).is_null()); - assert!(result.get(2).is_null()); - assert!(result.get(3).is_null()); + assert!(!result.is_null(0)); + assert!(!result.is_null(1)); + assert!(result.is_null(2)); + assert!(result.is_null(3)); - let result = func - .eval(&FunctionContext::default(), &[vec2, vec1]) - .unwrap(); + let result = test_invoke(func.as_ref(), &[vec2, vec1]).unwrap(); + let result = result.as_primitive::(); - assert!(!result.get(0).is_null()); - assert!(!result.get(1).is_null()); - assert!(result.get(2).is_null()); - assert!(result.get(3).is_null()); - } - } - - #[test] - fn test_distance_const_string() { - let funcs = [ - Box::new(CosDistanceFunction {}) as Box, - Box::new(L2SqDistanceFunction {}) as Box, - Box::new(DotProductFunction {}) as Box, - ]; - - for func in funcs { - let const_str = Arc::new(ConstantVector::new( - Arc::new(StringVector::from(vec!["[0.0, 1.0]"])), - 4, - )); - - let vec1 = Arc::new(StringVector::from(vec![ - Some("[0.0, 1.0]"), - Some("[1.0, 0.0]"), - None, - Some("[1.0, 0.0]"), - ])) as VectorRef; - let vec2 = Arc::new(BinaryVector::from(vec![ - // [0.0, 1.0] - Some(vec![0, 0, 0, 0, 0, 0, 128, 63]), - Some(vec![0, 0, 0, 0, 0, 0, 128, 63]), - Some(vec![0, 0, 0, 0, 0, 0, 128, 63]), - None, - ])) as VectorRef; - - let result = func - .eval( - &FunctionContext::default(), - &[const_str.clone(), vec1.clone()], - ) - .unwrap(); - - assert!(!result.get(0).is_null()); - assert!(!result.get(1).is_null()); - assert!(result.get(2).is_null()); - assert!(!result.get(3).is_null()); - - let result = func - .eval( - &FunctionContext::default(), - &[vec1.clone(), const_str.clone()], - ) - .unwrap(); - - assert!(!result.get(0).is_null()); - assert!(!result.get(1).is_null()); - assert!(result.get(2).is_null()); - assert!(!result.get(3).is_null()); - - let result = func - .eval( - &FunctionContext::default(), - &[const_str.clone(), vec2.clone()], - ) - .unwrap(); - - assert!(!result.get(0).is_null()); - assert!(!result.get(1).is_null()); - assert!(!result.get(2).is_null()); - assert!(result.get(3).is_null()); - - let result = func - .eval( - &FunctionContext::default(), - &[vec2.clone(), const_str.clone()], - ) - .unwrap(); - - assert!(!result.get(0).is_null()); - assert!(!result.get(1).is_null()); - assert!(!result.get(2).is_null()); - assert!(result.get(3).is_null()); + assert!(!result.is_null(0)); + assert!(!result.is_null(1)); + assert!(result.is_null(2)); + assert!(result.is_null(3)); } } @@ -345,15 +252,16 @@ mod tests { ]; for func in funcs { - let vec1 = Arc::new(StringVector::from(vec!["[1.0]"])) as VectorRef; - let vec2 = Arc::new(StringVector::from(vec!["[1.0, 1.0]"])) as VectorRef; - let result = func.eval(&FunctionContext::default(), &[vec1, vec2]); + let vec1: ArrayRef = Arc::new(StringViewArray::from(vec!["[1.0]"])); + let vec2: ArrayRef = Arc::new(StringViewArray::from(vec!["[1.0, 1.0]"])); + let result = test_invoke(func.as_ref(), &[vec1, vec2]); assert!(result.is_err()); - let vec1 = Arc::new(BinaryVector::from(vec![vec![0, 0, 128, 63]])) as VectorRef; - let vec2 = - Arc::new(BinaryVector::from(vec![vec![0, 0, 128, 63, 0, 0, 0, 64]])) as VectorRef; - let result = func.eval(&FunctionContext::default(), &[vec1, vec2]); + let vec1: ArrayRef = Arc::new(BinaryArray::from_iter_values(vec![vec![0, 0, 128, 63]])); + let vec2: ArrayRef = Arc::new(BinaryArray::from_iter_values(vec![vec![ + 0, 0, 128, 63, 0, 0, 0, 64, + ]])); + let result = test_invoke(func.as_ref(), &[vec1, vec2]); assert!(result.is_err()); } } diff --git a/src/common/function/src/scalars/vector/elem_product.rs b/src/common/function/src/scalars/vector/elem_product.rs index 954f0f73a6..c527faa616 100644 --- a/src/common/function/src/scalars/vector/elem_product.rs +++ b/src/common/function/src/scalars/vector/elem_product.rs @@ -12,20 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::borrow::Cow; use std::fmt::Display; -use common_query::error::{InvalidFuncArgsSnafu, Result}; +use common_query::error::Result; use datafusion::arrow::datatypes::DataType; +use datafusion::logical_expr::ColumnarValue; use datafusion::logical_expr_common::type_coercion::aggregates::{BINARYS, STRINGS}; -use datafusion_expr::{Signature, TypeSignature, Volatility}; -use datatypes::scalars::ScalarVectorBuilder; -use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef}; +use datafusion_common::ScalarValue; +use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility}; use nalgebra::DVectorView; -use snafu::ensure; -use crate::function::{Function, FunctionContext}; -use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const}; +use crate::function::Function; +use crate::scalars::vector::{VectorCalculator, impl_conv}; const NAME: &str = "vec_elem_product"; @@ -64,43 +62,21 @@ impl Function for ElemProductFunction { ) } - fn eval( + fn invoke_with_args( &self, - _func_ctx: &FunctionContext, - columns: &[VectorRef], - ) -> common_query::error::Result { - ensure!( - columns.len() == 1, - InvalidFuncArgsSnafu { - err_msg: format!( - "The length of the args is not correct, expect exactly one, have: {}", - columns.len() - ) - } - ); - let arg0 = &columns[0]; + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let body = |v0: &ScalarValue| -> datafusion_common::Result { + let v0 = impl_conv::as_veclit(v0)? + .map(|v0| DVectorView::from_slice(&v0, v0.len()).product()); + Ok(ScalarValue::Float32(v0)) + }; - let len = arg0.len(); - let mut result = Float32VectorBuilder::with_capacity(len); - if len == 0 { - return Ok(result.to_vector()); - } - - let arg0_const = as_veclit_if_const(arg0)?; - - for i in 0..len { - let arg0 = match arg0_const.as_ref() { - Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())), - None => as_veclit(arg0.get_ref(i))?, - }; - let Some(arg0) = arg0 else { - result.push_null(); - continue; - }; - result.push(Some(DVectorView::from_slice(&arg0, arg0.len()).product())); - } - - Ok(result.to_vector()) + let calculator = VectorCalculator { + name: self.name(), + func: body, + }; + calculator.invoke_with_single_argument(args) } } @@ -114,27 +90,39 @@ impl Display for ElemProductFunction { mod tests { use std::sync::Arc; - use datatypes::vectors::StringVector; + use arrow_schema::Field; + use datafusion::arrow::array::{Array, AsArray, StringArray}; + use datafusion::arrow::datatypes::Float32Type; + use datafusion_common::config::ConfigOptions; use super::*; - use crate::function::FunctionContext; #[test] fn test_elem_product() { let func = ElemProductFunction; - let input0 = Arc::new(StringVector::from(vec![ + let input = Arc::new(StringArray::from(vec![ Some("[1.0,2.0,3.0]".to_string()), Some("[4.0,5.0,6.0]".to_string()), None, ])); - let result = func.eval(&FunctionContext::default(), &[input0]).unwrap(); + let result = func + .invoke_with_args(ScalarFunctionArgs { + args: vec![ColumnarValue::Array(input.clone())], + arg_fields: vec![], + number_rows: input.len(), + return_field: Arc::new(Field::new("x", DataType::Float32, true)), + config_options: Arc::new(ConfigOptions::new()), + }) + .and_then(|v| ColumnarValue::values_to_arrays(&[v])) + .map(|mut a| a.remove(0)) + .unwrap(); + let result = result.as_primitive::(); - let result = result.as_ref(); assert_eq!(result.len(), 3); - assert_eq!(result.get_ref(0).as_f32().unwrap(), Some(6.0)); - assert_eq!(result.get_ref(1).as_f32().unwrap(), Some(120.0)); - assert_eq!(result.get_ref(2).as_f32().unwrap(), None); + assert_eq!(result.value(0), 6.0); + assert_eq!(result.value(1), 120.0); + assert!(result.is_null(2)); } } diff --git a/src/common/function/src/scalars/vector/elem_sum.rs b/src/common/function/src/scalars/vector/elem_sum.rs index 8bea6d6c32..08e078df73 100644 --- a/src/common/function/src/scalars/vector/elem_sum.rs +++ b/src/common/function/src/scalars/vector/elem_sum.rs @@ -12,20 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::borrow::Cow; use std::fmt::Display; -use common_query::error::{InvalidFuncArgsSnafu, Result}; +use common_query::error::Result; use datafusion::arrow::datatypes::DataType; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::ScalarValue; use datafusion_expr::type_coercion::aggregates::{BINARYS, STRINGS}; -use datafusion_expr::{Signature, TypeSignature, Volatility}; -use datatypes::scalars::ScalarVectorBuilder; -use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef}; +use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility}; use nalgebra::DVectorView; -use snafu::ensure; -use crate::function::{Function, FunctionContext}; -use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const}; +use crate::function::Function; +use crate::scalars::vector::{VectorCalculator, impl_conv}; const NAME: &str = "vec_elem_sum"; @@ -51,43 +49,21 @@ impl Function for ElemSumFunction { ) } - fn eval( + fn invoke_with_args( &self, - _func_ctx: &FunctionContext, - columns: &[VectorRef], - ) -> common_query::error::Result { - ensure!( - columns.len() == 1, - InvalidFuncArgsSnafu { - err_msg: format!( - "The length of the args is not correct, expect exactly one, have: {}", - columns.len() - ) - } - ); - let arg0 = &columns[0]; + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let body = |v0: &ScalarValue| -> datafusion_common::Result { + let v0 = + impl_conv::as_veclit(v0)?.map(|v0| DVectorView::from_slice(&v0, v0.len()).sum()); + Ok(ScalarValue::Float32(v0)) + }; - let len = arg0.len(); - let mut result = Float32VectorBuilder::with_capacity(len); - if len == 0 { - return Ok(result.to_vector()); - } - - let arg0_const = as_veclit_if_const(arg0)?; - - for i in 0..len { - let arg0 = match arg0_const.as_ref() { - Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())), - None => as_veclit(arg0.get_ref(i))?, - }; - let Some(arg0) = arg0 else { - result.push_null(); - continue; - }; - result.push(Some(DVectorView::from_slice(&arg0, arg0.len()).sum())); - } - - Ok(result.to_vector()) + let calculator = VectorCalculator { + name: self.name(), + func: body, + }; + calculator.invoke_with_single_argument(args) } } @@ -101,27 +77,40 @@ impl Display for ElemSumFunction { mod tests { use std::sync::Arc; - use datatypes::vectors::StringVector; + use arrow::array::StringViewArray; + use arrow_schema::Field; + use datafusion::arrow::array::{Array, AsArray}; + use datafusion::arrow::datatypes::Float32Type; + use datafusion_common::config::ConfigOptions; use super::*; - use crate::function::FunctionContext; #[test] fn test_elem_sum() { let func = ElemSumFunction; - let input0 = Arc::new(StringVector::from(vec![ + let input = Arc::new(StringViewArray::from(vec![ Some("[1.0,2.0,3.0]".to_string()), Some("[4.0,5.0,6.0]".to_string()), None, ])); - let result = func.eval(&FunctionContext::default(), &[input0]).unwrap(); + let result = func + .invoke_with_args(ScalarFunctionArgs { + args: vec![ColumnarValue::Array(input.clone())], + arg_fields: vec![], + number_rows: input.len(), + return_field: Arc::new(Field::new("x", DataType::Float32, true)), + config_options: Arc::new(ConfigOptions::new()), + }) + .and_then(|v| ColumnarValue::values_to_arrays(&[v])) + .map(|mut a| a.remove(0)) + .unwrap(); + let result = result.as_primitive::(); - let result = result.as_ref(); assert_eq!(result.len(), 3); - assert_eq!(result.get_ref(0).as_f32().unwrap(), Some(6.0)); - assert_eq!(result.get_ref(1).as_f32().unwrap(), Some(15.0)); - assert_eq!(result.get_ref(2).as_f32().unwrap(), None); + assert_eq!(result.value(0), 6.0); + assert_eq!(result.value(1), 15.0); + assert!(result.is_null(2)); } } diff --git a/src/common/function/src/scalars/vector/impl_conv.rs b/src/common/function/src/scalars/vector/impl_conv.rs index 70a142c290..e97ca85e33 100644 --- a/src/common/function/src/scalars/vector/impl_conv.rs +++ b/src/common/function/src/scalars/vector/impl_conv.rs @@ -13,40 +13,18 @@ // limitations under the License. use std::borrow::Cow; -use std::sync::Arc; use common_query::error::{InvalidFuncArgsSnafu, Result}; -use datatypes::prelude::ConcreteDataType; -use datatypes::value::ValueRef; -use datatypes::vectors::Vector; - -/// Convert a constant string or binary literal to a vector literal. -pub fn as_veclit_if_const(arg: &Arc) -> Result>> { - if !arg.is_const() { - return Ok(None); - } - if arg.data_type() != ConcreteDataType::string_datatype() - && arg.data_type() != ConcreteDataType::binary_datatype() - { - return Ok(None); - } - as_veclit(arg.get_ref(0)) -} +use datafusion_common::ScalarValue; /// Convert a string or binary literal to a vector literal. -pub fn as_veclit(arg: ValueRef<'_>) -> Result>> { - match arg.data_type() { - ConcreteDataType::Binary(_) => arg - .as_binary() - .unwrap() // Safe: checked if it is a binary - .map(binlit_as_veclit) +pub fn as_veclit(arg: &ScalarValue) -> Result>> { + match arg { + ScalarValue::Binary(b) => b.as_ref().map(|x| binlit_as_veclit(x)).transpose(), + ScalarValue::Utf8(s) | ScalarValue::Utf8View(s) => s + .as_ref() + .map(|x| parse_veclit_from_strlit(x).map(Cow::Owned)) .transpose(), - ConcreteDataType::String(_) => arg - .as_string() - .unwrap() // Safe: checked if it is a string - .map(|s| Ok(Cow::Owned(parse_veclit_from_strlit(s)?))) - .transpose(), - ConcreteDataType::Null(_) => Ok(None), _ => InvalidFuncArgsSnafu { err_msg: format!("Unsupported data type: {:?}", arg.data_type()), } diff --git a/src/common/function/src/scalars/vector/scalar_add.rs b/src/common/function/src/scalars/vector/scalar_add.rs index 81a532132e..187eccc761 100644 --- a/src/common/function/src/scalars/vector/scalar_add.rs +++ b/src/common/function/src/scalars/vector/scalar_add.rs @@ -12,20 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::borrow::Cow; use std::fmt::Display; -use common_query::error::{InvalidFuncArgsSnafu, Result}; -use datafusion_expr::Signature; -use datatypes::arrow::datatypes::DataType; -use datatypes::scalars::ScalarVectorBuilder; -use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef}; +use common_query::error::Result; +use datafusion::arrow::datatypes::DataType; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::ScalarValue; +use datafusion_expr::{ScalarFunctionArgs, Signature}; use nalgebra::DVectorView; -use snafu::ensure; -use crate::function::{Function, FunctionContext}; +use crate::function::Function; use crate::helper; -use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit}; +use crate::scalars::vector::VectorCalculator; +use crate::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit}; const NAME: &str = "vec_scalar_add"; @@ -60,7 +59,7 @@ impl Function for ScalarAddFunction { } fn return_type(&self, _: &[DataType]) -> Result { - Ok(DataType::Binary) + Ok(DataType::BinaryView) } fn signature(&self) -> Signature { @@ -70,52 +69,26 @@ impl Function for ScalarAddFunction { ) } - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - ensure!( - columns.len() == 2, - InvalidFuncArgsSnafu { - err_msg: format!( - "The length of the args is not correct, expect exactly two, have: {}", - columns.len() - ), - } - ); - let arg0 = &columns[0]; - let arg1 = &columns[1]; - - let len = arg0.len(); - let mut result = BinaryVectorBuilder::with_capacity(len); - if len == 0 { - return Ok(result.to_vector()); - } - - let arg1_const = as_veclit_if_const(arg1)?; - - for i in 0..len { - let arg0 = arg0.get(i).as_f64_lossy(); - let Some(arg0) = arg0 else { - result.push_null(); - continue; + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let body = |v0: &ScalarValue, v1: &ScalarValue| -> datafusion_common::Result { + let ScalarValue::Float64(Some(v0)) = v0 else { + return Ok(ScalarValue::BinaryView(None)); }; - let arg1 = match arg1_const.as_ref() { - Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())), - None => as_veclit(arg1.get_ref(i))?, - }; - let Some(arg1) = arg1 else { - result.push_null(); - continue; - }; + let v1 = as_veclit(v1)? + .map(|v1| DVectorView::from_slice(&v1, v1.len()).add_scalar(*v0 as f32)); + let result = v1.map(|v1| veclit_to_binlit(v1.as_slice())); + Ok(ScalarValue::BinaryView(result)) + }; - let vec = DVectorView::from_slice(&arg1, arg1.len()); - let vec_res = vec.add_scalar(arg0 as _); - - let veclit = vec_res.as_slice(); - let binlit = veclit_to_binlit(veclit); - result.push(Some(&binlit)); - } - - Ok(result.to_vector()) + let calculator = VectorCalculator { + name: self.name(), + func: body, + }; + calculator.invoke_with_args(args) } } @@ -129,7 +102,9 @@ impl Display for ScalarAddFunction { mod tests { use std::sync::Arc; - use datatypes::vectors::{Float32Vector, StringVector}; + use arrow_schema::Field; + use datafusion::arrow::array::{Array, AsArray, Float64Array, StringViewArray}; + use datafusion_common::config::ConfigOptions; use super::*; @@ -137,34 +112,42 @@ mod tests { fn test_scalar_add() { let func = ScalarAddFunction; - let input0 = Arc::new(Float32Vector::from(vec![ + let input0 = Arc::new(Float64Array::from(vec![ Some(1.0), Some(-1.0), None, Some(3.0), ])); - let input1 = Arc::new(StringVector::from(vec![ + let input1 = Arc::new(StringViewArray::from(vec![ Some("[1.0,2.0,3.0]".to_string()), Some("[4.0,5.0,6.0]".to_string()), Some("[7.0,8.0,9.0]".to_string()), None, ])); + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)], + arg_fields: vec![], + number_rows: 4, + return_field: Arc::new(Field::new("x", DataType::BinaryView, false)), + config_options: Arc::new(ConfigOptions::new()), + }; let result = func - .eval(&FunctionContext::default(), &[input0, input1]) + .invoke_with_args(args) + .and_then(|x| x.to_array(4)) .unwrap(); - let result = result.as_ref(); + let result = result.as_binary_view(); assert_eq!(result.len(), 4); assert_eq!( - result.get_ref(0).as_binary().unwrap(), - Some(veclit_to_binlit(&[2.0, 3.0, 4.0]).as_slice()) + result.value(0), + veclit_to_binlit(&[2.0, 3.0, 4.0]).as_slice() ); assert_eq!( - result.get_ref(1).as_binary().unwrap(), - Some(veclit_to_binlit(&[3.0, 4.0, 5.0]).as_slice()) + result.value(1), + veclit_to_binlit(&[3.0, 4.0, 5.0]).as_slice() ); - assert!(result.get_ref(2).is_null()); - assert!(result.get_ref(3).is_null()); + assert!(result.is_null(2)); + assert!(result.is_null(3)); } } diff --git a/src/common/function/src/scalars/vector/scalar_mul.rs b/src/common/function/src/scalars/vector/scalar_mul.rs index 8985987331..27127e8ee9 100644 --- a/src/common/function/src/scalars/vector/scalar_mul.rs +++ b/src/common/function/src/scalars/vector/scalar_mul.rs @@ -12,20 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::borrow::Cow; use std::fmt::Display; -use common_query::error::{InvalidFuncArgsSnafu, Result}; -use datafusion_expr::Signature; -use datatypes::arrow::datatypes::DataType; -use datatypes::scalars::ScalarVectorBuilder; -use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef}; +use common_query::error::Result; +use datafusion::arrow::datatypes::DataType; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::ScalarValue; +use datafusion_expr::{ScalarFunctionArgs, Signature}; use nalgebra::DVectorView; -use snafu::ensure; -use crate::function::{Function, FunctionContext}; +use crate::function::Function; use crate::helper; -use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit}; +use crate::scalars::vector::VectorCalculator; +use crate::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit}; const NAME: &str = "vec_scalar_mul"; @@ -60,7 +59,7 @@ impl Function for ScalarMulFunction { } fn return_type(&self, _: &[DataType]) -> Result { - Ok(DataType::Binary) + Ok(DataType::BinaryView) } fn signature(&self) -> Signature { @@ -70,52 +69,26 @@ impl Function for ScalarMulFunction { ) } - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - ensure!( - columns.len() == 2, - InvalidFuncArgsSnafu { - err_msg: format!( - "The length of the args is not correct, expect exactly two, have: {}", - columns.len() - ), - } - ); - let arg0 = &columns[0]; - let arg1 = &columns[1]; - - let len = arg0.len(); - let mut result = BinaryVectorBuilder::with_capacity(len); - if len == 0 { - return Ok(result.to_vector()); - } - - let arg1_const = as_veclit_if_const(arg1)?; - - for i in 0..len { - let arg0 = arg0.get(i).as_f64_lossy(); - let Some(arg0) = arg0 else { - result.push_null(); - continue; + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let body = |v0: &ScalarValue, v1: &ScalarValue| -> datafusion_common::Result { + let ScalarValue::Float64(Some(v0)) = v0 else { + return Ok(ScalarValue::BinaryView(None)); }; - let arg1 = match arg1_const.as_ref() { - Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())), - None => as_veclit(arg1.get_ref(i))?, - }; - let Some(arg1) = arg1 else { - result.push_null(); - continue; - }; + let v1 = + as_veclit(v1)?.map(|v1| DVectorView::from_slice(&v1, v1.len()).scale(*v0 as f32)); + let result = v1.map(|v1| veclit_to_binlit(v1.as_slice())); + Ok(ScalarValue::BinaryView(result)) + }; - let vec = DVectorView::from_slice(&arg1, arg1.len()); - let vec_res = vec.scale(arg0 as _); - - let veclit = vec_res.as_slice(); - let binlit = veclit_to_binlit(veclit); - result.push(Some(&binlit)); - } - - Ok(result.to_vector()) + let calculator = VectorCalculator { + name: self.name(), + func: body, + }; + calculator.invoke_with_args(args) } } @@ -129,7 +102,9 @@ impl Display for ScalarMulFunction { mod tests { use std::sync::Arc; - use datatypes::vectors::{Float32Vector, StringVector}; + use arrow_schema::Field; + use datafusion::arrow::array::{Array, AsArray, Float64Array, StringViewArray}; + use datafusion_common::config::ConfigOptions; use super::*; @@ -137,34 +112,42 @@ mod tests { fn test_scalar_mul() { let func = ScalarMulFunction; - let input0 = Arc::new(Float32Vector::from(vec![ + let input0 = Arc::new(Float64Array::from(vec![ Some(2.0), Some(-0.5), None, Some(3.0), ])); - let input1 = Arc::new(StringVector::from(vec![ + let input1 = Arc::new(StringViewArray::from(vec![ Some("[1.0,2.0,3.0]".to_string()), Some("[8.0,10.0,12.0]".to_string()), Some("[7.0,8.0,9.0]".to_string()), None, ])); + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)], + arg_fields: vec![], + number_rows: 4, + return_field: Arc::new(Field::new("x", DataType::BinaryView, false)), + config_options: Arc::new(ConfigOptions::new()), + }; let result = func - .eval(&FunctionContext::default(), &[input0, input1]) + .invoke_with_args(args) + .and_then(|x| x.to_array(4)) .unwrap(); - let result = result.as_ref(); + let result = result.as_binary_view(); assert_eq!(result.len(), 4); assert_eq!( - result.get_ref(0).as_binary().unwrap(), - Some(veclit_to_binlit(&[2.0, 4.0, 6.0]).as_slice()) + result.value(0), + veclit_to_binlit(&[2.0, 4.0, 6.0]).as_slice() ); assert_eq!( - result.get_ref(1).as_binary().unwrap(), - Some(veclit_to_binlit(&[-4.0, -5.0, -6.0]).as_slice()) + result.value(1), + veclit_to_binlit(&[-4.0, -5.0, -6.0]).as_slice() ); - assert!(result.get_ref(2).is_null()); - assert!(result.get_ref(3).is_null()); + assert!(result.is_null(2)); + assert!(result.is_null(3)); } } diff --git a/src/common/function/src/scalars/vector/vector_add.rs b/src/common/function/src/scalars/vector/vector_add.rs index 3ecab37938..e07cefbf09 100644 --- a/src/common/function/src/scalars/vector/vector_add.rs +++ b/src/common/function/src/scalars/vector/vector_add.rs @@ -15,17 +15,17 @@ use std::borrow::Cow; use std::fmt::Display; -use common_query::error::{InvalidFuncArgsSnafu, Result}; -use datafusion_expr::Signature; -use datatypes::arrow::datatypes::DataType; -use datatypes::scalars::ScalarVectorBuilder; -use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef}; +use common_query::error::Result; +use datafusion::arrow::datatypes::DataType; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{DataFusionError, ScalarValue}; +use datafusion_expr::{ScalarFunctionArgs, Signature}; use nalgebra::DVectorView; -use snafu::ensure; -use crate::function::{Function, FunctionContext}; +use crate::function::Function; use crate::helper; -use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit}; +use crate::scalars::vector::VectorCalculator; +use crate::scalars::vector::impl_conv::veclit_to_binlit; const NAME: &str = "vec_add"; @@ -51,7 +51,7 @@ impl Function for VectorAddFunction { } fn return_type(&self, _: &[DataType]) -> Result { - Ok(DataType::Binary) + Ok(DataType::BinaryView) } fn signature(&self) -> Signature { @@ -61,66 +61,36 @@ impl Function for VectorAddFunction { ) } - fn eval( + fn invoke_with_args( &self, - _func_ctx: &FunctionContext, - columns: &[VectorRef], - ) -> common_query::error::Result { - ensure!( - columns.len() == 2, - InvalidFuncArgsSnafu { - err_msg: format!( - "The length of the args is not correct, expect exactly two, have: {}", - columns.len() - ) - } - ); - let arg0 = &columns[0]; - let arg1 = &columns[1]; + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let body = |v0: &Option>, + v1: &Option>| + -> datafusion_common::Result { + let result = if let (Some(v0), Some(v1)) = (v0, v1) { + let v0 = DVectorView::from_slice(v0, v0.len()); + let v1 = DVectorView::from_slice(v1, v1.len()); + if v0.len() != v1.len() { + return Err(DataFusionError::Execution(format!( + "vectors length not match: {}", + self.name() + ))); + } - ensure!( - arg0.len() == arg1.len(), - InvalidFuncArgsSnafu { - err_msg: format!( - "The lengths of the vector are not aligned, args 0: {}, args 1: {}", - arg0.len(), - arg1.len(), - ) - } - ); - - let len = arg0.len(); - let mut result = BinaryVectorBuilder::with_capacity(len); - if len == 0 { - return Ok(result.to_vector()); - } - - let arg0_const = as_veclit_if_const(arg0)?; - let arg1_const = as_veclit_if_const(arg1)?; - - for i in 0..len { - let arg0 = match arg0_const.as_ref() { - Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())), - None => as_veclit(arg0.get_ref(i))?, + let result = veclit_to_binlit((v0 + v1).as_slice()); + Some(result) + } else { + None }; - let arg1 = match arg1_const.as_ref() { - Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())), - None => as_veclit(arg1.get_ref(i))?, - }; - let (Some(arg0), Some(arg1)) = (arg0, arg1) else { - result.push_null(); - continue; - }; - let vec0 = DVectorView::from_slice(&arg0, arg0.len()); - let vec1 = DVectorView::from_slice(&arg1, arg1.len()); + Ok(ScalarValue::BinaryView(result)) + }; - let vec_res = vec0 + vec1; - let veclit = vec_res.as_slice(); - let binlit = veclit_to_binlit(veclit); - result.push(Some(&binlit)); - } - - Ok(result.to_vector()) + let calculator = VectorCalculator { + name: self.name(), + func: body, + }; + calculator.invoke_with_vectors(args) } } @@ -134,8 +104,9 @@ impl Display for VectorAddFunction { mod tests { use std::sync::Arc; - use common_query::error::Error; - use datatypes::vectors::StringVector; + use arrow_schema::Field; + use datafusion::arrow::array::{Array, AsArray, StringViewArray}; + use datafusion_common::config::ConfigOptions; use super::*; @@ -143,63 +114,71 @@ mod tests { fn test_sub() { let func = VectorAddFunction; - let input0 = Arc::new(StringVector::from(vec![ + let input0 = Arc::new(StringViewArray::from(vec![ Some("[1.0,2.0,3.0]".to_string()), Some("[4.0,5.0,6.0]".to_string()), None, Some("[2.0,3.0,3.0]".to_string()), ])); - let input1 = Arc::new(StringVector::from(vec![ + let input1 = Arc::new(StringViewArray::from(vec![ Some("[1.0,1.0,1.0]".to_string()), Some("[6.0,5.0,4.0]".to_string()), Some("[3.0,2.0,2.0]".to_string()), None, ])); + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)], + arg_fields: vec![], + number_rows: 4, + return_field: Arc::new(Field::new("x", DataType::BinaryView, false)), + config_options: Arc::new(ConfigOptions::new()), + }; let result = func - .eval(&FunctionContext::default(), &[input0, input1]) + .invoke_with_args(args) + .and_then(|x| x.to_array(4)) .unwrap(); - let result = result.as_ref(); + let result = result.as_binary_view(); assert_eq!(result.len(), 4); assert_eq!( - result.get_ref(0).as_binary().unwrap(), - Some(veclit_to_binlit(&[2.0, 3.0, 4.0]).as_slice()) + result.value(0), + veclit_to_binlit(&[2.0, 3.0, 4.0]).as_slice() ); assert_eq!( - result.get_ref(1).as_binary().unwrap(), - Some(veclit_to_binlit(&[10.0, 10.0, 10.0]).as_slice()) + result.value(1), + veclit_to_binlit(&[10.0, 10.0, 10.0]).as_slice() ); - assert!(result.get_ref(2).is_null()); - assert!(result.get_ref(3).is_null()); + assert!(result.is_null(2)); + assert!(result.is_null(3)); } #[test] fn test_sub_error() { let func = VectorAddFunction; - let input0 = Arc::new(StringVector::from(vec![ + let input0 = Arc::new(StringViewArray::from(vec![ Some("[1.0,2.0,3.0]".to_string()), Some("[4.0,5.0,6.0]".to_string()), None, Some("[2.0,3.0,3.0]".to_string()), ])); - let input1 = Arc::new(StringVector::from(vec![ + let input1 = Arc::new(StringViewArray::from(vec![ Some("[1.0,1.0,1.0]".to_string()), Some("[6.0,5.0,4.0]".to_string()), Some("[3.0,2.0,2.0]".to_string()), ])); - let result = func.eval(&FunctionContext::default(), &[input0, input1]); - - match result { - Err(Error::InvalidFuncArgs { err_msg, .. }) => { - assert_eq!( - err_msg, - "The lengths of the vector are not aligned, args 0: 4, args 1: 3" - ) - } - _ => unreachable!(), - } + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)], + arg_fields: vec![], + number_rows: 4, + return_field: Arc::new(Field::new("x", DataType::BinaryView, false)), + config_options: Arc::new(ConfigOptions::new()), + }; + let e = func.invoke_with_args(args).unwrap_err(); + assert!(e.to_string().starts_with( + "Internal error: Arguments has mixed length. Expected length: 4, found length: 3." + )); } } diff --git a/src/common/function/src/scalars/vector/vector_dim.rs b/src/common/function/src/scalars/vector/vector_dim.rs index 2f877ed9ea..8a21f6a3f8 100644 --- a/src/common/function/src/scalars/vector/vector_dim.rs +++ b/src/common/function/src/scalars/vector/vector_dim.rs @@ -12,19 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::borrow::Cow; use std::fmt::Display; -use common_query::error::{InvalidFuncArgsSnafu, Result}; +use common_query::error::Result; use datafusion::arrow::datatypes::DataType; +use datafusion::logical_expr::ColumnarValue; use datafusion::logical_expr_common::type_coercion::aggregates::{BINARYS, STRINGS}; -use datafusion_expr::{Signature, TypeSignature, Volatility}; -use datatypes::scalars::ScalarVectorBuilder; -use datatypes::vectors::{MutableVector, UInt64VectorBuilder, VectorRef}; -use snafu::ensure; +use datafusion_common::ScalarValue; +use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility}; -use crate::function::{Function, FunctionContext}; -use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const}; +use crate::function::Function; +use crate::scalars::vector::VectorCalculator; +use crate::scalars::vector::impl_conv::as_veclit; const NAME: &str = "vec_dim"; @@ -63,43 +62,20 @@ impl Function for VectorDimFunction { ) } - fn eval( + fn invoke_with_args( &self, - _func_ctx: &FunctionContext, - columns: &[VectorRef], - ) -> common_query::error::Result { - ensure!( - columns.len() == 1, - InvalidFuncArgsSnafu { - err_msg: format!( - "The length of the args is not correct, expect exactly one, have: {}", - columns.len() - ) - } - ); - let arg0 = &columns[0]; + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let body = |v0: &ScalarValue| -> datafusion_common::Result { + let v = as_veclit(v0)?.map(|v0| v0.len() as u64); + Ok(ScalarValue::UInt64(v)) + }; - let len = arg0.len(); - let mut result = UInt64VectorBuilder::with_capacity(len); - if len == 0 { - return Ok(result.to_vector()); - } - - let arg0_const = as_veclit_if_const(arg0)?; - - for i in 0..len { - let arg0 = match arg0_const.as_ref() { - Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())), - None => as_veclit(arg0.get_ref(i))?, - }; - let Some(arg0) = arg0 else { - result.push_null(); - continue; - }; - result.push(Some(arg0.len() as u64)); - } - - Ok(result.to_vector()) + let calculator = VectorCalculator { + name: self.name(), + func: body, + }; + calculator.invoke_with_single_argument(args) } } @@ -113,8 +89,10 @@ impl Display for VectorDimFunction { mod tests { use std::sync::Arc; - use common_query::error::Error; - use datatypes::vectors::StringVector; + use arrow_schema::Field; + use datafusion::arrow::array::{Array, AsArray, StringViewArray}; + use datafusion::arrow::datatypes::UInt64Type; + use datafusion_common::config::ConfigOptions; use super::*; @@ -122,49 +100,60 @@ mod tests { fn test_vec_dim() { let func = VectorDimFunction; - let input0 = Arc::new(StringVector::from(vec![ + let input0 = Arc::new(StringViewArray::from(vec![ Some("[0.0,2.0,3.0]".to_string()), Some("[1.0,2.0,3.0,4.0]".to_string()), None, Some("[5.0]".to_string()), ])); - let result = func.eval(&FunctionContext::default(), &[input0]).unwrap(); + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(input0)], + arg_fields: vec![], + number_rows: 4, + return_field: Arc::new(Field::new("x", DataType::UInt64, false)), + config_options: Arc::new(ConfigOptions::new()), + }; + let result = func + .invoke_with_args(args) + .and_then(|x| x.to_array(4)) + .unwrap(); - let result = result.as_ref(); + let result = result.as_primitive::(); assert_eq!(result.len(), 4); - assert_eq!(result.get_ref(0).as_u64().unwrap(), Some(3)); - assert_eq!(result.get_ref(1).as_u64().unwrap(), Some(4)); - assert!(result.get_ref(2).is_null()); - assert_eq!(result.get_ref(3).as_u64().unwrap(), Some(1)); + assert_eq!(result.value(0), 3); + assert_eq!(result.value(1), 4); + assert!(result.is_null(2)); + assert_eq!(result.value(3), 1); } #[test] fn test_dim_error() { let func = VectorDimFunction; - let input0 = Arc::new(StringVector::from(vec![ + let input0 = Arc::new(StringViewArray::from(vec![ Some("[1.0,2.0,3.0]".to_string()), Some("[4.0,5.0,6.0]".to_string()), None, Some("[2.0,3.0,3.0]".to_string()), ])); - let input1 = Arc::new(StringVector::from(vec![ + let input1 = Arc::new(StringViewArray::from(vec![ Some("[1.0,1.0,1.0]".to_string()), Some("[6.0,5.0,4.0]".to_string()), Some("[3.0,2.0,2.0]".to_string()), ])); - let result = func.eval(&FunctionContext::default(), &[input0, input1]); - - match result { - Err(Error::InvalidFuncArgs { err_msg, .. }) => { - assert_eq!( - err_msg, - "The length of the args is not correct, expect exactly one, have: 2" - ) - } - _ => unreachable!(), - } + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)], + arg_fields: vec![], + number_rows: 4, + return_field: Arc::new(Field::new("x", DataType::UInt64, false)), + config_options: Arc::new(ConfigOptions::new()), + }; + let e = func.invoke_with_args(args).unwrap_err(); + assert!( + e.to_string() + .starts_with("Execution error: vec_dim function requires 1 argument, got 2") + ) } } diff --git a/src/common/function/src/scalars/vector/vector_div.rs b/src/common/function/src/scalars/vector/vector_div.rs index 2e2d6898a1..b76602ac7c 100644 --- a/src/common/function/src/scalars/vector/vector_div.rs +++ b/src/common/function/src/scalars/vector/vector_div.rs @@ -15,17 +15,17 @@ use std::borrow::Cow; use std::fmt::Display; -use common_query::error::{InvalidFuncArgsSnafu, Result}; -use datafusion_expr::Signature; -use datatypes::arrow::datatypes::DataType; -use datatypes::scalars::ScalarVectorBuilder; -use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef}; +use common_query::error::Result; +use datafusion::arrow::datatypes::DataType; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{DataFusionError, ScalarValue}; +use datafusion_expr::{ScalarFunctionArgs, Signature}; use nalgebra::DVectorView; -use snafu::ensure; -use crate::function::{Function, FunctionContext}; +use crate::function::Function; use crate::helper; -use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit}; +use crate::scalars::vector::VectorCalculator; +use crate::scalars::vector::impl_conv::veclit_to_binlit; const NAME: &str = "vec_div"; @@ -52,7 +52,7 @@ impl Function for VectorDivFunction { } fn return_type(&self, _: &[DataType]) -> Result { - Ok(DataType::Binary) + Ok(DataType::BinaryView) } fn signature(&self) -> Signature { @@ -62,64 +62,36 @@ impl Function for VectorDivFunction { ) } - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - ensure!( - columns.len() == 2, - InvalidFuncArgsSnafu { - err_msg: format!( - "The length of the args is not correct, expect exactly two, have: {}", - columns.len() - ), - } - ); + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let body = |v0: &Option>, + v1: &Option>| + -> datafusion_common::Result { + let result = if let (Some(v0), Some(v1)) = (v0, v1) { + let v0 = DVectorView::from_slice(v0, v0.len()); + let v1 = DVectorView::from_slice(v1, v1.len()); + if v0.len() != v1.len() { + return Err(DataFusionError::Execution(format!( + "vectors length not match: {}", + self.name() + ))); + } - let arg0 = &columns[0]; - let arg1 = &columns[1]; - - let len = arg0.len(); - let mut result = BinaryVectorBuilder::with_capacity(len); - if len == 0 { - return Ok(result.to_vector()); - } - - let arg0_const = as_veclit_if_const(arg0)?; - let arg1_const = as_veclit_if_const(arg1)?; - - for i in 0..len { - let arg0 = match arg0_const.as_ref() { - Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())), - None => as_veclit(arg0.get_ref(i))?, - }; - - let arg1 = match arg1_const.as_ref() { - Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())), - None => as_veclit(arg1.get_ref(i))?, - }; - - if let (Some(arg0), Some(arg1)) = (arg0, arg1) { - ensure!( - arg0.len() == arg1.len(), - InvalidFuncArgsSnafu { - err_msg: format!( - "The length of the vectors must match for division, have: {} vs {}", - arg0.len(), - arg1.len() - ), - } - ); - let vec0 = DVectorView::from_slice(&arg0, arg0.len()); - let vec1 = DVectorView::from_slice(&arg1, arg1.len()); - let vec_res = vec0.component_div(&vec1); - - let veclit = vec_res.as_slice(); - let binlit = veclit_to_binlit(veclit); - result.push(Some(&binlit)); + let result = veclit_to_binlit((v0.component_div(&v1)).as_slice()); + Some(result) } else { - result.push_null(); - } - } + None + }; + Ok(ScalarValue::BinaryView(result)) + }; - Ok(result.to_vector()) + let calculator = VectorCalculator { + name: self.name(), + func: body, + }; + calculator.invoke_with_vectors(args) } } @@ -133,8 +105,9 @@ impl Display for VectorDivFunction { mod tests { use std::sync::Arc; - use common_query::error; - use datatypes::vectors::StringVector; + use arrow_schema::Field; + use datafusion::arrow::array::{Array, AsArray, StringViewArray}; + use datafusion_common::config::ConfigOptions; use super::*; @@ -144,69 +117,80 @@ mod tests { let vec0 = vec![1.0, 2.0, 3.0]; let vec1 = vec![1.0, 1.0]; - let (len0, len1) = (vec0.len(), vec1.len()); - let input0 = Arc::new(StringVector::from(vec![Some(format!("{vec0:?}"))])); - let input1 = Arc::new(StringVector::from(vec![Some(format!("{vec1:?}"))])); + let input0 = Arc::new(StringViewArray::from(vec![Some(format!("{vec0:?}"))])); + let input1 = Arc::new(StringViewArray::from(vec![Some(format!("{vec1:?}"))])); - let err = func - .eval(&FunctionContext::default(), &[input0, input1]) - .unwrap_err(); + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)], + arg_fields: vec![], + number_rows: 3, + return_field: Arc::new(Field::new("x", DataType::BinaryView, false)), + config_options: Arc::new(ConfigOptions::new()), + }; + let e = func.invoke_with_args(args).unwrap_err(); + assert_eq!( + e.to_string(), + "Execution error: vectors length not match: vec_div" + ); - match err { - error::Error::InvalidFuncArgs { err_msg, .. } => { - assert_eq!( - err_msg, - format!( - "The length of the vectors must match for division, have: {} vs {}", - len0, len1 - ) - ) - } - _ => unreachable!(), - } - - let input0 = Arc::new(StringVector::from(vec![ + let input0 = Arc::new(StringViewArray::from(vec![ Some("[1.0,2.0,3.0]".to_string()), Some("[8.0,10.0,12.0]".to_string()), Some("[7.0,8.0,9.0]".to_string()), None, ])); - let input1 = Arc::new(StringVector::from(vec![ + let input1 = Arc::new(StringViewArray::from(vec![ Some("[1.0,1.0,1.0]".to_string()), Some("[2.0,2.0,2.0]".to_string()), None, Some("[3.0,3.0,3.0]".to_string()), ])); + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)], + arg_fields: vec![], + number_rows: 4, + return_field: Arc::new(Field::new("x", DataType::BinaryView, false)), + config_options: Arc::new(ConfigOptions::new()), + }; let result = func - .eval(&FunctionContext::default(), &[input0, input1]) + .invoke_with_args(args) + .and_then(|x| x.to_array(4)) .unwrap(); - let result = result.as_ref(); + let result = result.as_binary_view(); assert_eq!(result.len(), 4); assert_eq!( - result.get_ref(0).as_binary().unwrap(), - Some(veclit_to_binlit(&[1.0, 2.0, 3.0]).as_slice()) + result.value(0), + veclit_to_binlit(&[1.0, 2.0, 3.0]).as_slice() ); assert_eq!( - result.get_ref(1).as_binary().unwrap(), - Some(veclit_to_binlit(&[4.0, 5.0, 6.0]).as_slice()) + result.value(1), + veclit_to_binlit(&[4.0, 5.0, 6.0]).as_slice() ); - assert!(result.get_ref(2).is_null()); - assert!(result.get_ref(3).is_null()); + assert!(result.is_null(2)); + assert!(result.is_null(3)); - let input0 = Arc::new(StringVector::from(vec![Some("[1.0,-2.0]".to_string())])); - let input1 = Arc::new(StringVector::from(vec![Some("[0.0,0.0]".to_string())])); + let input0 = Arc::new(StringViewArray::from(vec![Some("[1.0,-2.0]".to_string())])); + let input1 = Arc::new(StringViewArray::from(vec![Some("[0.0,0.0]".to_string())])); + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)], + arg_fields: vec![], + number_rows: 2, + return_field: Arc::new(Field::new("x", DataType::BinaryView, false)), + config_options: Arc::new(ConfigOptions::new()), + }; let result = func - .eval(&FunctionContext::default(), &[input0, input1]) + .invoke_with_args(args) + .and_then(|x| x.to_array(2)) .unwrap(); - let result = result.as_ref(); + let result = result.as_binary_view(); assert_eq!( - result.get_ref(0).as_binary().unwrap(), - Some(veclit_to_binlit(&[f64::INFINITY as f32, f64::NEG_INFINITY as f32]).as_slice()) + result.value(0), + veclit_to_binlit(&[f64::INFINITY as f32, f64::NEG_INFINITY as f32]).as_slice() ); } } diff --git a/src/common/function/src/scalars/vector/vector_kth_elem.rs b/src/common/function/src/scalars/vector/vector_kth_elem.rs index d1d110c47c..aae451e4d4 100644 --- a/src/common/function/src/scalars/vector/vector_kth_elem.rs +++ b/src/common/function/src/scalars/vector/vector_kth_elem.rs @@ -12,19 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::borrow::Cow; use std::fmt::Display; -use common_query::error::{InvalidFuncArgsSnafu, Result}; -use datafusion_expr::Signature; +use common_query::error::Result; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{DataFusionError, ScalarValue}; +use datafusion_expr::{ScalarFunctionArgs, Signature}; use datatypes::arrow::datatypes::DataType; -use datatypes::scalars::ScalarVectorBuilder; -use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef}; -use snafu::ensure; -use crate::function::{Function, FunctionContext}; +use crate::function::Function; use crate::helper; -use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const}; +use crate::scalars::vector::VectorCalculator; +use crate::scalars::vector::impl_conv::as_veclit; const NAME: &str = "vec_kth_elem"; @@ -63,72 +62,44 @@ impl Function for VectorKthElemFunction { ) } - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - ensure!( - columns.len() == 2, - InvalidFuncArgsSnafu { - err_msg: format!( - "The length of the args is not correct, expect exactly two, have: {}", - columns.len() - ), - } - ); + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let body = |v0: &ScalarValue, v1: &ScalarValue| -> datafusion_common::Result { + let v0 = as_veclit(v0)?; - let arg0 = &columns[0]; - let arg1 = &columns[1]; + let v1 = match v1 { + ScalarValue::Int64(None) => return Ok(ScalarValue::Float32(None)), + ScalarValue::Int64(Some(v1)) if *v1 >= 0 => *v1 as usize, + _ => { + return Err(DataFusionError::Execution(format!( + "2nd argument not a valid index or expected datatype: {}", + self.name() + ))); + } + }; - let len = arg0.len(); - let mut result = Float32VectorBuilder::with_capacity(len); - if len == 0 { - return Ok(result.to_vector()); + let result = v0 + .map(|v0| { + if v1 >= v0.len() { + Err(DataFusionError::Execution(format!( + "index out of bound: {}", + self.name() + ))) + } else { + Ok(v0[v1]) + } + }) + .transpose()?; + Ok(ScalarValue::Float32(result)) }; - let arg0_const = as_veclit_if_const(arg0)?; - - for i in 0..len { - let arg0 = match arg0_const.as_ref() { - Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())), - None => as_veclit(arg0.get_ref(i))?, - }; - let Some(arg0) = arg0 else { - result.push_null(); - continue; - }; - - let arg1 = arg1.get(i).as_f64_lossy(); - let Some(arg1) = arg1 else { - result.push_null(); - continue; - }; - - ensure!( - arg1 >= 0.0 && arg1.fract() == 0.0, - InvalidFuncArgsSnafu { - err_msg: format!( - "Invalid argument: k must be a non-negative integer, but got k = {}.", - arg1 - ), - } - ); - - let k = arg1 as usize; - - ensure!( - k < arg0.len(), - InvalidFuncArgsSnafu { - err_msg: format!( - "Out of range: k must be in the range [0, {}], but got k = {}.", - arg0.len() - 1, - k - ), - } - ); - - let value = arg0[k]; - - result.push(Some(value)); - } - Ok(result.to_vector()) + let calculator = VectorCalculator { + name: self.name(), + func: body, + }; + calculator.invoke_with_args(args) } } @@ -142,8 +113,10 @@ impl Display for VectorKthElemFunction { mod tests { use std::sync::Arc; - use common_query::error; - use datatypes::vectors::{Int64Vector, StringVector}; + use arrow_schema::Field; + use datafusion::arrow::array::{Array, ArrayRef, AsArray, Int64Array, StringViewArray}; + use datafusion::arrow::datatypes::Float32Type; + use datafusion_common::config::ConfigOptions; use super::*; @@ -151,55 +124,66 @@ mod tests { fn test_vec_kth_elem() { let func = VectorKthElemFunction; - let input0 = Arc::new(StringVector::from(vec![ + let input0: ArrayRef = Arc::new(StringViewArray::from(vec![ Some("[1.0,2.0,3.0]".to_string()), Some("[4.0,5.0,6.0]".to_string()), Some("[7.0,8.0,9.0]".to_string()), None, ])); - let input1 = Arc::new(Int64Vector::from(vec![Some(0), Some(2), None, Some(1)])); + let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(0), Some(2), None, Some(1)])); + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)], + arg_fields: vec![], + number_rows: 4, + return_field: Arc::new(Field::new("x", DataType::Float32, false)), + config_options: Arc::new(ConfigOptions::new()), + }; let result = func - .eval(&FunctionContext::default(), &[input0, input1]) + .invoke_with_args(args) + .and_then(|x| x.to_array(4)) .unwrap(); - let result = result.as_ref(); + let result = result.as_primitive::(); assert_eq!(result.len(), 4); - assert_eq!(result.get_ref(0).as_f32().unwrap(), Some(1.0)); - assert_eq!(result.get_ref(1).as_f32().unwrap(), Some(6.0)); - assert!(result.get_ref(2).is_null()); - assert!(result.get_ref(3).is_null()); + assert_eq!(result.value(0), 1.0); + assert_eq!(result.value(1), 6.0); + assert!(result.is_null(2)); + assert!(result.is_null(3)); - let input0 = Arc::new(StringVector::from(vec![Some("[1.0,2.0,3.0]".to_string())])); - let input1 = Arc::new(Int64Vector::from(vec![Some(3)])); + let input0: ArrayRef = Arc::new(StringViewArray::from(vec![Some( + "[1.0,2.0,3.0]".to_string(), + )])); + let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(3)])); - let err = func - .eval(&FunctionContext::default(), &[input0, input1]) - .unwrap_err(); - match err { - error::Error::InvalidFuncArgs { err_msg, .. } => { - assert_eq!( - err_msg, - format!("Out of range: k must be in the range [0, 2], but got k = 3.") - ) - } - _ => unreachable!(), - } + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)], + arg_fields: vec![], + number_rows: 3, + return_field: Arc::new(Field::new("x", DataType::Float32, false)), + config_options: Arc::new(ConfigOptions::new()), + }; + let e = func.invoke_with_args(args).unwrap_err(); + assert!( + e.to_string() + .starts_with("Execution error: index out of bound: vec_kth_elem") + ); - let input0 = Arc::new(StringVector::from(vec![Some("[1.0,2.0,3.0]".to_string())])); - let input1 = Arc::new(Int64Vector::from(vec![Some(-1)])); + let input0: ArrayRef = Arc::new(StringViewArray::from(vec![Some( + "[1.0,2.0,3.0]".to_string(), + )])); + let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(-1)])); - let err = func - .eval(&FunctionContext::default(), &[input0, input1]) - .unwrap_err(); - match err { - error::Error::InvalidFuncArgs { err_msg, .. } => { - assert_eq!( - err_msg, - format!("Invalid argument: k must be a non-negative integer, but got k = -1.") - ) - } - _ => unreachable!(), - } + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)], + arg_fields: vec![], + number_rows: 3, + return_field: Arc::new(Field::new("x", DataType::Float32, false)), + config_options: Arc::new(ConfigOptions::new()), + }; + let e = func.invoke_with_args(args).unwrap_err(); + assert!(e.to_string().starts_with( + "Execution error: 2nd argument not a valid index or expected datatype: vec_kth_elem" + )); } } diff --git a/src/common/function/src/scalars/vector/vector_mul.rs b/src/common/function/src/scalars/vector/vector_mul.rs index 99bf3efe5d..e25178cc7c 100644 --- a/src/common/function/src/scalars/vector/vector_mul.rs +++ b/src/common/function/src/scalars/vector/vector_mul.rs @@ -15,17 +15,17 @@ use std::borrow::Cow; use std::fmt::Display; -use common_query::error::{InvalidFuncArgsSnafu, Result}; -use datafusion_expr::Signature; -use datatypes::arrow::datatypes::DataType; -use datatypes::scalars::ScalarVectorBuilder; -use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef}; +use common_query::error::Result; +use datafusion::arrow::datatypes::DataType; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{DataFusionError, ScalarValue}; +use datafusion_expr::{ScalarFunctionArgs, Signature}; use nalgebra::DVectorView; -use snafu::ensure; -use crate::function::{Function, FunctionContext}; +use crate::function::Function; use crate::helper; -use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit}; +use crate::scalars::vector::VectorCalculator; +use crate::scalars::vector::impl_conv::veclit_to_binlit; const NAME: &str = "vec_mul"; @@ -52,7 +52,7 @@ impl Function for VectorMulFunction { } fn return_type(&self, _: &[DataType]) -> Result { - Ok(DataType::Binary) + Ok(DataType::BinaryView) } fn signature(&self) -> Signature { @@ -62,64 +62,36 @@ impl Function for VectorMulFunction { ) } - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - ensure!( - columns.len() == 2, - InvalidFuncArgsSnafu { - err_msg: format!( - "The length of the args is not correct, expect exactly two, have: {}", - columns.len() - ), - } - ); + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let body = |v0: &Option>, + v1: &Option>| + -> datafusion_common::Result { + let result = if let (Some(v0), Some(v1)) = (v0, v1) { + let v0 = DVectorView::from_slice(v0, v0.len()); + let v1 = DVectorView::from_slice(v1, v1.len()); + if v0.len() != v1.len() { + return Err(DataFusionError::Execution(format!( + "vectors length not match: {}", + self.name() + ))); + } - let arg0 = &columns[0]; - let arg1 = &columns[1]; - - let len = arg0.len(); - let mut result = BinaryVectorBuilder::with_capacity(len); - if len == 0 { - return Ok(result.to_vector()); - } - - let arg0_const = as_veclit_if_const(arg0)?; - let arg1_const = as_veclit_if_const(arg1)?; - - for i in 0..len { - let arg0 = match arg0_const.as_ref() { - Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())), - None => as_veclit(arg0.get_ref(i))?, - }; - - let arg1 = match arg1_const.as_ref() { - Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())), - None => as_veclit(arg1.get_ref(i))?, - }; - - if let (Some(arg0), Some(arg1)) = (arg0, arg1) { - ensure!( - arg0.len() == arg1.len(), - InvalidFuncArgsSnafu { - err_msg: format!( - "The length of the vectors must match for multiplying, have: {} vs {}", - arg0.len(), - arg1.len() - ), - } - ); - let vec0 = DVectorView::from_slice(&arg0, arg0.len()); - let vec1 = DVectorView::from_slice(&arg1, arg1.len()); - let vec_res = vec1.component_mul(&vec0); - - let veclit = vec_res.as_slice(); - let binlit = veclit_to_binlit(veclit); - result.push(Some(&binlit)); + let result = veclit_to_binlit((v0.component_mul(&v1)).as_slice()); + Some(result) } else { - result.push_null(); - } - } + None + }; + Ok(ScalarValue::BinaryView(result)) + }; - Ok(result.to_vector()) + let calculator = VectorCalculator { + name: self.name(), + func: body, + }; + calculator.invoke_with_vectors(args) } } @@ -133,8 +105,9 @@ impl Display for VectorMulFunction { mod tests { use std::sync::Arc; - use common_query::error; - use datatypes::vectors::StringVector; + use arrow_schema::Field; + use datafusion::arrow::array::{Array, AsArray, StringViewArray}; + use datafusion_common::config::ConfigOptions; use super::*; @@ -144,56 +117,59 @@ mod tests { let vec0 = vec![1.0, 2.0, 3.0]; let vec1 = vec![1.0, 1.0]; - let (len0, len1) = (vec0.len(), vec1.len()); - let input0 = Arc::new(StringVector::from(vec![Some(format!("{vec0:?}"))])); - let input1 = Arc::new(StringVector::from(vec![Some(format!("{vec1:?}"))])); + let input0 = Arc::new(StringViewArray::from(vec![Some(format!("{vec0:?}"))])); + let input1 = Arc::new(StringViewArray::from(vec![Some(format!("{vec1:?}"))])); - let err = func - .eval(&FunctionContext::default(), &[input0, input1]) - .unwrap_err(); + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)], + arg_fields: vec![], + number_rows: 4, + return_field: Arc::new(Field::new("x", DataType::BinaryView, false)), + config_options: Arc::new(ConfigOptions::new()), + }; + let e = func.invoke_with_args(args).unwrap_err(); + assert!( + e.to_string() + .starts_with("Execution error: vectors length not match: vec_mul") + ); - match err { - error::Error::InvalidFuncArgs { err_msg, .. } => { - assert_eq!( - err_msg, - format!( - "The length of the vectors must match for multiplying, have: {} vs {}", - len0, len1 - ) - ) - } - _ => unreachable!(), - } - - let input0 = Arc::new(StringVector::from(vec![ + let input0 = Arc::new(StringViewArray::from(vec![ Some("[1.0,2.0,3.0]".to_string()), Some("[8.0,10.0,12.0]".to_string()), Some("[7.0,8.0,9.0]".to_string()), None, ])); - let input1 = Arc::new(StringVector::from(vec![ + let input1 = Arc::new(StringViewArray::from(vec![ Some("[1.0,1.0,1.0]".to_string()), Some("[2.0,2.0,2.0]".to_string()), None, Some("[3.0,3.0,3.0]".to_string()), ])); + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)], + arg_fields: vec![], + number_rows: 4, + return_field: Arc::new(Field::new("x", DataType::BinaryView, false)), + config_options: Arc::new(ConfigOptions::new()), + }; let result = func - .eval(&FunctionContext::default(), &[input0, input1]) + .invoke_with_args(args) + .and_then(|x| x.to_array(4)) .unwrap(); - let result = result.as_ref(); + let result = result.as_binary_view(); assert_eq!(result.len(), 4); assert_eq!( - result.get_ref(0).as_binary().unwrap(), - Some(veclit_to_binlit(&[1.0, 2.0, 3.0]).as_slice()) + result.value(0), + veclit_to_binlit(&[1.0, 2.0, 3.0]).as_slice() ); assert_eq!( - result.get_ref(1).as_binary().unwrap(), - Some(veclit_to_binlit(&[16.0, 20.0, 24.0]).as_slice()) + result.value(1), + veclit_to_binlit(&[16.0, 20.0, 24.0]).as_slice() ); - assert!(result.get_ref(2).is_null()); - assert!(result.get_ref(3).is_null()); + assert!(result.is_null(2)); + assert!(result.is_null(3)); } } diff --git a/src/common/function/src/scalars/vector/vector_norm.rs b/src/common/function/src/scalars/vector/vector_norm.rs index bd82efbdbc..9a7d19a371 100644 --- a/src/common/function/src/scalars/vector/vector_norm.rs +++ b/src/common/function/src/scalars/vector/vector_norm.rs @@ -12,20 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::borrow::Cow; use std::fmt::Display; -use common_query::error::{InvalidFuncArgsSnafu, Result}; +use common_query::error::Result; use datafusion::arrow::datatypes::DataType; +use datafusion::logical_expr::ColumnarValue; use datafusion::logical_expr_common::type_coercion::aggregates::{BINARYS, STRINGS}; -use datafusion_expr::{Signature, TypeSignature, Volatility}; -use datatypes::scalars::ScalarVectorBuilder; -use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef}; +use datafusion_common::ScalarValue; +use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility}; use nalgebra::DVectorView; -use snafu::ensure; -use crate::function::{Function, FunctionContext}; -use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit}; +use crate::function::Function; +use crate::scalars::vector::VectorCalculator; +use crate::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit}; const NAME: &str = "vec_norm"; @@ -53,7 +52,7 @@ impl Function for VectorNormFunction { } fn return_type(&self, _: &[DataType]) -> Result { - Ok(DataType::Binary) + Ok(DataType::BinaryView) } fn signature(&self) -> Signature { @@ -66,55 +65,27 @@ impl Function for VectorNormFunction { ) } - fn eval( + fn invoke_with_args( &self, - _func_ctx: &FunctionContext, - columns: &[VectorRef], - ) -> common_query::error::Result { - ensure!( - columns.len() == 1, - InvalidFuncArgsSnafu { - err_msg: format!( - "The length of the args is not correct, expect exactly one, have: {}", - columns.len() - ) - } - ); - let arg0 = &columns[0]; - - let len = arg0.len(); - let mut result = BinaryVectorBuilder::with_capacity(len); - if len == 0 { - return Ok(result.to_vector()); - } - - let arg0_const = as_veclit_if_const(arg0)?; - - for i in 0..len { - let arg0 = match arg0_const.as_ref() { - Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())), - None => as_veclit(arg0.get_ref(i))?, - }; - let Some(arg0) = arg0 else { - result.push_null(); - continue; + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let body = |v0: &ScalarValue| -> datafusion_common::Result { + let v0 = as_veclit(v0)?; + let Some(v0) = v0 else { + return Ok(ScalarValue::BinaryView(None)); }; - let vec0 = DVectorView::from_slice(&arg0, arg0.len()); - let vec1 = DVectorView::from_slice(&arg0, arg0.len()); - let vec2scalar = vec1.component_mul(&vec0); - let scalar_var = vec2scalar.sum().sqrt(); + let v0 = DVectorView::from_slice(&v0, v0.len()); + let result = + veclit_to_binlit(v0.unscale(v0.component_mul(&v0).sum().sqrt()).as_slice()); + Ok(ScalarValue::BinaryView(Some(result))) + }; - let vec = DVectorView::from_slice(&arg0, arg0.len()); - // Use unscale to avoid division by zero and keep more precision as possible - let vec_res = vec.unscale(scalar_var); - - let veclit = vec_res.as_slice(); - let binlit = veclit_to_binlit(veclit); - result.push(Some(&binlit)); - } - - Ok(result.to_vector()) + let calculator = VectorCalculator { + name: self.name(), + func: body, + }; + calculator.invoke_with_single_argument(args) } } @@ -128,7 +99,9 @@ impl Display for VectorNormFunction { mod tests { use std::sync::Arc; - use datatypes::vectors::StringVector; + use arrow_schema::Field; + use datafusion::arrow::array::{Array, AsArray, StringViewArray}; + use datafusion_common::config::ConfigOptions; use super::*; @@ -136,7 +109,7 @@ mod tests { fn test_vec_norm() { let func = VectorNormFunction; - let input0 = Arc::new(StringVector::from(vec![ + let input0 = Arc::new(StringViewArray::from(vec![ Some("[0.0,2.0,3.0]".to_string()), Some("[1.0,2.0,3.0]".to_string()), Some("[7.0,8.0,9.0]".to_string()), @@ -144,26 +117,36 @@ mod tests { None, ])); - let result = func.eval(&FunctionContext::default(), &[input0]).unwrap(); + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(input0)], + arg_fields: vec![], + number_rows: 5, + return_field: Arc::new(Field::new("x", DataType::BinaryView, false)), + config_options: Arc::new(ConfigOptions::new()), + }; + let result = func + .invoke_with_args(args) + .and_then(|x| x.to_array(5)) + .unwrap(); - let result = result.as_ref(); + let result = result.as_binary_view(); assert_eq!(result.len(), 5); assert_eq!( - result.get_ref(0).as_binary().unwrap(), - Some(veclit_to_binlit(&[0.0, 0.5547002, 0.8320503]).as_slice()) + result.value(0), + veclit_to_binlit(&[0.0, 0.5547002, 0.8320503]).as_slice() ); assert_eq!( - result.get_ref(1).as_binary().unwrap(), - Some(veclit_to_binlit(&[0.26726124, 0.5345225, 0.8017837]).as_slice()) + result.value(1), + veclit_to_binlit(&[0.26726124, 0.5345225, 0.8017837]).as_slice() ); assert_eq!( - result.get_ref(2).as_binary().unwrap(), - Some(veclit_to_binlit(&[0.5025707, 0.5743665, 0.64616233]).as_slice()) + result.value(2), + veclit_to_binlit(&[0.5025707, 0.5743665, 0.64616233]).as_slice() ); assert_eq!( - result.get_ref(3).as_binary().unwrap(), - Some(veclit_to_binlit(&[0.5025707, -0.5743665, 0.64616233]).as_slice()) + result.value(3), + veclit_to_binlit(&[0.5025707, -0.5743665, 0.64616233]).as_slice() ); - assert!(result.get_ref(4).is_null()); + assert!(result.is_null(4)); } } diff --git a/src/common/function/src/scalars/vector/vector_sub.rs b/src/common/function/src/scalars/vector/vector_sub.rs index a849df2a3d..43b33d152f 100644 --- a/src/common/function/src/scalars/vector/vector_sub.rs +++ b/src/common/function/src/scalars/vector/vector_sub.rs @@ -15,17 +15,17 @@ use std::borrow::Cow; use std::fmt::Display; -use common_query::error::{InvalidFuncArgsSnafu, Result}; -use datafusion_expr::Signature; -use datatypes::arrow::datatypes::DataType; -use datatypes::scalars::ScalarVectorBuilder; -use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef}; +use common_query::error::Result; +use datafusion::arrow::datatypes::DataType; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{DataFusionError, ScalarValue}; +use datafusion_expr::{ScalarFunctionArgs, Signature}; use nalgebra::DVectorView; -use snafu::ensure; -use crate::function::{Function, FunctionContext}; +use crate::function::Function; use crate::helper; -use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit}; +use crate::scalars::vector::VectorCalculator; +use crate::scalars::vector::impl_conv::veclit_to_binlit; const NAME: &str = "vec_sub"; @@ -51,7 +51,7 @@ impl Function for VectorSubFunction { } fn return_type(&self, _: &[DataType]) -> Result { - Ok(DataType::Binary) + Ok(DataType::BinaryView) } fn signature(&self) -> Signature { @@ -61,66 +61,36 @@ impl Function for VectorSubFunction { ) } - fn eval( + fn invoke_with_args( &self, - _func_ctx: &FunctionContext, - columns: &[VectorRef], - ) -> common_query::error::Result { - ensure!( - columns.len() == 2, - InvalidFuncArgsSnafu { - err_msg: format!( - "The length of the args is not correct, expect exactly two, have: {}", - columns.len() - ) - } - ); - let arg0 = &columns[0]; - let arg1 = &columns[1]; + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let body = |v0: &Option>, + v1: &Option>| + -> datafusion_common::Result { + let result = if let (Some(v0), Some(v1)) = (v0, v1) { + let v0 = DVectorView::from_slice(v0, v0.len()); + let v1 = DVectorView::from_slice(v1, v1.len()); + if v0.len() != v1.len() { + return Err(DataFusionError::Execution(format!( + "vectors length not match: {}", + self.name() + ))); + } - ensure!( - arg0.len() == arg1.len(), - InvalidFuncArgsSnafu { - err_msg: format!( - "The lengths of the vector are not aligned, args 0: {}, args 1: {}", - arg0.len(), - arg1.len(), - ) - } - ); - - let len = arg0.len(); - let mut result = BinaryVectorBuilder::with_capacity(len); - if len == 0 { - return Ok(result.to_vector()); - } - - let arg0_const = as_veclit_if_const(arg0)?; - let arg1_const = as_veclit_if_const(arg1)?; - - for i in 0..len { - let arg0 = match arg0_const.as_ref() { - Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())), - None => as_veclit(arg0.get_ref(i))?, + let result = veclit_to_binlit((v0 - v1).as_slice()); + Some(result) + } else { + None }; - let arg1 = match arg1_const.as_ref() { - Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())), - None => as_veclit(arg1.get_ref(i))?, - }; - let (Some(arg0), Some(arg1)) = (arg0, arg1) else { - result.push_null(); - continue; - }; - let vec0 = DVectorView::from_slice(&arg0, arg0.len()); - let vec1 = DVectorView::from_slice(&arg1, arg1.len()); + Ok(ScalarValue::BinaryView(result)) + }; - let vec_res = vec0 - vec1; - let veclit = vec_res.as_slice(); - let binlit = veclit_to_binlit(veclit); - result.push(Some(&binlit)); - } - - Ok(result.to_vector()) + let calculator = VectorCalculator { + name: self.name(), + func: body, + }; + calculator.invoke_with_vectors(args) } } @@ -134,8 +104,9 @@ impl Display for VectorSubFunction { mod tests { use std::sync::Arc; - use common_query::error::Error; - use datatypes::vectors::StringVector; + use arrow_schema::Field; + use datafusion::arrow::array::{Array, ArrayRef, AsArray, StringViewArray}; + use datafusion_common::config::ConfigOptions; use super::*; @@ -143,63 +114,71 @@ mod tests { fn test_sub() { let func = VectorSubFunction; - let input0 = Arc::new(StringVector::from(vec![ + let input0: ArrayRef = Arc::new(StringViewArray::from(vec![ Some("[1.0,2.0,3.0]".to_string()), Some("[4.0,5.0,6.0]".to_string()), None, Some("[2.0,3.0,3.0]".to_string()), ])); - let input1 = Arc::new(StringVector::from(vec![ + let input1: ArrayRef = Arc::new(StringViewArray::from(vec![ Some("[1.0,1.0,1.0]".to_string()), Some("[6.0,5.0,4.0]".to_string()), Some("[3.0,2.0,2.0]".to_string()), None, ])); + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)], + arg_fields: vec![], + number_rows: 4, + return_field: Arc::new(Field::new("x", DataType::BinaryView, false)), + config_options: Arc::new(ConfigOptions::new()), + }; let result = func - .eval(&FunctionContext::default(), &[input0, input1]) + .invoke_with_args(args) + .and_then(|x| x.to_array(4)) .unwrap(); - let result = result.as_ref(); + let result = result.as_binary_view(); assert_eq!(result.len(), 4); assert_eq!( - result.get_ref(0).as_binary().unwrap(), - Some(veclit_to_binlit(&[0.0, 1.0, 2.0]).as_slice()) + result.value(0), + veclit_to_binlit(&[0.0, 1.0, 2.0]).as_slice() ); assert_eq!( - result.get_ref(1).as_binary().unwrap(), - Some(veclit_to_binlit(&[-2.0, 0.0, 2.0]).as_slice()) + result.value(1), + veclit_to_binlit(&[-2.0, 0.0, 2.0]).as_slice() ); - assert!(result.get_ref(2).is_null()); - assert!(result.get_ref(3).is_null()); + assert!(result.is_null(2)); + assert!(result.is_null(3)); } #[test] fn test_sub_error() { let func = VectorSubFunction; - let input0 = Arc::new(StringVector::from(vec![ + let input0: ArrayRef = Arc::new(StringViewArray::from(vec![ Some("[1.0,2.0,3.0]".to_string()), Some("[4.0,5.0,6.0]".to_string()), None, Some("[2.0,3.0,3.0]".to_string()), ])); - let input1 = Arc::new(StringVector::from(vec![ + let input1: ArrayRef = Arc::new(StringViewArray::from(vec![ Some("[1.0,1.0,1.0]".to_string()), Some("[6.0,5.0,4.0]".to_string()), Some("[3.0,2.0,2.0]".to_string()), ])); - let result = func.eval(&FunctionContext::default(), &[input0, input1]); - - match result { - Err(Error::InvalidFuncArgs { err_msg, .. }) => { - assert_eq!( - err_msg, - "The lengths of the vector are not aligned, args 0: 4, args 1: 3" - ) - } - _ => unreachable!(), - } + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)], + arg_fields: vec![], + number_rows: 4, + return_field: Arc::new(Field::new("x", DataType::BinaryView, false)), + config_options: Arc::new(ConfigOptions::new()), + }; + let e = func.invoke_with_args(args).unwrap_err(); + assert!(e.to_string().starts_with( + "Internal error: Arguments has mixed length. Expected length: 4, found length: 3." + )); } } diff --git a/src/common/function/src/scalars/vector/vector_subvector.rs b/src/common/function/src/scalars/vector/vector_subvector.rs index cda358d0f9..239edaaa93 100644 --- a/src/common/function/src/scalars/vector/vector_subvector.rs +++ b/src/common/function/src/scalars/vector/vector_subvector.rs @@ -12,18 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::borrow::Cow; use std::fmt::Display; +use std::sync::Arc; use common_query::error::{InvalidFuncArgsSnafu, Result}; -use datafusion_expr::{Signature, TypeSignature, Volatility}; +use datafusion::arrow::array::{Array, AsArray, BinaryViewBuilder}; +use datafusion::arrow::datatypes::Int64Type; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{ScalarValue, utils}; +use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility}; use datatypes::arrow::datatypes::DataType; -use datatypes::scalars::ScalarVectorBuilder; -use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef}; use snafu::ensure; -use crate::function::{Function, FunctionContext}; -use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit}; +use crate::function::Function; +use crate::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit}; const NAME: &str = "vec_subvector"; @@ -52,7 +54,7 @@ impl Function for VectorSubvectorFunction { } fn return_type(&self, _: &[DataType]) -> Result { - Ok(DataType::Binary) + Ok(DataType::BinaryView) } fn signature(&self) -> Signature { @@ -65,50 +67,28 @@ impl Function for VectorSubvectorFunction { ) } - fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { - ensure!( - columns.len() == 3, - InvalidFuncArgsSnafu { - err_msg: format!( - "The length of the args is not correct, expect exactly three, have: {}", - columns.len() - ) - } - ); - - let arg0 = &columns[0]; - let arg1 = &columns[1]; - let arg2 = &columns[2]; - - ensure!( - arg0.len() == arg1.len() && arg1.len() == arg2.len(), - InvalidFuncArgsSnafu { - err_msg: format!( - "The lengths of the vector are not aligned, args 0: {}, args 1: {}, args 2: {}", - arg0.len(), - arg1.len(), - arg2.len() - ) - } - ); + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let args = ColumnarValue::values_to_arrays(&args.args)?; + let [arg0, arg1, arg2] = utils::take_function_args(self.name(), args)?; + let arg1 = arg1.as_primitive::(); + let arg2 = arg2.as_primitive::(); let len = arg0.len(); - let mut result = BinaryVectorBuilder::with_capacity(len); + let mut builder = BinaryViewBuilder::with_capacity(len); if len == 0 { - return Ok(result.to_vector()); + return Ok(ColumnarValue::Array(Arc::new(builder.finish()))); } - let arg0_const = as_veclit_if_const(arg0)?; - for i in 0..len { - let arg0 = match arg0_const.as_ref() { - Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())), - None => as_veclit(arg0.get_ref(i))?, - }; - let arg1 = arg1.get(i).as_i64(); - let arg2 = arg2.get(i).as_i64(); + let v = ScalarValue::try_from_array(&arg0, i)?; + let arg0 = as_veclit(&v)?; + let arg1 = arg1.is_valid(i).then(|| arg1.value(i)); + let arg2 = arg2.is_valid(i).then(|| arg2.value(i)); let (Some(arg0), Some(arg1), Some(arg2)) = (arg0, arg1, arg2) else { - result.push_null(); + builder.append_null(); continue; }; @@ -126,10 +106,10 @@ impl Function for VectorSubvectorFunction { let subvector = &arg0[arg1 as usize..arg2 as usize]; let binlit = veclit_to_binlit(subvector); - result.push(Some(&binlit)); + builder.append_value(&binlit); } - Ok(result.to_vector()) + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) } } @@ -143,89 +123,102 @@ impl Display for VectorSubvectorFunction { mod tests { use std::sync::Arc; - use common_query::error::Error; - use datatypes::vectors::{Int64Vector, StringVector}; + use arrow_schema::Field; + use datafusion::arrow::array::{ArrayRef, Int64Array, StringViewArray}; + use datafusion_common::config::ConfigOptions; use super::*; - use crate::function::FunctionContext; + #[test] fn test_subvector() { let func = VectorSubvectorFunction; - let input0 = Arc::new(StringVector::from(vec![ + let input0: ArrayRef = Arc::new(StringViewArray::from(vec![ Some("[1.0, 2.0, 3.0, 4.0, 5.0]".to_string()), Some("[6.0, 7.0, 8.0, 9.0, 10.0]".to_string()), None, Some("[11.0, 12.0, 13.0]".to_string()), ])); - let input1 = Arc::new(Int64Vector::from(vec![Some(1), Some(0), Some(0), Some(1)])); - let input2 = Arc::new(Int64Vector::from(vec![Some(3), Some(5), Some(2), Some(3)])); + let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(0), Some(0), Some(1)])); + let input2: ArrayRef = Arc::new(Int64Array::from(vec![Some(3), Some(5), Some(2), Some(3)])); + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(input0), + ColumnarValue::Array(input1), + ColumnarValue::Array(input2), + ], + arg_fields: vec![], + number_rows: 5, + return_field: Arc::new(Field::new("x", DataType::BinaryView, false)), + config_options: Arc::new(ConfigOptions::new()), + }; let result = func - .eval(&FunctionContext::default(), &[input0, input1, input2]) + .invoke_with_args(args) + .and_then(|x| x.to_array(5)) .unwrap(); - let result = result.as_ref(); + let result = result.as_binary_view(); assert_eq!(result.len(), 4); + assert_eq!(result.value(0), veclit_to_binlit(&[2.0, 3.0]).as_slice()); assert_eq!( - result.get_ref(0).as_binary().unwrap(), - Some(veclit_to_binlit(&[2.0, 3.0]).as_slice()) - ); - assert_eq!( - result.get_ref(1).as_binary().unwrap(), - Some(veclit_to_binlit(&[6.0, 7.0, 8.0, 9.0, 10.0]).as_slice()) - ); - assert!(result.get_ref(2).is_null()); - assert_eq!( - result.get_ref(3).as_binary().unwrap(), - Some(veclit_to_binlit(&[12.0, 13.0]).as_slice()) + result.value(1), + veclit_to_binlit(&[6.0, 7.0, 8.0, 9.0, 10.0]).as_slice() ); + assert!(result.is_null(2)); + assert_eq!(result.value(3), veclit_to_binlit(&[12.0, 13.0]).as_slice()); } #[test] fn test_subvector_error() { let func = VectorSubvectorFunction; - let input0 = Arc::new(StringVector::from(vec![ + let input0: ArrayRef = Arc::new(StringViewArray::from(vec![ Some("[1.0, 2.0, 3.0]".to_string()), Some("[4.0, 5.0, 6.0]".to_string()), ])); - let input1 = Arc::new(Int64Vector::from(vec![Some(1), Some(2)])); - let input2 = Arc::new(Int64Vector::from(vec![Some(3)])); + let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(2)])); + let input2: ArrayRef = Arc::new(Int64Array::from(vec![Some(3)])); - let result = func.eval(&FunctionContext::default(), &[input0, input1, input2]); - - match result { - Err(Error::InvalidFuncArgs { err_msg, .. }) => { - assert_eq!( - err_msg, - "The lengths of the vector are not aligned, args 0: 2, args 1: 2, args 2: 1" - ) - } - _ => unreachable!(), - } + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(input0), + ColumnarValue::Array(input1), + ColumnarValue::Array(input2), + ], + arg_fields: vec![], + number_rows: 3, + return_field: Arc::new(Field::new("x", DataType::BinaryView, false)), + config_options: Arc::new(ConfigOptions::new()), + }; + let e = func.invoke_with_args(args).unwrap_err(); + assert!(e.to_string().starts_with( + "Internal error: Arguments has mixed length. Expected length: 2, found length: 1." + )); } #[test] fn test_subvector_invalid_indices() { let func = VectorSubvectorFunction; - let input0 = Arc::new(StringVector::from(vec![ + let input0 = Arc::new(StringViewArray::from(vec![ Some("[1.0, 2.0, 3.0]".to_string()), Some("[4.0, 5.0, 6.0]".to_string()), ])); - let input1 = Arc::new(Int64Vector::from(vec![Some(1), Some(3)])); - let input2 = Arc::new(Int64Vector::from(vec![Some(3), Some(4)])); + let input1 = Arc::new(Int64Array::from(vec![Some(1), Some(3)])); + let input2 = Arc::new(Int64Array::from(vec![Some(3), Some(4)])); - let result = func.eval(&FunctionContext::default(), &[input0, input1, input2]); - - match result { - Err(Error::InvalidFuncArgs { err_msg, .. }) => { - assert_eq!( - err_msg, - "Invalid start and end indices: start=3, end=4, vec_len=3" - ) - } - _ => unreachable!(), - } + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(input0), + ColumnarValue::Array(input1), + ColumnarValue::Array(input2), + ], + arg_fields: vec![], + number_rows: 3, + return_field: Arc::new(Field::new("x", DataType::BinaryView, false)), + config_options: Arc::new(ConfigOptions::new()), + }; + let e = func.invoke_with_args(args).unwrap_err(); + assert!(e.to_string().starts_with("External error: Invalid function args: Invalid start and end indices: start=3, end=4, vec_len=3")); } } diff --git a/src/query/src/tests/vec_product_test.rs b/src/query/src/tests/vec_product_test.rs index 6f49dd711e..53eb0d3272 100644 --- a/src/query/src/tests/vec_product_test.rs +++ b/src/query/src/tests/vec_product_test.rs @@ -12,11 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::borrow::Cow; - -use common_function::scalars::vector::impl_conv::{ - as_veclit, as_veclit_if_const, veclit_to_binlit, -}; +use common_function::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit}; +use datafusion_common::ScalarValue; use datatypes::prelude::Value; use nalgebra::{Const, DVectorView, Dyn, OVector}; @@ -35,14 +32,10 @@ async fn test_vec_product_aggregator() -> Result<(), common_query::error::Error> let sql = "SELECT vector FROM vectors"; let vectors = exec_selection(engine, sql).await; - let column = vectors[0].column(0); - let vector_const = as_veclit_if_const(column)?; - + let column = vectors[0].column(0).to_arrow_array(); for i in 0..column.len() { - let vector = match vector_const.as_ref() { - Some(vector) => Some(Cow::Borrowed(vector.as_ref())), - None => as_veclit(column.get_ref(i))?, - }; + let v = ScalarValue::try_from_array(&column, i)?; + let vector = as_veclit(&v)?; let Some(vector) = vector else { expected_value = None; break; diff --git a/src/query/src/tests/vec_sum_test.rs b/src/query/src/tests/vec_sum_test.rs index 5727a24f2e..2c488c3c53 100644 --- a/src/query/src/tests/vec_sum_test.rs +++ b/src/query/src/tests/vec_sum_test.rs @@ -12,12 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::borrow::Cow; use std::ops::AddAssign; -use common_function::scalars::vector::impl_conv::{ - as_veclit, as_veclit_if_const, veclit_to_binlit, -}; +use common_function::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit}; +use datafusion_common::ScalarValue; use datatypes::prelude::Value; use nalgebra::{Const, DVectorView, Dyn, OVector}; @@ -36,14 +34,10 @@ async fn test_vec_sum_aggregator() -> Result<(), common_query::error::Error> { let sql = "SELECT vector FROM vectors"; let vectors = exec_selection(engine, sql).await; - let column = vectors[0].column(0); - let vector_const = as_veclit_if_const(column)?; - + let column = vectors[0].column(0).to_arrow_array(); for i in 0..column.len() { - let vector = match vector_const.as_ref() { - Some(vector) => Some(Cow::Borrowed(vector.as_ref())), - None => as_veclit(column.get_ref(i))?, - }; + let v = ScalarValue::try_from_array(&column, i)?; + let vector = as_veclit(&v)?; let Some(vector) = vector else { expected_value = None; break; diff --git a/tests/cases/standalone/common/function/arithmetic.result b/tests/cases/standalone/common/function/arithmetic.result index 91087bec17..01d2c7e062 100644 --- a/tests/cases/standalone/common/function/arithmetic.result +++ b/tests/cases/standalone/common/function/arithmetic.result @@ -76,13 +76,7 @@ SELECT CLAMP(0.5, 0, 1); SELECT CLAMP(10, 1, 0); -Error: 3001(EngineExecuteQuery), Invalid function args: The second arg should be less than or equal to the third arg, have: PrimitiveVector { array: PrimitiveArray -[ - 1, -] }, PrimitiveVector { array: PrimitiveArray -[ - 0, -] } +Error: 3001(EngineExecuteQuery), Execution error: min '1' > max '0' SELECT CLAMP_MIN(10, 12); diff --git a/tests/cases/standalone/common/function/geo.result b/tests/cases/standalone/common/function/geo.result index 607908b3e0..c9a064d405 100644 --- a/tests/cases/standalone/common/function/geo.result +++ b/tests/cases/standalone/common/function/geo.result @@ -240,11 +240,11 @@ SELECT geohash(37.76938, -122.3889, 11); SELECT geohash(37.76938, -122.3889, 100); -Error: 3001(EngineExecuteQuery), Invalid geohash resolution 100, expect value: [1, 12] +Error: 3001(EngineExecuteQuery), Execution error: Invalid geohash resolution 100, valid value range: [1, 12] SELECT geohash(37.76938, -122.3889, -1); -Error: 3001(EngineExecuteQuery), Invalid geohash resolution -1, expect value: [1, 12] +Error: 3001(EngineExecuteQuery), Cast error: Can't cast value -1 to type UInt8 SELECT geohash(37.76938, -122.3889, 11::Int8); diff --git a/tests/cases/standalone/common/promql/scalar.result b/tests/cases/standalone/common/promql/scalar.result index 12a4481ad0..c5c3e5ebd1 100644 --- a/tests/cases/standalone/common/promql/scalar.result +++ b/tests/cases/standalone/common/promql/scalar.result @@ -375,17 +375,7 @@ TQL EVAL (0, 15, '5s') clamp(host, 6 - 6, 6 + 6); -- SQLNESS SORT_RESULT 3 1 TQL EVAL (0, 15, '5s') clamp(host, 12, 0); -Error: 3001(EngineExecuteQuery), Invalid function args: The second arg should be less than or equal to the third arg, have: PrimitiveVector { array: PrimitiveArray -[ - 12.0, - 0.0, - 0.0, - 0.0, - 12.0, - 12.0, -[ -] }, PrimitiveVector { array: PrimitiveArray -] } +Error: 3001(EngineExecuteQuery), Execution error: min '12' > max '0' -- SQLNESS SORT_RESULT 3 1 TQL EVAL (0, 15, '5s') clamp(host{host="host1"}, -1, 6); diff --git a/tests/cases/standalone/common/types/vector/vector.result b/tests/cases/standalone/common/types/vector/vector.result index 3d40f4f8b2..bc9a38c816 100644 --- a/tests/cases/standalone/common/types/vector/vector.result +++ b/tests/cases/standalone/common/types/vector/vector.result @@ -99,7 +99,7 @@ SELECT round(vec_cos_distance(v, v), 2) FROM t; -- Unexpected dimension -- SELECT vec_cos_distance(v, '[1.0]') FROM t; -Error: 3001(EngineExecuteQuery), Invalid function args: The length of the vectors must match to calculate distance, have: 3 vs 1 +Error: 3001(EngineExecuteQuery), Execution error: vectors length not match: vec_cos_distance -- Invalid type -- SELECT vec_cos_distance(v, 1.0) FROM t; @@ -174,7 +174,7 @@ SELECT round(vec_l2sq_distance(v, v), 2) FROM t; -- Unexpected dimension -- SELECT vec_l2sq_distance(v, '[1.0]') FROM t; -Error: 3001(EngineExecuteQuery), Invalid function args: The length of the vectors must match to calculate distance, have: 3 vs 1 +Error: 3001(EngineExecuteQuery), Execution error: vectors length not match: vec_l2sq_distance -- Invalid type -- SELECT vec_l2sq_distance(v, 1.0) FROM t; @@ -249,7 +249,7 @@ SELECT round(vec_dot_product(v, v), 2) FROM t; -- Unexpected dimension -- SELECT vec_dot_product(v, '[1.0]') FROM t; -Error: 3001(EngineExecuteQuery), Invalid function args: The length of the vectors must match to calculate distance, have: 3 vs 1 +Error: 3001(EngineExecuteQuery), Execution error: vectors length not match: vec_dot_product -- Invalid type -- SELECT vec_dot_product(v, 1.0) FROM t;