refactor: rewrite some UDFs to DataFusion style (part 2) (#6967)

* refactor: rewrite some UDFs to DataFusion style (part 2)

Signed-off-by: luofucong <luofc@foxmail.com>

* deal with vector UDFs `(scalar, scalar)` situation, and try getting the scalar value reference everytime

Signed-off-by: luofucong <luofc@foxmail.com>

* reduce some vector literal parsing

Signed-off-by: luofucong <luofc@foxmail.com>

* fix ci

Signed-off-by: luofucong <luofc@foxmail.com>

---------

Signed-off-by: luofucong <luofc@foxmail.com>
This commit is contained in:
LFC
2025-09-18 14:37:27 +08:00
committed by GitHub
parent e26b98f452
commit cbe0cf4a74
30 changed files with 1663 additions and 1877 deletions

View File

@@ -14,14 +14,15 @@
use std::fmt;
use common_query::error::{ArrowComputeSnafu, IntoVectorSnafu, InvalidFuncArgsSnafu, Result};
use datafusion_expr::Signature;
use common_query::error::{ArrowComputeSnafu, Result};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::utils;
use datafusion_expr::{ScalarFunctionArgs, Signature};
use datatypes::arrow::compute::kernels::numeric;
use datatypes::arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
use datatypes::vectors::{Helper, VectorRef};
use snafu::{ResultExt, ensure};
use snafu::ResultExt;
use crate::function::{Function, FunctionContext};
use crate::function::Function;
use crate::helper;
/// A function adds an interval value to Timestamp, Date, and return the result.
@@ -58,25 +59,15 @@ impl Function for DateAddFunction {
)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect 2, have: {}",
columns.len()
),
}
);
let left = columns[0].to_arrow_array();
let right = columns[1].to_arrow_array();
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(&args.args)?;
let [left, right] = utils::take_function_args(self.name(), args)?;
let result = numeric::add(&left, &right).context(ArrowComputeSnafu)?;
let arrow_type = result.data_type().clone();
Helper::try_into_vector(result).context(IntoVectorSnafu {
data_type: arrow_type,
})
Ok(ColumnarValue::Array(result))
}
}
@@ -90,12 +81,14 @@ impl fmt::Display for DateAddFunction {
mod tests {
use std::sync::Arc;
use datafusion_expr::{TypeSignature, Volatility};
use datatypes::arrow::datatypes::IntervalDayTime;
use datatypes::value::Value;
use datatypes::vectors::{
DateVector, IntervalDayTimeVector, IntervalYearMonthVector, TimestampSecondVector,
use arrow_schema::Field;
use datafusion::arrow::array::{
Array, AsArray, Date32Array, IntervalDayTimeArray, IntervalYearMonthArray,
TimestampSecondArray,
};
use datafusion::arrow::datatypes::{Date32Type, IntervalDayTime, TimestampSecondType};
use datafusion_common::config::ConfigOptions;
use datafusion_expr::{TypeSignature, Volatility};
use super::{DateAddFunction, *};
@@ -142,25 +135,37 @@ mod tests {
];
let results = [Some(124), None, Some(45), None];
let time_vector = TimestampSecondVector::from(times.clone());
let interval_vector = IntervalDayTimeVector::from_vec(intervals);
let args: Vec<VectorRef> = vec![Arc::new(time_vector), Arc::new(interval_vector)];
let vector = f.eval(&FunctionContext::default(), &args).unwrap();
let args = vec![
ColumnarValue::Array(Arc::new(TimestampSecondArray::from(times.clone()))),
ColumnarValue::Array(Arc::new(IntervalDayTimeArray::from(intervals))),
];
let vector = f
.invoke_with_args(ScalarFunctionArgs {
args,
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new(
"x",
DataType::Timestamp(TimeUnit::Second, None),
true,
)),
config_options: Arc::new(ConfigOptions::new()),
})
.and_then(|v| ColumnarValue::values_to_arrays(&[v]))
.map(|mut a| a.remove(0))
.unwrap();
let vector = vector.as_primitive::<TimestampSecondType>();
assert_eq!(4, vector.len());
for (i, _t) in times.iter().enumerate() {
let v = vector.get(i);
let result = results.get(i).unwrap();
if result.is_none() {
assert_eq!(Value::Null, v);
continue;
}
match v {
Value::Timestamp(ts) => {
assert_eq!(ts.value(), result.unwrap());
}
_ => unreachable!(),
if let Some(x) = result {
assert!(vector.is_valid(i));
assert_eq!(vector.value(i), *x);
} else {
assert!(vector.is_null(i));
}
}
}
@@ -174,25 +179,37 @@ mod tests {
let intervals = vec![1, 2, 3, 1];
let results = [Some(154), None, Some(131), None];
let date_vector = DateVector::from(dates.clone());
let interval_vector = IntervalYearMonthVector::from_vec(intervals);
let args: Vec<VectorRef> = vec![Arc::new(date_vector), Arc::new(interval_vector)];
let vector = f.eval(&FunctionContext::default(), &args).unwrap();
let args = vec![
ColumnarValue::Array(Arc::new(Date32Array::from(dates.clone()))),
ColumnarValue::Array(Arc::new(IntervalYearMonthArray::from(intervals))),
];
let vector = f
.invoke_with_args(ScalarFunctionArgs {
args,
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new(
"x",
DataType::Timestamp(TimeUnit::Second, None),
true,
)),
config_options: Arc::new(ConfigOptions::new()),
})
.and_then(|v| ColumnarValue::values_to_arrays(&[v]))
.map(|mut a| a.remove(0))
.unwrap();
let vector = vector.as_primitive::<Date32Type>();
assert_eq!(4, vector.len());
for (i, _t) in dates.iter().enumerate() {
let v = vector.get(i);
let result = results.get(i).unwrap();
if result.is_none() {
assert_eq!(Value::Null, v);
continue;
}
match v {
Value::Date(date) => {
assert_eq!(date.val(), result.unwrap());
}
_ => unreachable!(),
if let Some(x) = result {
assert!(vector.is_valid(i));
assert_eq!(vector.value(i), *x);
} else {
assert!(vector.is_null(i));
}
}
}

View File

@@ -14,14 +14,15 @@
use std::fmt;
use common_query::error::{ArrowComputeSnafu, IntoVectorSnafu, InvalidFuncArgsSnafu, Result};
use datafusion_expr::Signature;
use common_query::error::{ArrowComputeSnafu, Result};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::utils;
use datafusion_expr::{ScalarFunctionArgs, Signature};
use datatypes::arrow::compute::kernels::numeric;
use datatypes::arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
use datatypes::vectors::{Helper, VectorRef};
use snafu::{ResultExt, ensure};
use snafu::ResultExt;
use crate::function::{Function, FunctionContext};
use crate::function::Function;
use crate::helper;
/// A function subtracts an interval value to Timestamp, Date, and return the result.
@@ -58,25 +59,15 @@ impl Function for DateSubFunction {
)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect 2, have: {}",
columns.len()
),
}
);
let left = columns[0].to_arrow_array();
let right = columns[1].to_arrow_array();
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(&args.args)?;
let [left, right] = utils::take_function_args(self.name(), args)?;
let result = numeric::sub(&left, &right).context(ArrowComputeSnafu)?;
let arrow_type = result.data_type().clone();
Helper::try_into_vector(result).context(IntoVectorSnafu {
data_type: arrow_type,
})
Ok(ColumnarValue::Array(result))
}
}
@@ -90,12 +81,14 @@ impl fmt::Display for DateSubFunction {
mod tests {
use std::sync::Arc;
use datafusion_expr::{TypeSignature, Volatility};
use datatypes::arrow::datatypes::IntervalDayTime;
use datatypes::value::Value;
use datatypes::vectors::{
DateVector, IntervalDayTimeVector, IntervalYearMonthVector, TimestampSecondVector,
use arrow_schema::Field;
use datafusion::arrow::array::{
Array, AsArray, Date32Array, IntervalDayTimeArray, IntervalYearMonthArray,
TimestampSecondArray,
};
use datafusion::arrow::datatypes::{Date32Type, IntervalDayTime, TimestampSecondType};
use datafusion_common::config::ConfigOptions;
use datafusion_expr::{TypeSignature, Volatility};
use super::{DateSubFunction, *};
@@ -142,25 +135,37 @@ mod tests {
];
let results = [Some(122), None, Some(39), None];
let time_vector = TimestampSecondVector::from(times.clone());
let interval_vector = IntervalDayTimeVector::from_vec(intervals);
let args: Vec<VectorRef> = vec![Arc::new(time_vector), Arc::new(interval_vector)];
let vector = f.eval(&FunctionContext::default(), &args).unwrap();
let args = vec![
ColumnarValue::Array(Arc::new(TimestampSecondArray::from(times.clone()))),
ColumnarValue::Array(Arc::new(IntervalDayTimeArray::from(intervals))),
];
let vector = f
.invoke_with_args(ScalarFunctionArgs {
args,
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new(
"x",
DataType::Timestamp(TimeUnit::Second, None),
true,
)),
config_options: Arc::new(ConfigOptions::new()),
})
.and_then(|v| ColumnarValue::values_to_arrays(&[v]))
.map(|mut a| a.remove(0))
.unwrap();
let vector = vector.as_primitive::<TimestampSecondType>();
assert_eq!(4, vector.len());
for (i, _t) in times.iter().enumerate() {
let v = vector.get(i);
let result = results.get(i).unwrap();
if result.is_none() {
assert_eq!(Value::Null, v);
continue;
}
match v {
Value::Timestamp(ts) => {
assert_eq!(ts.value(), result.unwrap());
}
_ => unreachable!(),
if let Some(x) = result {
assert!(vector.is_valid(i));
assert_eq!(vector.value(i), *x);
} else {
assert!(vector.is_null(i));
}
}
}
@@ -180,25 +185,37 @@ mod tests {
let intervals = vec![1, 2, 3, 1];
let results = [Some(3659), None, Some(1168), None];
let date_vector = DateVector::from(dates.clone());
let interval_vector = IntervalYearMonthVector::from_vec(intervals);
let args: Vec<VectorRef> = vec![Arc::new(date_vector), Arc::new(interval_vector)];
let vector = f.eval(&FunctionContext::default(), &args).unwrap();
let args = vec![
ColumnarValue::Array(Arc::new(Date32Array::from(dates.clone()))),
ColumnarValue::Array(Arc::new(IntervalYearMonthArray::from(intervals))),
];
let vector = f
.invoke_with_args(ScalarFunctionArgs {
args,
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new(
"x",
DataType::Timestamp(TimeUnit::Second, None),
true,
)),
config_options: Arc::new(ConfigOptions::new()),
})
.and_then(|v| ColumnarValue::values_to_arrays(&[v]))
.map(|mut a| a.remove(0))
.unwrap();
let vector = vector.as_primitive::<Date32Type>();
assert_eq!(4, vector.len());
for (i, _t) in dates.iter().enumerate() {
let v = vector.get(i);
let result = results.get(i).unwrap();
if result.is_none() {
assert_eq!(Value::Null, v);
continue;
}
match v {
Value::Date(date) => {
assert_eq!(date.val(), result.unwrap());
}
_ => unreachable!(),
if let Some(x) = result {
assert!(vector.is_valid(i));
assert_eq!(vector.value(i), *x);
} else {
assert!(vector.is_null(i));
}
}
}

View File

@@ -17,62 +17,26 @@ use std::sync::Arc;
use common_error::ext::{BoxedError, PlainError};
use common_error::status_code::StatusCode;
use common_query::error::{self, InvalidFuncArgsSnafu, Result};
use datafusion::arrow::datatypes::Field;
use common_query::error::{self, Result};
use datafusion::arrow::array::{Array, AsArray, ListBuilder, StringViewBuilder};
use datafusion::arrow::datatypes::{DataType, Field, Float64Type, UInt8Type};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{DataFusionError, utils};
use datafusion_expr::type_coercion::aggregates::INTEGERS;
use datafusion_expr::{Signature, TypeSignature, Volatility};
use datatypes::arrow::datatypes::DataType;
use datatypes::prelude::ConcreteDataType;
use datatypes::scalars::{Scalar, ScalarVectorBuilder};
use datatypes::value::{ListValue, Value};
use datatypes::vectors::{ListVectorBuilder, MutableVector, StringVectorBuilder, VectorRef};
use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
use geohash::Coord;
use snafu::{ResultExt, ensure};
use snafu::ResultExt;
use crate::function::{Function, FunctionContext};
use crate::function::Function;
use crate::scalars::geo::helpers;
macro_rules! ensure_resolution_usize {
($v: ident) => {
if !($v > 0 && $v <= 12) {
Err(BoxedError::new(PlainError::new(
format!("Invalid geohash resolution {}, expect value: [1, 12]", $v),
StatusCode::EngineExecuteQuery,
)))
.context(error::ExecuteSnafu)
} else {
Ok($v as usize)
}
};
}
fn try_into_resolution(v: Value) -> Result<usize> {
match v {
Value::Int8(v) => {
ensure_resolution_usize!(v)
}
Value::Int16(v) => {
ensure_resolution_usize!(v)
}
Value::Int32(v) => {
ensure_resolution_usize!(v)
}
Value::Int64(v) => {
ensure_resolution_usize!(v)
}
Value::UInt8(v) => {
ensure_resolution_usize!(v)
}
Value::UInt16(v) => {
ensure_resolution_usize!(v)
}
Value::UInt32(v) => {
ensure_resolution_usize!(v)
}
Value::UInt64(v) => {
ensure_resolution_usize!(v)
}
_ => unreachable!(),
fn ensure_resolution_usize(v: u8) -> datafusion_common::Result<usize> {
if v == 0 || v > 12 {
return Err(DataFusionError::Execution(format!(
"Invalid geohash resolution {v}, valid value range: [1, 12]"
)));
}
Ok(v as usize)
}
/// Function that return geohash string for a given geospatial coordinate.
@@ -109,31 +73,33 @@ impl Function for GeohashFunction {
Signature::one_of(signatures, Volatility::Stable)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 3,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect 3, provided : {}",
columns.len()
),
}
);
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(&args.args)?;
let [lat_vec, lon_vec, resolutions] = utils::take_function_args(self.name(), args)?;
let lat_vec = &columns[0];
let lon_vec = &columns[1];
let resolution_vec = &columns[2];
let lat_vec = helpers::cast::<Float64Type>(&lat_vec)?;
let lat_vec = lat_vec.as_primitive::<Float64Type>();
let lon_vec = helpers::cast::<Float64Type>(&lon_vec)?;
let lon_vec = lon_vec.as_primitive::<Float64Type>();
let resolutions = helpers::cast::<UInt8Type>(&resolutions)?;
let resolutions = resolutions.as_primitive::<UInt8Type>();
let size = lat_vec.len();
let mut results = StringVectorBuilder::with_capacity(size);
let mut builder = StringViewBuilder::with_capacity(size);
for i in 0..size {
let lat = lat_vec.get(i).as_f64_lossy();
let lon = lon_vec.get(i).as_f64_lossy();
let r = try_into_resolution(resolution_vec.get(i))?;
let lat = lat_vec.is_valid(i).then(|| lat_vec.value(i));
let lon = lon_vec.is_valid(i).then(|| lon_vec.value(i));
let r = resolutions
.is_valid(i)
.then(|| ensure_resolution_usize(resolutions.value(i)))
.transpose()?;
let result = match (lat, lon) {
(Some(lat), Some(lon)) => {
let result = match (lat, lon, r) {
(Some(lat), Some(lon), Some(r)) => {
let coord = Coord { x: lon, y: lat };
let encoded = geohash::encode(coord, r)
.map_err(|e| {
@@ -148,10 +114,10 @@ impl Function for GeohashFunction {
_ => None,
};
results.push(result.as_deref());
builder.append_option(result);
}
Ok(results.to_vector())
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
}
@@ -176,8 +142,8 @@ impl Function for GeohashNeighboursFunction {
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::List(Arc::new(Field::new(
"x",
DataType::Utf8,
"item",
DataType::Utf8View,
false,
))))
}
@@ -199,32 +165,33 @@ impl Function for GeohashNeighboursFunction {
Signature::one_of(signatures, Volatility::Stable)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 3,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect 3, provided : {}",
columns.len()
),
}
);
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(&args.args)?;
let [lat_vec, lon_vec, resolutions] = utils::take_function_args(self.name(), args)?;
let lat_vec = &columns[0];
let lon_vec = &columns[1];
let resolution_vec = &columns[2];
let lat_vec = helpers::cast::<Float64Type>(&lat_vec)?;
let lat_vec = lat_vec.as_primitive::<Float64Type>();
let lon_vec = helpers::cast::<Float64Type>(&lon_vec)?;
let lon_vec = lon_vec.as_primitive::<Float64Type>();
let resolutions = helpers::cast::<UInt8Type>(&resolutions)?;
let resolutions = resolutions.as_primitive::<UInt8Type>();
let size = lat_vec.len();
let mut results =
ListVectorBuilder::with_type_capacity(ConcreteDataType::string_datatype(), size);
let mut builder = ListBuilder::new(StringViewBuilder::new());
for i in 0..size {
let lat = lat_vec.get(i).as_f64_lossy();
let lon = lon_vec.get(i).as_f64_lossy();
let r = try_into_resolution(resolution_vec.get(i))?;
let lat = lat_vec.is_valid(i).then(|| lat_vec.value(i));
let lon = lon_vec.is_valid(i).then(|| lon_vec.value(i));
let r = resolutions
.is_valid(i)
.then(|| ensure_resolution_usize(resolutions.value(i)))
.transpose()?;
let result = match (lat, lon) {
(Some(lat), Some(lon)) => {
match (lat, lon, r) {
(Some(lat), Some(lon), Some(r)) => {
let coord = Coord { x: lon, y: lat };
let encoded = geohash::encode(coord, r)
.map_err(|e| {
@@ -242,8 +209,8 @@ impl Function for GeohashNeighboursFunction {
))
})
.context(error::ExecuteSnafu)?;
Some(ListValue::new(
vec![
builder.append_value(
[
neighbours.n,
neighbours.nw,
neighbours.w,
@@ -254,22 +221,14 @@ impl Function for GeohashNeighboursFunction {
neighbours.ne,
]
.into_iter()
.map(Value::from)
.collect(),
ConcreteDataType::string_datatype(),
))
.map(Some),
);
}
_ => None,
_ => builder.append_null(),
};
if let Some(list_value) = result {
results.push(Some(list_value.as_scalar_ref()));
} else {
results.push(None);
}
}
Ok(results.to_vector())
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
}

View File

@@ -19,13 +19,11 @@ use common_error::ext::{BoxedError, PlainError};
use common_error::status_code::StatusCode;
use common_query::error::{self, Result};
use datafusion::arrow::array::{
Array, ArrayRef, AsArray, BooleanBuilder, Float64Builder, Int32Builder, ListBuilder,
StringViewArray, StringViewBuilder, UInt8Builder, UInt64Builder,
Array, AsArray, BooleanBuilder, Float64Builder, Int32Builder, ListBuilder, StringViewArray,
StringViewBuilder, UInt8Builder, UInt64Builder,
};
use datafusion::arrow::compute;
use datafusion::arrow::datatypes::{
ArrowPrimitiveType, Float64Type, Int64Type, UInt8Type, UInt64Type,
};
use datafusion::arrow::datatypes::{Float64Type, Int64Type, UInt8Type, UInt64Type};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{DataFusionError, ScalarValue, utils};
use datafusion_expr::type_coercion::aggregates::INTEGERS;
@@ -36,6 +34,7 @@ use h3o::{CellIndex, LatLng, Resolution};
use snafu::prelude::*;
use crate::function::Function;
use crate::scalars::geo::helpers;
static CELL_TYPES: LazyLock<Vec<DataType>> =
LazyLock::new(|| vec![DataType::Int64, DataType::UInt64, DataType::Utf8]);
@@ -89,11 +88,11 @@ impl Function for H3LatLngToCell {
let args = ColumnarValue::values_to_arrays(&args.args)?;
let [lat_vec, lon_vec, resolution_vec] = utils::take_function_args(self.name(), args)?;
let lat_vec = cast::<Float64Type>(&lat_vec)?;
let lat_vec = helpers::cast::<Float64Type>(&lat_vec)?;
let lat_vec = lat_vec.as_primitive::<Float64Type>();
let lon_vec = cast::<Float64Type>(&lon_vec)?;
let lon_vec = helpers::cast::<Float64Type>(&lon_vec)?;
let lon_vec = lon_vec.as_primitive::<Float64Type>();
let resolutions = cast::<UInt8Type>(&resolution_vec)?;
let resolutions = helpers::cast::<UInt8Type>(&resolution_vec)?;
let resolution_vec = resolutions.as_primitive::<UInt8Type>();
let size = lat_vec.len();
@@ -171,11 +170,11 @@ impl Function for H3LatLngToCellString {
let args = ColumnarValue::values_to_arrays(&args.args)?;
let [lat_vec, lon_vec, resolution_vec] = utils::take_function_args(self.name(), args)?;
let lat_vec = cast::<Float64Type>(&lat_vec)?;
let lat_vec = helpers::cast::<Float64Type>(&lat_vec)?;
let lat_vec = lat_vec.as_primitive::<Float64Type>();
let lon_vec = cast::<Float64Type>(&lon_vec)?;
let lon_vec = helpers::cast::<Float64Type>(&lon_vec)?;
let lon_vec = lon_vec.as_primitive::<Float64Type>();
let resolutions = cast::<UInt8Type>(&resolution_vec)?;
let resolutions = helpers::cast::<UInt8Type>(&resolution_vec)?;
let resolution_vec = resolutions.as_primitive::<UInt8Type>();
let size = lat_vec.len();
@@ -547,7 +546,7 @@ impl Function for H3CellToChildren {
) -> datafusion_common::Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(&args.args)?;
let [cell_vec, res_vec] = utils::take_function_args(self.name(), args)?;
let resolutions = cast::<UInt8Type>(&res_vec)?;
let resolutions = helpers::cast::<UInt8Type>(&res_vec)?;
let resolutions = resolutions.as_primitive::<UInt8Type>();
let size = cell_vec.len();
@@ -641,7 +640,7 @@ where
{
let args = ColumnarValue::values_to_arrays(&args.args)?;
let [cells, resolutions] = utils::take_function_args(name, args)?;
let resolutions = cast::<UInt8Type>(&resolutions)?;
let resolutions = helpers::cast::<UInt8Type>(&resolutions)?;
let resolutions = resolutions.as_primitive::<UInt8Type>();
let mut builder = UInt64Builder::with_capacity(cells.len());
@@ -698,7 +697,7 @@ impl Function for H3ChildPosToCell {
) -> datafusion_common::Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(&args.args)?;
let [pos_vec, cell_vec, res_vec] = utils::take_function_args(self.name(), args)?;
let resolutions = cast::<UInt8Type>(&res_vec)?;
let resolutions = helpers::cast::<UInt8Type>(&res_vec)?;
let resolutions = resolutions.as_primitive::<UInt8Type>();
let size = cell_vec.len();
@@ -722,18 +721,6 @@ impl Function for H3ChildPosToCell {
}
}
fn cast<T: ArrowPrimitiveType>(array: &ArrayRef) -> datafusion_common::Result<ArrayRef> {
let x = compute::cast_with_options(
array.as_ref(),
&T::DATA_TYPE,
&compute::CastOptions {
safe: false,
..Default::default()
},
)?;
Ok(x)
}
/// Function that returns cells with k distances of given cell
#[derive(Clone, Debug, Default, Display)]
#[display("{}", self.name())]

View File

@@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use datafusion::arrow::array::{ArrayRef, ArrowPrimitiveType};
use datafusion::arrow::compute;
macro_rules! ensure_columns_len {
($columns:ident) => {
snafu::ensure!(
@@ -73,3 +76,15 @@ macro_rules! ensure_and_coerce {
}
pub(crate) use ensure_and_coerce;
pub(crate) fn cast<T: ArrowPrimitiveType>(array: &ArrayRef) -> datafusion_common::Result<ArrayRef> {
let x = compute::cast_with_options(
array.as_ref(),
&T::DATA_TYPE,
&compute::CastOptions {
safe: false,
..Default::default()
},
)?;
Ok(x)
}

View File

@@ -16,21 +16,20 @@ use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use common_query::error::{IntoVectorSnafu, InvalidFuncArgsSnafu, InvalidInputTypeSnafu, Result};
use common_query::error::{InvalidFuncArgsSnafu, Result};
use datafusion::arrow::array::{Array, ArrayRef, AsArray, BooleanArray};
use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion};
use datafusion::common::{DFSchema, Result as DfResult};
use datafusion::execution::SessionStateBuilder;
use datafusion::logical_expr::{self, Expr, Volatility};
use datafusion::logical_expr::{self, ColumnarValue, Expr, Volatility};
use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner};
use datafusion_expr::Signature;
use datafusion_common::{DataFusionError, utils};
use datafusion_expr::{ScalarFunctionArgs, Signature};
use datatypes::arrow::array::RecordBatch;
use datatypes::arrow::datatypes::{DataType, Field};
use datatypes::prelude::VectorRef;
use datatypes::vectors::BooleanVector;
use snafu::{OptionExt, ResultExt, ensure};
use store_api::storage::ConcreteDataType;
use snafu::{OptionExt, ensure};
use crate::function::{Function, FunctionContext};
use crate::function::Function;
use crate::function_registry::FunctionRegistry;
/// `matches` for full text search.
@@ -65,38 +64,36 @@ impl Function for MatchesFunction {
}
// TODO: read case-sensitive config
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly 2, have: {}",
columns.len()
),
}
);
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DfResult<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(&args.args)?;
let [data_column, patterns] = utils::take_function_args(self.name(), args)?;
let data_column = &columns[0];
if data_column.is_empty() {
return Ok(Arc::new(BooleanVector::from(Vec::<bool>::with_capacity(0))));
return Ok(ColumnarValue::Array(Arc::new(BooleanArray::from(
Vec::<bool>::with_capacity(0),
))));
}
let pattern_vector = &columns[1]
.cast(&ConcreteDataType::string_datatype())
.context(InvalidInputTypeSnafu {
err_msg: "cannot cast `pattern` to string",
})?;
// Safety: both length and type are checked before
let pattern = pattern_vector.get(0).as_string().unwrap();
let pattern = match patterns.data_type() {
DataType::Utf8View => patterns.as_string_view().value(0),
DataType::Utf8 => patterns.as_string::<i32>().value(0),
DataType::LargeUtf8 => patterns.as_string::<i64>().value(0),
t => {
return Err(DataFusionError::Execution(format!(
"unsupported datatype {t}"
)));
}
};
self.eval(data_column, pattern)
}
}
impl MatchesFunction {
fn eval(&self, data: &VectorRef, pattern: String) -> Result<VectorRef> {
fn eval(&self, data_array: ArrayRef, pattern: &str) -> DfResult<ColumnarValue> {
let col_name = "data";
let parser_context = ParserContext::default();
let raw_ast = parser_context.parse_pattern(&pattern)?;
let raw_ast = parser_context.parse_pattern(pattern)?;
let ast = raw_ast.transform_ast()?;
let like_expr = ast.into_like_expr(col_name);
@@ -107,19 +104,14 @@ impl MatchesFunction {
let physical_expr =
planner.create_physical_expr(&like_expr, &input_schema, &session_state)?;
let data_array = data.to_arrow_array();
let arrow_schema = Arc::new(input_schema.as_arrow().clone());
let input_record_batch = RecordBatch::try_new(arrow_schema, vec![data_array]).unwrap();
let num_rows = input_record_batch.num_rows();
let result = physical_expr.evaluate(&input_record_batch)?;
let result_array = result.into_array(num_rows)?;
let result_vector =
BooleanVector::try_from_arrow_array(result_array).context(IntoVectorSnafu {
data_type: DataType::Boolean,
})?;
Ok(Arc::new(result_vector))
Ok(ColumnarValue::Array(Arc::new(result_array)))
}
fn input_schema() -> DFSchema {
@@ -833,7 +825,9 @@ impl Tokenizer {
#[cfg(test)]
mod test {
use datatypes::vectors::StringVector;
use datafusion::arrow::array::StringArray;
use datafusion_common::ScalarValue;
use datafusion_common::config::ConfigOptions;
use super::*;
@@ -1300,7 +1294,7 @@ mod test {
"The quick brown fox jumps over dog",
"The quick brown fox jumps over the dog",
];
let input_vector: VectorRef = Arc::new(StringVector::from(input_data));
let col: ArrayRef = Arc::new(StringArray::from(input_data));
let cases = [
// basic cases
("quick", vec![true, false, true, true, true, true, true]),
@@ -1391,9 +1385,22 @@ mod test {
let f = MatchesFunction;
for (pattern, expected) in cases {
let actual: VectorRef = f.eval(&input_vector, pattern.to_string()).unwrap();
let expected: VectorRef = Arc::new(BooleanVector::from(expected)) as _;
assert_eq!(expected, actual, "{pattern}");
let args = ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(col.clone()),
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(pattern.to_string()))),
],
arg_fields: vec![],
number_rows: col.len(),
return_field: Arc::new(Field::new("x", col.data_type().clone(), true)),
config_options: Arc::new(ConfigOptions::new()),
};
let actual = f
.invoke_with_args(args)
.and_then(|x| x.to_array(col.len()))
.unwrap();
let expected: ArrayRef = Arc::new(BooleanArray::from(expected));
assert_eq!(expected.as_ref(), actual.as_ref(), "{pattern}");
}
}
}

View File

@@ -23,10 +23,9 @@ use common_query::error::Result;
use datafusion::arrow::datatypes::DataType;
use datafusion::error::DataFusionError;
use datafusion_expr::{Signature, Volatility};
use datatypes::vectors::VectorRef;
pub use rate::RateFunction;
use crate::function::{Function, FunctionContext};
use crate::function::Function;
use crate::function_registry::FunctionRegistry;
use crate::scalars::math::modulo::ModuloFunction;
@@ -75,11 +74,4 @@ impl Function for RangeFunction {
fn signature(&self) -> Signature {
Signature::variadic_any(Volatility::Immutable)
}
fn eval(&self, _func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
Err(DataFusionError::Internal(
"range_fn just a empty function used in range select, It should not be eval!".into(),
)
.into())
}
}

View File

@@ -15,54 +15,21 @@
use std::fmt::{self, Display};
use std::sync::Arc;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use datafusion::arrow::array::{ArrayIter, PrimitiveArray};
use common_query::error::Result;
use datafusion::arrow::array::{Array, ArrayRef, AsArray, PrimitiveArray};
use datafusion::arrow::datatypes::DataType as ArrowDataType;
use datafusion::logical_expr::Volatility;
use datafusion_expr::Signature;
use datafusion::logical_expr::{ColumnarValue, Volatility};
use datafusion_common::{DataFusionError, ScalarValue, utils};
use datafusion_expr::type_coercion::aggregates::NUMERICS;
use datatypes::data_type::DataType;
use datatypes::prelude::VectorRef;
use datatypes::types::LogicalPrimitiveType;
use datatypes::value::TryAsPrimitive;
use datatypes::vectors::PrimitiveVector;
use datatypes::with_match_primitive_type_id;
use snafu::{OptionExt, ensure};
use datafusion_expr::{ScalarFunctionArgs, Signature};
use crate::function::{Function, FunctionContext};
use crate::function::Function;
#[derive(Clone, Debug, Default)]
pub struct ClampFunction;
const CLAMP_NAME: &str = "clamp";
/// Ensure the vector is constant and not empty (i.e., all values are identical)
fn ensure_constant_vector(vector: &VectorRef) -> Result<()> {
ensure!(
!vector.is_empty(),
InvalidFuncArgsSnafu {
err_msg: "Expect at least one value",
}
);
if vector.is_const() {
return Ok(());
}
let first = vector.get_ref(0);
for i in 1..vector.len() {
let v = vector.get_ref(i);
if first != v {
return InvalidFuncArgsSnafu {
err_msg: "All values in min/max argument must be identical",
}
.fail();
}
}
Ok(())
}
impl Function for ClampFunction {
fn name(&self) -> &str {
CLAMP_NAME
@@ -78,76 +45,12 @@ impl Function for ClampFunction {
Signature::uniform(3, NUMERICS.to_vec(), Volatility::Immutable)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 3,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly 3, have: {}",
columns.len()
),
}
);
ensure!(
columns[0].data_type().is_numeric(),
InvalidFuncArgsSnafu {
err_msg: format!(
"The first arg's type is not numeric, have: {}",
columns[0].data_type()
),
}
);
ensure!(
columns[0].data_type() == columns[1].data_type()
&& columns[1].data_type() == columns[2].data_type(),
InvalidFuncArgsSnafu {
err_msg: format!(
"Arguments don't have identical types: {}, {}, {}",
columns[0].data_type(),
columns[1].data_type(),
columns[2].data_type()
),
}
);
ensure_constant_vector(&columns[1])?;
ensure_constant_vector(&columns[2])?;
with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
let input_array = columns[0].to_arrow_array();
let input = input_array
.as_any()
.downcast_ref::<PrimitiveArray<<$S as LogicalPrimitiveType>::ArrowPrimitive>>()
.unwrap();
let min = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0))
.with_context(|| {
InvalidFuncArgsSnafu {
err_msg: "The second arg should not be none",
}
})?;
let max = TryAsPrimitive::<$S>::try_as_primitive(&columns[2].get(0))
.with_context(|| {
InvalidFuncArgsSnafu {
err_msg: "The third arg should not be none",
}
})?;
// ensure min <= max
ensure!(
min <= max,
InvalidFuncArgsSnafu {
err_msg: format!(
"The second arg should be less than or equal to the third arg, have: {:?}, {:?}",
columns[1], columns[2]
),
}
);
clamp_impl::<$S, true, true>(input, min, max)
},{
unreachable!()
})
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let [col, min, max] = utils::take_function_args(self.name(), args.args)?;
clamp_impl(col, min, max)
}
}
@@ -157,25 +60,155 @@ impl Display for ClampFunction {
}
}
fn clamp_impl<T: LogicalPrimitiveType, const CLAMP_MIN: bool, const CLAMP_MAX: bool>(
input: &PrimitiveArray<T::ArrowPrimitive>,
min: T::Native,
max: T::Native,
) -> Result<VectorRef> {
let iter = ArrayIter::new(input);
let result = iter.map(|x| {
x.map(|x| {
if CLAMP_MIN && x < min {
min
} else if CLAMP_MAX && x > max {
max
} else {
x
fn clamp_impl(
col: ColumnarValue,
min: ColumnarValue,
max: ColumnarValue,
) -> datafusion_common::Result<ColumnarValue> {
if col.data_type() != min.data_type() || min.data_type() != max.data_type() {
return Err(DataFusionError::Execution(format!(
"argument data types mismatch: {}, {}, {}",
col.data_type(),
min.data_type(),
max.data_type(),
)));
}
macro_rules! with_match_numerics_types {
($data_type:expr, | $_:tt $T:ident | $body:tt) => {{
macro_rules! __with_ty__ {
( $_ $T:ident ) => {
$body
};
}
})
});
let result = PrimitiveArray::<T::ArrowPrimitive>::from_iter(result);
Ok(Arc::new(PrimitiveVector::<T>::from(result)))
use datafusion::arrow::datatypes::{
Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, Int64Type, UInt8Type,
UInt16Type, UInt32Type, UInt64Type,
};
match $data_type {
ArrowDataType::Int8 => Ok(__with_ty__! { Int8Type }),
ArrowDataType::Int16 => Ok(__with_ty__! { Int16Type }),
ArrowDataType::Int32 => Ok(__with_ty__! { Int32Type }),
ArrowDataType::Int64 => Ok(__with_ty__! { Int64Type }),
ArrowDataType::UInt8 => Ok(__with_ty__! { UInt8Type }),
ArrowDataType::UInt16 => Ok(__with_ty__! { UInt16Type }),
ArrowDataType::UInt32 => Ok(__with_ty__! { UInt32Type }),
ArrowDataType::UInt64 => Ok(__with_ty__! { UInt64Type }),
ArrowDataType::Float32 => Ok(__with_ty__! { Float32Type }),
ArrowDataType::Float64 => Ok(__with_ty__! { Float64Type }),
_ => Err(DataFusionError::Execution(format!(
"unsupported numeric data type: '{}'",
$data_type
))),
}
}};
}
macro_rules! clamp {
($v: ident, $min: ident, $max: ident) => {
if $v < $min {
$min
} else if $v > $max {
$max
} else {
$v
}
};
}
match (col, min, max) {
(ColumnarValue::Scalar(col), ColumnarValue::Scalar(min), ColumnarValue::Scalar(max)) => {
if min > max {
return Err(DataFusionError::Execution(format!(
"min '{}' > max '{}'",
min, max
)));
}
Ok(ColumnarValue::Scalar(clamp!(col, min, max)))
}
(ColumnarValue::Array(col), ColumnarValue::Array(min), ColumnarValue::Array(max)) => {
if col.len() != min.len() || col.len() != max.len() {
return Err(DataFusionError::Internal(
"arguments not of same length".to_string(),
));
}
let result = with_match_numerics_types!(
col.data_type(),
|$S| {
let col = col.as_primitive::<$S>();
let min = min.as_primitive::<$S>();
let max = max.as_primitive::<$S>();
Arc::new(PrimitiveArray::<$S>::from(
(0..col.len())
.map(|i| {
let v = col.is_valid(i).then(|| col.value(i));
// Index safety: checked above, all have same length.
let min = min.is_valid(i).then(|| min.value(i));
let max = max.is_valid(i).then(|| max.value(i));
Ok(match (v, min, max) {
(Some(v), Some(min), Some(max)) => {
if min > max {
return Err(DataFusionError::Execution(format!(
"min '{}' > max '{}'",
min, max
)));
}
Some(clamp!(v, min, max))
},
_ => None,
})
})
.collect::<datafusion_common::Result<Vec<_>>>()?,
)
) as ArrayRef
}
)?;
Ok(ColumnarValue::Array(result))
}
(ColumnarValue::Array(col), ColumnarValue::Scalar(min), ColumnarValue::Scalar(max)) => {
if min.is_null() || max.is_null() {
return Err(DataFusionError::Execution(
"argument 'min' or 'max' is null".to_string(),
));
}
let min = min.to_array()?;
let max = max.to_array()?;
let result = with_match_numerics_types!(
col.data_type(),
|$S| {
let col = col.as_primitive::<$S>();
// Index safety: checked above, both are not nulls.
let min = min.as_primitive::<$S>().value(0);
let max = max.as_primitive::<$S>().value(0);
if min > max {
return Err(DataFusionError::Execution(format!(
"min '{}' > max '{}'",
min, max
)));
}
Arc::new(PrimitiveArray::<$S>::from(
(0..col.len())
.map(|x| {
col.is_valid(x).then(|| {
let v = col.value(x);
clamp!(v, min, max)
})
})
.collect::<Vec<_>>(),
)
) as ArrayRef
}
)?;
Ok(ColumnarValue::Array(result))
}
_ => Err(DataFusionError::Internal(
"argument column types mismatch".to_string(),
)),
}
}
#[derive(Clone, Debug, Default)]
@@ -197,59 +230,19 @@ impl Function for ClampMinFunction {
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly 2, have: {}",
columns.len()
),
}
);
ensure!(
columns[0].data_type().is_numeric(),
InvalidFuncArgsSnafu {
err_msg: format!(
"The first arg's type is not numeric, have: {}",
columns[0].data_type()
),
}
);
ensure!(
columns[0].data_type() == columns[1].data_type(),
InvalidFuncArgsSnafu {
err_msg: format!(
"Arguments don't have identical types: {}, {}",
columns[0].data_type(),
columns[1].data_type()
),
}
);
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let [col, min] = utils::take_function_args(self.name(), args.args)?;
ensure_constant_vector(&columns[1])?;
with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
let input_array = columns[0].to_arrow_array();
let input = input_array
.as_any()
.downcast_ref::<PrimitiveArray<<$S as LogicalPrimitiveType>::ArrowPrimitive>>()
.unwrap();
let min = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0))
.with_context(|| {
InvalidFuncArgsSnafu {
err_msg: "The second arg (min) should not be none",
}
})?;
// For clamp_min, max is effectively infinity, so we don't use it in the clamp_impl logic.
// We pass a default/dummy value for max.
let max_dummy = <$S as LogicalPrimitiveType>::Native::default();
clamp_impl::<$S, true, false>(input, min, max_dummy)
},{
unreachable!()
})
let Some(max) = ScalarValue::max(&min.data_type()) else {
return Err(DataFusionError::Internal(format!(
"cannot find a max value for numeric data type {}",
min.data_type()
)));
};
clamp_impl(col, min, ColumnarValue::Scalar(max))
}
}
@@ -278,59 +271,19 @@ impl Function for ClampMaxFunction {
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly 2, have: {}",
columns.len()
),
}
);
ensure!(
columns[0].data_type().is_numeric(),
InvalidFuncArgsSnafu {
err_msg: format!(
"The first arg's type is not numeric, have: {}",
columns[0].data_type()
),
}
);
ensure!(
columns[0].data_type() == columns[1].data_type(),
InvalidFuncArgsSnafu {
err_msg: format!(
"Arguments don't have identical types: {}, {}",
columns[0].data_type(),
columns[1].data_type()
),
}
);
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let [col, max] = utils::take_function_args(self.name(), args.args)?;
ensure_constant_vector(&columns[1])?;
with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
let input_array = columns[0].to_arrow_array();
let input = input_array
.as_any()
.downcast_ref::<PrimitiveArray<<$S as LogicalPrimitiveType>::ArrowPrimitive>>()
.unwrap();
let max = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0))
.with_context(|| {
InvalidFuncArgsSnafu {
err_msg: "The second arg (max) should not be none",
}
})?;
// For clamp_max, min is effectively -infinity, so we don't use it in the clamp_impl logic.
// We pass a default/dummy value for min.
let min_dummy = <$S as LogicalPrimitiveType>::Native::default();
clamp_impl::<$S, false, true>(input, min_dummy, max)
},{
unreachable!()
})
let Some(min) = ScalarValue::min(&max.data_type()) else {
return Err(DataFusionError::Internal(format!(
"cannot find a min value for numeric data type {}",
max.data_type()
)));
};
clamp_impl(col, ColumnarValue::Scalar(min), max)
}
}
@@ -345,55 +298,80 @@ mod test {
use std::sync::Arc;
use datatypes::prelude::ScalarVector;
use datatypes::vectors::{
ConstantVector, Float64Vector, Int64Vector, StringVector, UInt64Vector,
};
use arrow_schema::Field;
use datafusion_common::config::ConfigOptions;
use datatypes::arrow::array::{ArrayRef, Float64Array, Int64Array, UInt64Array};
use datatypes::arrow_array::StringArray;
use super::*;
use crate::function::FunctionContext;
macro_rules! impl_test_eval {
($func: ty) => {
impl $func {
fn test_eval(
&self,
args: Vec<ColumnarValue>,
number_rows: usize,
) -> datafusion_common::Result<ArrayRef> {
let input_type = args[0].data_type();
self.invoke_with_args(ScalarFunctionArgs {
args,
arg_fields: vec![],
number_rows,
return_field: Arc::new(Field::new("x", input_type, false)),
config_options: Arc::new(ConfigOptions::new()),
})
.and_then(|v| ColumnarValue::values_to_arrays(&[v]).map_err(Into::into))
.map(|mut a| a.remove(0))
}
}
};
}
impl_test_eval!(ClampFunction);
impl_test_eval!(ClampMinFunction);
impl_test_eval!(ClampMaxFunction);
#[test]
fn clamp_i64() {
let inputs = [
(
vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
-1,
10,
-1i64,
10i64,
vec![Some(-1), Some(-1), Some(-1), Some(0), Some(1), Some(2)],
),
(
vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
0,
0,
0i64,
0i64,
vec![Some(0), Some(0), Some(0), Some(0), Some(0), Some(0)],
),
(
vec![Some(-3), None, Some(-1), None, None, Some(2)],
-2,
1,
-2i64,
1i64,
vec![Some(-2), None, Some(-1), None, None, Some(1)],
),
(
vec![None, None, None, None, None],
0,
1,
0i64,
1i64,
vec![None, None, None, None, None],
),
];
let func = ClampFunction;
for (in_data, min, max, expected) in inputs {
let args = [
Arc::new(Int64Vector::from(in_data)) as _,
Arc::new(Int64Vector::from_vec(vec![min])) as _,
Arc::new(Int64Vector::from_vec(vec![max])) as _,
let number_rows = in_data.len();
let args = vec![
ColumnarValue::Array(Arc::new(Int64Array::from(in_data))),
ColumnarValue::Scalar(min.into()),
ColumnarValue::Scalar(max.into()),
];
let result = func
.eval(&FunctionContext::default(), args.as_slice())
.unwrap();
let expected: VectorRef = Arc::new(Int64Vector::from(expected));
assert_eq!(expected, result);
let result = func.test_eval(args, number_rows).unwrap();
let expected: ArrayRef = Arc::new(Int64Array::from(expected));
assert_eq!(expected.as_ref(), result.as_ref());
}
}
@@ -402,42 +380,41 @@ mod test {
let inputs = [
(
vec![Some(0), Some(1), Some(2), Some(3), Some(4), Some(5)],
1,
3,
1u64,
3u64,
vec![Some(1), Some(1), Some(2), Some(3), Some(3), Some(3)],
),
(
vec![Some(0), Some(1), Some(2), Some(3), Some(4), Some(5)],
0,
0,
0u64,
0u64,
vec![Some(0), Some(0), Some(0), Some(0), Some(0), Some(0)],
),
(
vec![Some(0), None, Some(2), None, None, Some(5)],
1,
3,
1u64,
3u64,
vec![Some(1), None, Some(2), None, None, Some(3)],
),
(
vec![None, None, None, None, None],
0,
1,
0u64,
1u64,
vec![None, None, None, None, None],
),
];
let func = ClampFunction;
for (in_data, min, max, expected) in inputs {
let args = [
Arc::new(UInt64Vector::from(in_data)) as _,
Arc::new(UInt64Vector::from_vec(vec![min])) as _,
Arc::new(UInt64Vector::from_vec(vec![max])) as _,
let number_rows = in_data.len();
let args = vec![
ColumnarValue::Array(Arc::new(UInt64Array::from(in_data))),
ColumnarValue::Scalar(min.into()),
ColumnarValue::Scalar(max.into()),
];
let result = func
.eval(&FunctionContext::default(), args.as_slice())
.unwrap();
let expected: VectorRef = Arc::new(UInt64Vector::from(expected));
assert_eq!(expected, result);
let result = func.test_eval(args, number_rows).unwrap();
let expected: ArrayRef = Arc::new(UInt64Array::from(expected));
assert_eq!(expected.as_ref(), result.as_ref());
}
}
@@ -472,38 +449,18 @@ mod test {
let func = ClampFunction;
for (in_data, min, max, expected) in inputs {
let args = [
Arc::new(Float64Vector::from(in_data)) as _,
Arc::new(Float64Vector::from_vec(vec![min])) as _,
Arc::new(Float64Vector::from_vec(vec![max])) as _,
let number_rows = in_data.len();
let args = vec![
ColumnarValue::Array(Arc::new(Float64Array::from(in_data))),
ColumnarValue::Scalar(min.into()),
ColumnarValue::Scalar(max.into()),
];
let result = func
.eval(&FunctionContext::default(), args.as_slice())
.unwrap();
let expected: VectorRef = Arc::new(Float64Vector::from(expected));
assert_eq!(expected, result);
let result = func.test_eval(args, number_rows).unwrap();
let expected: ArrayRef = Arc::new(Float64Array::from(expected));
assert_eq!(expected.as_ref(), result.as_ref());
}
}
#[test]
fn clamp_const_i32() {
let input = vec![Some(5)];
let min = 2;
let max = 4;
let func = ClampFunction;
let args = [
Arc::new(ConstantVector::new(Arc::new(Int64Vector::from(input)), 1)) as _,
Arc::new(Int64Vector::from_vec(vec![min])) as _,
Arc::new(Int64Vector::from_vec(vec![max])) as _,
];
let result = func
.eval(&FunctionContext::default(), args.as_slice())
.unwrap();
let expected: VectorRef = Arc::new(Int64Vector::from(vec![Some(4)]));
assert_eq!(expected, result);
}
#[test]
fn clamp_invalid_min_max() {
let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
@@ -511,28 +468,30 @@ mod test {
let max = -1.0;
let func = ClampFunction;
let args = [
Arc::new(Float64Vector::from(input)) as _,
Arc::new(Float64Vector::from_vec(vec![min])) as _,
Arc::new(Float64Vector::from_vec(vec![max])) as _,
let number_rows = input.len();
let args = vec![
ColumnarValue::Array(Arc::new(Float64Array::from(input))),
ColumnarValue::Scalar(min.into()),
ColumnarValue::Scalar(max.into()),
];
let result = func.eval(&FunctionContext::default(), args.as_slice());
let result = func.test_eval(args, number_rows);
assert!(result.is_err());
}
#[test]
fn clamp_type_not_match() {
let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
let min = -1;
let max = 10;
let min = -1i64;
let max = 10u64;
let func = ClampFunction;
let args = [
Arc::new(Float64Vector::from(input)) as _,
Arc::new(Int64Vector::from_vec(vec![min])) as _,
Arc::new(UInt64Vector::from_vec(vec![max])) as _,
let number_rows = input.len();
let args = vec![
ColumnarValue::Array(Arc::new(Float64Array::from(input))),
ColumnarValue::Scalar(min.into()),
ColumnarValue::Scalar(max.into()),
];
let result = func.eval(&FunctionContext::default(), args.as_slice());
let result = func.test_eval(args, number_rows);
assert!(result.is_err());
}
@@ -543,12 +502,13 @@ mod test {
let max = 1.0;
let func = ClampFunction;
let args = [
Arc::new(Float64Vector::from(input)) as _,
Arc::new(Float64Vector::from_vec(vec![min, max])) as _,
Arc::new(Float64Vector::from_vec(vec![max, min])) as _,
let number_rows = input.len();
let args = vec![
ColumnarValue::Array(Arc::new(Float64Array::from(input))),
ColumnarValue::Array(Arc::new(Float64Array::from(vec![min, max]))),
ColumnarValue::Array(Arc::new(Float64Array::from(vec![max, min]))),
];
let result = func.eval(&FunctionContext::default(), args.as_slice());
let result = func.test_eval(args, number_rows);
assert!(result.is_err());
}
@@ -558,11 +518,12 @@ mod test {
let min = -10.0;
let func = ClampFunction;
let args = [
Arc::new(Float64Vector::from(input)) as _,
Arc::new(Float64Vector::from_vec(vec![min])) as _,
let number_rows = input.len();
let args = vec![
ColumnarValue::Array(Arc::new(Float64Array::from(input))),
ColumnarValue::Scalar(min.into()),
];
let result = func.eval(&FunctionContext::default(), args.as_slice());
let result = func.test_eval(args, number_rows);
assert!(result.is_err());
}
@@ -571,12 +532,13 @@ mod test {
let input = vec![Some("foo"), Some("foo"), Some("foo"), Some("foo")];
let func = ClampFunction;
let args = [
Arc::new(StringVector::from(input)) as _,
Arc::new(StringVector::from_vec(vec!["bar"])) as _,
Arc::new(StringVector::from_vec(vec!["baz"])) as _,
let number_rows = input.len();
let args = vec![
ColumnarValue::Array(Arc::new(StringArray::from(input))),
ColumnarValue::Scalar("bar".into()),
ColumnarValue::Scalar("baz".into()),
];
let result = func.eval(&FunctionContext::default(), args.as_slice());
let result = func.test_eval(args, number_rows);
assert!(result.is_err());
}
@@ -585,27 +547,26 @@ mod test {
let inputs = [
(
vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
-1,
-1i64,
vec![Some(-1), Some(-1), Some(-1), Some(0), Some(1), Some(2)],
),
(
vec![Some(-3), None, Some(-1), None, None, Some(2)],
-2,
-2i64,
vec![Some(-2), None, Some(-1), None, None, Some(2)],
),
];
let func = ClampMinFunction;
for (in_data, min, expected) in inputs {
let args = [
Arc::new(Int64Vector::from(in_data)) as _,
Arc::new(Int64Vector::from_vec(vec![min])) as _,
let number_rows = in_data.len();
let args = vec![
ColumnarValue::Array(Arc::new(Int64Array::from(in_data))),
ColumnarValue::Scalar(min.into()),
];
let result = func
.eval(&FunctionContext::default(), args.as_slice())
.unwrap();
let expected: VectorRef = Arc::new(Int64Vector::from(expected));
assert_eq!(expected, result);
let result = func.test_eval(args, number_rows).unwrap();
let expected: ArrayRef = Arc::new(Int64Array::from(expected));
assert_eq!(expected.as_ref(), result.as_ref());
}
}
@@ -614,27 +575,26 @@ mod test {
let inputs = [
(
vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
1,
1i64,
vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(1)],
),
(
vec![Some(-3), None, Some(-1), None, None, Some(2)],
0,
0i64,
vec![Some(-3), None, Some(-1), None, None, Some(0)],
),
];
let func = ClampMaxFunction;
for (in_data, max, expected) in inputs {
let args = [
Arc::new(Int64Vector::from(in_data)) as _,
Arc::new(Int64Vector::from_vec(vec![max])) as _,
let number_rows = in_data.len();
let args = vec![
ColumnarValue::Array(Arc::new(Int64Array::from(in_data))),
ColumnarValue::Scalar(max.into()),
];
let result = func
.eval(&FunctionContext::default(), args.as_slice())
.unwrap();
let expected: VectorRef = Arc::new(Int64Vector::from(expected));
assert_eq!(expected, result);
let result = func.test_eval(args, number_rows).unwrap();
let expected: ArrayRef = Arc::new(Int64Array::from(expected));
assert_eq!(expected.as_ref(), result.as_ref());
}
}
@@ -648,15 +608,14 @@ mod test {
let func = ClampMinFunction;
for (in_data, min, expected) in inputs {
let args = [
Arc::new(Float64Vector::from(in_data)) as _,
Arc::new(Float64Vector::from_vec(vec![min])) as _,
let number_rows = in_data.len();
let args = vec![
ColumnarValue::Array(Arc::new(Float64Array::from(in_data))),
ColumnarValue::Scalar(min.into()),
];
let result = func
.eval(&FunctionContext::default(), args.as_slice())
.unwrap();
let expected: VectorRef = Arc::new(Float64Vector::from(expected));
assert_eq!(expected, result);
let result = func.test_eval(args, number_rows).unwrap();
let expected: ArrayRef = Arc::new(Float64Array::from(expected));
assert_eq!(expected.as_ref(), result.as_ref());
}
}
@@ -670,43 +629,44 @@ mod test {
let func = ClampMaxFunction;
for (in_data, max, expected) in inputs {
let args = [
Arc::new(Float64Vector::from(in_data)) as _,
Arc::new(Float64Vector::from_vec(vec![max])) as _,
let number_rows = in_data.len();
let args = vec![
ColumnarValue::Array(Arc::new(Float64Array::from(in_data))),
ColumnarValue::Scalar(max.into()),
];
let result = func
.eval(&FunctionContext::default(), args.as_slice())
.unwrap();
let expected: VectorRef = Arc::new(Float64Vector::from(expected));
assert_eq!(expected, result);
let result = func.test_eval(args, number_rows).unwrap();
let expected: ArrayRef = Arc::new(Float64Array::from(expected));
assert_eq!(expected.as_ref(), result.as_ref());
}
}
#[test]
fn clamp_min_type_not_match() {
let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
let min = -1;
let min = -1i64;
let func = ClampMinFunction;
let args = [
Arc::new(Float64Vector::from(input)) as _,
Arc::new(Int64Vector::from_vec(vec![min])) as _,
let number_rows = input.len();
let args = vec![
ColumnarValue::Array(Arc::new(Float64Array::from(input))),
ColumnarValue::Scalar(min.into()),
];
let result = func.eval(&FunctionContext::default(), args.as_slice());
let result = func.test_eval(args, number_rows);
assert!(result.is_err());
}
#[test]
fn clamp_max_type_not_match() {
let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
let max = 1;
let max = 1i64;
let func = ClampMaxFunction;
let args = [
Arc::new(Float64Vector::from(input)) as _,
Arc::new(Int64Vector::from_vec(vec![max])) as _,
let number_rows = input.len();
let args = vec![
ColumnarValue::Array(Arc::new(Float64Array::from(input))),
ColumnarValue::Scalar(max.into()),
];
let result = func.eval(&FunctionContext::default(), args.as_slice());
let result = func.test_eval(args, number_rows);
assert!(result.is_err());
}
}

View File

@@ -28,7 +28,14 @@ mod vector_norm;
mod vector_sub;
mod vector_subvector;
use std::borrow::Cow;
use datafusion_common::{DataFusionError, Result, ScalarValue, utils};
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
use crate::function_registry::FunctionRegistry;
use crate::scalars::vector::impl_conv::as_veclit;
pub(crate) struct VectorFunction;
impl VectorFunction {
@@ -59,3 +66,155 @@ impl VectorFunction {
registry.register_scalar(elem_product::ElemProductFunction);
}
}
// Use macro instead of function to "return" the reference to `ScalarValue` in the
// `ColumnarValue::Array` match arm.
macro_rules! try_get_scalar_value {
($col: ident, $i: ident) => {
match $col {
datafusion::logical_expr::ColumnarValue::Array(a) => {
&datafusion_common::ScalarValue::try_from_array(a.as_ref(), $i)?
}
datafusion::logical_expr::ColumnarValue::Scalar(v) => v,
}
};
}
pub(crate) fn ensure_same_length(values: &[&ColumnarValue]) -> Result<usize> {
if values.is_empty() {
return Ok(0);
}
let mut array_len = None;
for v in values {
array_len = match (v, array_len) {
(ColumnarValue::Array(a), None) => Some(a.len()),
(ColumnarValue::Array(a), Some(array_len)) => {
if array_len == a.len() {
Some(array_len)
} else {
return Err(DataFusionError::Internal(format!(
"Arguments has mixed length. Expected length: {array_len}, found length: {}",
a.len()
)));
}
}
(ColumnarValue::Scalar(_), array_len) => array_len,
}
}
// If array_len is none, it means there are only scalars, treat them each as 1 element array.
let array_len = array_len.unwrap_or(1);
Ok(array_len)
}
struct VectorCalculator<'a, F> {
name: &'a str,
func: F,
}
impl<F> VectorCalculator<'_, F>
where
F: Fn(&ScalarValue, &ScalarValue) -> Result<ScalarValue>,
{
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let [arg0, arg1] = utils::take_function_args(self.name, &args.args)?;
if let (ColumnarValue::Scalar(v0), ColumnarValue::Scalar(v1)) = (arg0, arg1) {
let result = (self.func)(v0, v1)?;
return Ok(ColumnarValue::Scalar(result));
}
let len = ensure_same_length(&[arg0, arg1])?;
let mut results = Vec::with_capacity(len);
for i in 0..len {
let v0 = try_get_scalar_value!(arg0, i);
let v1 = try_get_scalar_value!(arg1, i);
results.push((self.func)(v0, v1)?);
}
let results = ScalarValue::iter_to_array(results.into_iter())?;
Ok(ColumnarValue::Array(results))
}
}
impl<F> VectorCalculator<'_, F>
where
F: Fn(&Option<Cow<[f32]>>, &Option<Cow<[f32]>>) -> Result<ScalarValue>,
{
fn invoke_with_vectors(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let [arg0, arg1] = utils::take_function_args(self.name, &args.args)?;
if let (ColumnarValue::Scalar(v0), ColumnarValue::Scalar(v1)) = (arg0, arg1) {
let v0 = as_veclit(v0)?;
let v1 = as_veclit(v1)?;
let result = (self.func)(&v0, &v1)?;
return Ok(ColumnarValue::Scalar(result));
}
let len = ensure_same_length(&[arg0, arg1])?;
let mut results = Vec::with_capacity(len);
match (arg0, arg1) {
(ColumnarValue::Scalar(v0), ColumnarValue::Array(a1)) => {
let v0 = as_veclit(v0)?;
for i in 0..len {
let v1 = ScalarValue::try_from_array(a1, i)?;
let v1 = as_veclit(&v1)?;
results.push((self.func)(&v0, &v1)?);
}
}
(ColumnarValue::Array(a0), ColumnarValue::Scalar(v1)) => {
let v1 = as_veclit(v1)?;
for i in 0..len {
let v0 = ScalarValue::try_from_array(a0, i)?;
let v0 = as_veclit(&v0)?;
results.push((self.func)(&v0, &v1)?);
}
}
(ColumnarValue::Array(a0), ColumnarValue::Array(a1)) => {
for i in 0..len {
let v0 = ScalarValue::try_from_array(a0, i)?;
let v0 = as_veclit(&v0)?;
let v1 = ScalarValue::try_from_array(a1, i)?;
let v1 = as_veclit(&v1)?;
results.push((self.func)(&v0, &v1)?);
}
}
(ColumnarValue::Scalar(_), ColumnarValue::Scalar(_)) => {
// unreachable because this arm has been separately dealt with above
unreachable!()
}
}
let results = ScalarValue::iter_to_array(results.into_iter())?;
Ok(ColumnarValue::Array(results))
}
}
impl<F> VectorCalculator<'_, F>
where
F: Fn(&ScalarValue) -> Result<ScalarValue>,
{
fn invoke_with_single_argument(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let [arg0] = utils::take_function_args(self.name, &args.args)?;
let arg0 = match arg0 {
ColumnarValue::Scalar(v) => {
let result = (self.func)(v)?;
return Ok(ColumnarValue::Scalar(result));
}
ColumnarValue::Array(a) => a,
};
let len = arg0.len();
let mut results = Vec::with_capacity(len);
for i in 0..len {
let v = ScalarValue::try_from_array(arg0, i)?;
results.push((self.func)(&v)?);
}
let results = ScalarValue::iter_to_array(results.into_iter())?;
Ok(ColumnarValue::Array(results))
}
}

View File

@@ -17,7 +17,7 @@ use std::fmt::Display;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use datafusion::arrow::datatypes::DataType;
use datafusion_expr::type_coercion::aggregates::BINARYS;
use datafusion_expr::{Signature, Volatility};
use datafusion_expr::{Signature, TypeSignature, Volatility};
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::types::vector_type_value_to_string;
use datatypes::value::Value;
@@ -41,7 +41,13 @@ impl Function for VectorToStringFunction {
}
fn signature(&self) -> Signature {
Signature::uniform(1, BINARYS.to_vec(), Volatility::Immutable)
Signature::one_of(
vec![
TypeSignature::Uniform(1, vec![DataType::BinaryView]),
TypeSignature::Uniform(1, BINARYS.to_vec()),
],
Volatility::Immutable,
)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {

View File

@@ -19,20 +19,17 @@ mod l2sq;
use std::borrow::Cow;
use std::fmt::Display;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use datafusion_expr::Signature;
use common_query::error::Result;
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::ScalarValue;
use datafusion_expr::{ScalarFunctionArgs, Signature};
use datatypes::arrow::datatypes::DataType;
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef};
use snafu::ensure;
use crate::function::{Function, FunctionContext};
use crate::function::Function;
use crate::helper;
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
macro_rules! define_distance_function {
($StructName:ident, $display_name:expr, $similarity_method:path) => {
/// A function calculates the distance between two vectors.
#[derive(Debug, Clone, Default)]
@@ -54,59 +51,34 @@ macro_rules! define_distance_function {
)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly two, have: {}",
columns.len()
),
}
);
let arg0 = &columns[0];
let arg1 = &columns[1];
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let body = |v0: &Option<Cow<[f32]>>,
v1: &Option<Cow<[f32]>>|
-> datafusion_common::Result<ScalarValue> {
let result = if let (Some(v0), Some(v1)) = (v0, v1) {
if v0.len() != v1.len() {
return Err(datafusion_common::DataFusionError::Execution(format!(
"vectors length not match: {}",
self.name()
)));
}
let size = arg0.len();
let mut result = Float32VectorBuilder::with_capacity(size);
if size == 0 {
return Ok(result.to_vector());
}
let arg0_const = as_veclit_if_const(arg0)?;
let arg1_const = as_veclit_if_const(arg1)?;
for i in 0..size {
let vec0 = match arg0_const.as_ref() {
Some(a) => Some(Cow::Borrowed(a.as_ref())),
None => as_veclit(arg0.get_ref(i))?,
};
let vec1 = match arg1_const.as_ref() {
Some(b) => Some(Cow::Borrowed(b.as_ref())),
None => as_veclit(arg1.get_ref(i))?,
};
if let (Some(vec0), Some(vec1)) = (vec0, vec1) {
ensure!(
vec0.len() == vec1.len(),
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the vectors must match to calculate distance, have: {} vs {}",
vec0.len(),
vec1.len()
),
}
);
// Checked if the length of the vectors match
let d = $similarity_method(vec0.as_ref(), vec1.as_ref());
result.push(Some(d));
let d = $similarity_method(v0, v1);
Some(d)
} else {
result.push_null();
}
}
None
};
Ok(ScalarValue::Float32(result))
};
return Ok(result.to_vector());
let calculator = $crate::scalars::vector::VectorCalculator {
name: self.name(),
func: body,
};
calculator.invoke_with_vectors(args)
}
}
@@ -115,7 +87,7 @@ macro_rules! define_distance_function {
write!(f, "{}", $display_name.to_ascii_uppercase())
}
}
}
};
}
define_distance_function!(CosDistanceFunction, "vec_cos_distance", cos::cos);
@@ -126,10 +98,29 @@ define_distance_function!(DotProductFunction, "vec_dot_product", dot::dot);
mod tests {
use std::sync::Arc;
use datatypes::vectors::{BinaryVector, ConstantVector, StringVector};
use arrow_schema::Field;
use datafusion::arrow::array::{Array, ArrayRef, AsArray, BinaryArray, StringViewArray};
use datafusion::arrow::datatypes::Float32Type;
use datafusion_common::config::ConfigOptions;
use super::*;
fn test_invoke(func: &dyn Function, args: &[ArrayRef]) -> datafusion_common::Result<ArrayRef> {
let number_rows = args[0].len();
let args = ScalarFunctionArgs {
args: args
.iter()
.map(|x| ColumnarValue::Array(x.clone()))
.collect::<Vec<_>>(),
arg_fields: vec![],
number_rows,
return_field: Arc::new(Field::new("x", DataType::Float32, false)),
config_options: Arc::new(ConfigOptions::new()),
};
func.invoke_with_args(args)
.and_then(|x| x.to_array(number_rows))
}
#[test]
fn test_distance_string_string() {
let funcs = [
@@ -139,36 +130,34 @@ mod tests {
];
for func in funcs {
let vec1 = Arc::new(StringVector::from(vec![
let vec1: ArrayRef = Arc::new(StringViewArray::from(vec![
Some("[0.0, 1.0]"),
Some("[1.0, 0.0]"),
None,
Some("[1.0, 0.0]"),
])) as VectorRef;
let vec2 = Arc::new(StringVector::from(vec![
]));
let vec2: ArrayRef = Arc::new(StringViewArray::from(vec![
Some("[0.0, 1.0]"),
Some("[0.0, 1.0]"),
Some("[0.0, 1.0]"),
None,
])) as VectorRef;
]));
let result = func
.eval(&FunctionContext::default(), &[vec1.clone(), vec2.clone()])
.unwrap();
let result = test_invoke(func.as_ref(), &[vec1.clone(), vec2.clone()]).unwrap();
let result = result.as_primitive::<Float32Type>();
assert!(!result.get(0).is_null());
assert!(!result.get(1).is_null());
assert!(result.get(2).is_null());
assert!(result.get(3).is_null());
assert!(!result.is_null(0));
assert!(!result.is_null(1));
assert!(result.is_null(2));
assert!(result.is_null(3));
let result = func
.eval(&FunctionContext::default(), &[vec2, vec1])
.unwrap();
let result = test_invoke(func.as_ref(), &[vec2, vec1]).unwrap();
let result = result.as_primitive::<Float32Type>();
assert!(!result.get(0).is_null());
assert!(!result.get(1).is_null());
assert!(result.get(2).is_null());
assert!(result.get(3).is_null());
assert!(!result.is_null(0));
assert!(!result.is_null(1));
assert!(result.is_null(2));
assert!(result.is_null(3));
}
}
@@ -181,37 +170,35 @@ mod tests {
];
for func in funcs {
let vec1 = Arc::new(BinaryVector::from(vec![
let vec1: ArrayRef = Arc::new(BinaryArray::from_iter(vec![
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
Some(vec![0, 0, 128, 63, 0, 0, 0, 0]),
None,
Some(vec![0, 0, 128, 63, 0, 0, 0, 0]),
])) as VectorRef;
let vec2 = Arc::new(BinaryVector::from(vec![
]));
let vec2: ArrayRef = Arc::new(BinaryArray::from_iter(vec![
// [0.0, 1.0]
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
None,
])) as VectorRef;
]));
let result = func
.eval(&FunctionContext::default(), &[vec1.clone(), vec2.clone()])
.unwrap();
let result = test_invoke(func.as_ref(), &[vec1.clone(), vec2.clone()]).unwrap();
let result = result.as_primitive::<Float32Type>();
assert!(!result.get(0).is_null());
assert!(!result.get(1).is_null());
assert!(result.get(2).is_null());
assert!(result.get(3).is_null());
assert!(!result.is_null(0));
assert!(!result.is_null(1));
assert!(result.is_null(2));
assert!(result.is_null(3));
let result = func
.eval(&FunctionContext::default(), &[vec2, vec1])
.unwrap();
let result = test_invoke(func.as_ref(), &[vec2, vec1]).unwrap();
let result = result.as_primitive::<Float32Type>();
assert!(!result.get(0).is_null());
assert!(!result.get(1).is_null());
assert!(result.get(2).is_null());
assert!(result.get(3).is_null());
assert!(!result.is_null(0));
assert!(!result.is_null(1));
assert!(result.is_null(2));
assert!(result.is_null(3));
}
}
@@ -224,115 +211,35 @@ mod tests {
];
for func in funcs {
let vec1 = Arc::new(StringVector::from(vec![
let vec1: ArrayRef = Arc::new(StringViewArray::from(vec![
Some("[0.0, 1.0]"),
Some("[1.0, 0.0]"),
None,
Some("[1.0, 0.0]"),
])) as VectorRef;
let vec2 = Arc::new(BinaryVector::from(vec![
]));
let vec2: ArrayRef = Arc::new(BinaryArray::from_iter(vec![
// [0.0, 1.0]
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
None,
])) as VectorRef;
]));
let result = func
.eval(&FunctionContext::default(), &[vec1.clone(), vec2.clone()])
.unwrap();
let result = test_invoke(func.as_ref(), &[vec1.clone(), vec2.clone()]).unwrap();
let result = result.as_primitive::<Float32Type>();
assert!(!result.get(0).is_null());
assert!(!result.get(1).is_null());
assert!(result.get(2).is_null());
assert!(result.get(3).is_null());
assert!(!result.is_null(0));
assert!(!result.is_null(1));
assert!(result.is_null(2));
assert!(result.is_null(3));
let result = func
.eval(&FunctionContext::default(), &[vec2, vec1])
.unwrap();
let result = test_invoke(func.as_ref(), &[vec2, vec1]).unwrap();
let result = result.as_primitive::<Float32Type>();
assert!(!result.get(0).is_null());
assert!(!result.get(1).is_null());
assert!(result.get(2).is_null());
assert!(result.get(3).is_null());
}
}
#[test]
fn test_distance_const_string() {
let funcs = [
Box::new(CosDistanceFunction {}) as Box<dyn Function>,
Box::new(L2SqDistanceFunction {}) as Box<dyn Function>,
Box::new(DotProductFunction {}) as Box<dyn Function>,
];
for func in funcs {
let const_str = Arc::new(ConstantVector::new(
Arc::new(StringVector::from(vec!["[0.0, 1.0]"])),
4,
));
let vec1 = Arc::new(StringVector::from(vec![
Some("[0.0, 1.0]"),
Some("[1.0, 0.0]"),
None,
Some("[1.0, 0.0]"),
])) as VectorRef;
let vec2 = Arc::new(BinaryVector::from(vec![
// [0.0, 1.0]
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
None,
])) as VectorRef;
let result = func
.eval(
&FunctionContext::default(),
&[const_str.clone(), vec1.clone()],
)
.unwrap();
assert!(!result.get(0).is_null());
assert!(!result.get(1).is_null());
assert!(result.get(2).is_null());
assert!(!result.get(3).is_null());
let result = func
.eval(
&FunctionContext::default(),
&[vec1.clone(), const_str.clone()],
)
.unwrap();
assert!(!result.get(0).is_null());
assert!(!result.get(1).is_null());
assert!(result.get(2).is_null());
assert!(!result.get(3).is_null());
let result = func
.eval(
&FunctionContext::default(),
&[const_str.clone(), vec2.clone()],
)
.unwrap();
assert!(!result.get(0).is_null());
assert!(!result.get(1).is_null());
assert!(!result.get(2).is_null());
assert!(result.get(3).is_null());
let result = func
.eval(
&FunctionContext::default(),
&[vec2.clone(), const_str.clone()],
)
.unwrap();
assert!(!result.get(0).is_null());
assert!(!result.get(1).is_null());
assert!(!result.get(2).is_null());
assert!(result.get(3).is_null());
assert!(!result.is_null(0));
assert!(!result.is_null(1));
assert!(result.is_null(2));
assert!(result.is_null(3));
}
}
@@ -345,15 +252,16 @@ mod tests {
];
for func in funcs {
let vec1 = Arc::new(StringVector::from(vec!["[1.0]"])) as VectorRef;
let vec2 = Arc::new(StringVector::from(vec!["[1.0, 1.0]"])) as VectorRef;
let result = func.eval(&FunctionContext::default(), &[vec1, vec2]);
let vec1: ArrayRef = Arc::new(StringViewArray::from(vec!["[1.0]"]));
let vec2: ArrayRef = Arc::new(StringViewArray::from(vec!["[1.0, 1.0]"]));
let result = test_invoke(func.as_ref(), &[vec1, vec2]);
assert!(result.is_err());
let vec1 = Arc::new(BinaryVector::from(vec![vec![0, 0, 128, 63]])) as VectorRef;
let vec2 =
Arc::new(BinaryVector::from(vec![vec![0, 0, 128, 63, 0, 0, 0, 64]])) as VectorRef;
let result = func.eval(&FunctionContext::default(), &[vec1, vec2]);
let vec1: ArrayRef = Arc::new(BinaryArray::from_iter_values(vec![vec![0, 0, 128, 63]]));
let vec2: ArrayRef = Arc::new(BinaryArray::from_iter_values(vec![vec![
0, 0, 128, 63, 0, 0, 0, 64,
]]));
let result = test_invoke(func.as_ref(), &[vec1, vec2]);
assert!(result.is_err());
}
}

View File

@@ -12,20 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::borrow::Cow;
use std::fmt::Display;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use common_query::error::Result;
use datafusion::arrow::datatypes::DataType;
use datafusion::logical_expr::ColumnarValue;
use datafusion::logical_expr_common::type_coercion::aggregates::{BINARYS, STRINGS};
use datafusion_expr::{Signature, TypeSignature, Volatility};
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef};
use datafusion_common::ScalarValue;
use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
use nalgebra::DVectorView;
use snafu::ensure;
use crate::function::{Function, FunctionContext};
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
use crate::function::Function;
use crate::scalars::vector::{VectorCalculator, impl_conv};
const NAME: &str = "vec_elem_product";
@@ -64,43 +62,21 @@ impl Function for ElemProductFunction {
)
}
fn eval(
fn invoke_with_args(
&self,
_func_ctx: &FunctionContext,
columns: &[VectorRef],
) -> common_query::error::Result<VectorRef> {
ensure!(
columns.len() == 1,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly one, have: {}",
columns.len()
)
}
);
let arg0 = &columns[0];
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
let v0 = impl_conv::as_veclit(v0)?
.map(|v0| DVectorView::from_slice(&v0, v0.len()).product());
Ok(ScalarValue::Float32(v0))
};
let len = arg0.len();
let mut result = Float32VectorBuilder::with_capacity(len);
if len == 0 {
return Ok(result.to_vector());
}
let arg0_const = as_veclit_if_const(arg0)?;
for i in 0..len {
let arg0 = match arg0_const.as_ref() {
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
None => as_veclit(arg0.get_ref(i))?,
};
let Some(arg0) = arg0 else {
result.push_null();
continue;
};
result.push(Some(DVectorView::from_slice(&arg0, arg0.len()).product()));
}
Ok(result.to_vector())
let calculator = VectorCalculator {
name: self.name(),
func: body,
};
calculator.invoke_with_single_argument(args)
}
}
@@ -114,27 +90,39 @@ impl Display for ElemProductFunction {
mod tests {
use std::sync::Arc;
use datatypes::vectors::StringVector;
use arrow_schema::Field;
use datafusion::arrow::array::{Array, AsArray, StringArray};
use datafusion::arrow::datatypes::Float32Type;
use datafusion_common::config::ConfigOptions;
use super::*;
use crate::function::FunctionContext;
#[test]
fn test_elem_product() {
let func = ElemProductFunction;
let input0 = Arc::new(StringVector::from(vec![
let input = Arc::new(StringArray::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[4.0,5.0,6.0]".to_string()),
None,
]));
let result = func.eval(&FunctionContext::default(), &[input0]).unwrap();
let result = func
.invoke_with_args(ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input.clone())],
arg_fields: vec![],
number_rows: input.len(),
return_field: Arc::new(Field::new("x", DataType::Float32, true)),
config_options: Arc::new(ConfigOptions::new()),
})
.and_then(|v| ColumnarValue::values_to_arrays(&[v]))
.map(|mut a| a.remove(0))
.unwrap();
let result = result.as_primitive::<Float32Type>();
let result = result.as_ref();
assert_eq!(result.len(), 3);
assert_eq!(result.get_ref(0).as_f32().unwrap(), Some(6.0));
assert_eq!(result.get_ref(1).as_f32().unwrap(), Some(120.0));
assert_eq!(result.get_ref(2).as_f32().unwrap(), None);
assert_eq!(result.value(0), 6.0);
assert_eq!(result.value(1), 120.0);
assert!(result.is_null(2));
}
}

View File

@@ -12,20 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::borrow::Cow;
use std::fmt::Display;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use common_query::error::Result;
use datafusion::arrow::datatypes::DataType;
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::ScalarValue;
use datafusion_expr::type_coercion::aggregates::{BINARYS, STRINGS};
use datafusion_expr::{Signature, TypeSignature, Volatility};
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef};
use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
use nalgebra::DVectorView;
use snafu::ensure;
use crate::function::{Function, FunctionContext};
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
use crate::function::Function;
use crate::scalars::vector::{VectorCalculator, impl_conv};
const NAME: &str = "vec_elem_sum";
@@ -51,43 +49,21 @@ impl Function for ElemSumFunction {
)
}
fn eval(
fn invoke_with_args(
&self,
_func_ctx: &FunctionContext,
columns: &[VectorRef],
) -> common_query::error::Result<VectorRef> {
ensure!(
columns.len() == 1,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly one, have: {}",
columns.len()
)
}
);
let arg0 = &columns[0];
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
let v0 =
impl_conv::as_veclit(v0)?.map(|v0| DVectorView::from_slice(&v0, v0.len()).sum());
Ok(ScalarValue::Float32(v0))
};
let len = arg0.len();
let mut result = Float32VectorBuilder::with_capacity(len);
if len == 0 {
return Ok(result.to_vector());
}
let arg0_const = as_veclit_if_const(arg0)?;
for i in 0..len {
let arg0 = match arg0_const.as_ref() {
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
None => as_veclit(arg0.get_ref(i))?,
};
let Some(arg0) = arg0 else {
result.push_null();
continue;
};
result.push(Some(DVectorView::from_slice(&arg0, arg0.len()).sum()));
}
Ok(result.to_vector())
let calculator = VectorCalculator {
name: self.name(),
func: body,
};
calculator.invoke_with_single_argument(args)
}
}
@@ -101,27 +77,40 @@ impl Display for ElemSumFunction {
mod tests {
use std::sync::Arc;
use datatypes::vectors::StringVector;
use arrow::array::StringViewArray;
use arrow_schema::Field;
use datafusion::arrow::array::{Array, AsArray};
use datafusion::arrow::datatypes::Float32Type;
use datafusion_common::config::ConfigOptions;
use super::*;
use crate::function::FunctionContext;
#[test]
fn test_elem_sum() {
let func = ElemSumFunction;
let input0 = Arc::new(StringVector::from(vec![
let input = Arc::new(StringViewArray::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[4.0,5.0,6.0]".to_string()),
None,
]));
let result = func.eval(&FunctionContext::default(), &[input0]).unwrap();
let result = func
.invoke_with_args(ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input.clone())],
arg_fields: vec![],
number_rows: input.len(),
return_field: Arc::new(Field::new("x", DataType::Float32, true)),
config_options: Arc::new(ConfigOptions::new()),
})
.and_then(|v| ColumnarValue::values_to_arrays(&[v]))
.map(|mut a| a.remove(0))
.unwrap();
let result = result.as_primitive::<Float32Type>();
let result = result.as_ref();
assert_eq!(result.len(), 3);
assert_eq!(result.get_ref(0).as_f32().unwrap(), Some(6.0));
assert_eq!(result.get_ref(1).as_f32().unwrap(), Some(15.0));
assert_eq!(result.get_ref(2).as_f32().unwrap(), None);
assert_eq!(result.value(0), 6.0);
assert_eq!(result.value(1), 15.0);
assert!(result.is_null(2));
}
}

View File

@@ -13,40 +13,18 @@
// limitations under the License.
use std::borrow::Cow;
use std::sync::Arc;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use datatypes::prelude::ConcreteDataType;
use datatypes::value::ValueRef;
use datatypes::vectors::Vector;
/// Convert a constant string or binary literal to a vector literal.
pub fn as_veclit_if_const(arg: &Arc<dyn Vector>) -> Result<Option<Cow<'_, [f32]>>> {
if !arg.is_const() {
return Ok(None);
}
if arg.data_type() != ConcreteDataType::string_datatype()
&& arg.data_type() != ConcreteDataType::binary_datatype()
{
return Ok(None);
}
as_veclit(arg.get_ref(0))
}
use datafusion_common::ScalarValue;
/// Convert a string or binary literal to a vector literal.
pub fn as_veclit(arg: ValueRef<'_>) -> Result<Option<Cow<'_, [f32]>>> {
match arg.data_type() {
ConcreteDataType::Binary(_) => arg
.as_binary()
.unwrap() // Safe: checked if it is a binary
.map(binlit_as_veclit)
pub fn as_veclit(arg: &ScalarValue) -> Result<Option<Cow<'_, [f32]>>> {
match arg {
ScalarValue::Binary(b) => b.as_ref().map(|x| binlit_as_veclit(x)).transpose(),
ScalarValue::Utf8(s) | ScalarValue::Utf8View(s) => s
.as_ref()
.map(|x| parse_veclit_from_strlit(x).map(Cow::Owned))
.transpose(),
ConcreteDataType::String(_) => arg
.as_string()
.unwrap() // Safe: checked if it is a string
.map(|s| Ok(Cow::Owned(parse_veclit_from_strlit(s)?)))
.transpose(),
ConcreteDataType::Null(_) => Ok(None),
_ => InvalidFuncArgsSnafu {
err_msg: format!("Unsupported data type: {:?}", arg.data_type()),
}

View File

@@ -12,20 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::borrow::Cow;
use std::fmt::Display;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use datafusion_expr::Signature;
use datatypes::arrow::datatypes::DataType;
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
use common_query::error::Result;
use datafusion::arrow::datatypes::DataType;
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::ScalarValue;
use datafusion_expr::{ScalarFunctionArgs, Signature};
use nalgebra::DVectorView;
use snafu::ensure;
use crate::function::{Function, FunctionContext};
use crate::function::Function;
use crate::helper;
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
use crate::scalars::vector::VectorCalculator;
use crate::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit};
const NAME: &str = "vec_scalar_add";
@@ -60,7 +59,7 @@ impl Function for ScalarAddFunction {
}
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Binary)
Ok(DataType::BinaryView)
}
fn signature(&self) -> Signature {
@@ -70,52 +69,26 @@ impl Function for ScalarAddFunction {
)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly two, have: {}",
columns.len()
),
}
);
let arg0 = &columns[0];
let arg1 = &columns[1];
let len = arg0.len();
let mut result = BinaryVectorBuilder::with_capacity(len);
if len == 0 {
return Ok(result.to_vector());
}
let arg1_const = as_veclit_if_const(arg1)?;
for i in 0..len {
let arg0 = arg0.get(i).as_f64_lossy();
let Some(arg0) = arg0 else {
result.push_null();
continue;
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let body = |v0: &ScalarValue, v1: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
let ScalarValue::Float64(Some(v0)) = v0 else {
return Ok(ScalarValue::BinaryView(None));
};
let arg1 = match arg1_const.as_ref() {
Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
None => as_veclit(arg1.get_ref(i))?,
};
let Some(arg1) = arg1 else {
result.push_null();
continue;
};
let v1 = as_veclit(v1)?
.map(|v1| DVectorView::from_slice(&v1, v1.len()).add_scalar(*v0 as f32));
let result = v1.map(|v1| veclit_to_binlit(v1.as_slice()));
Ok(ScalarValue::BinaryView(result))
};
let vec = DVectorView::from_slice(&arg1, arg1.len());
let vec_res = vec.add_scalar(arg0 as _);
let veclit = vec_res.as_slice();
let binlit = veclit_to_binlit(veclit);
result.push(Some(&binlit));
}
Ok(result.to_vector())
let calculator = VectorCalculator {
name: self.name(),
func: body,
};
calculator.invoke_with_args(args)
}
}
@@ -129,7 +102,9 @@ impl Display for ScalarAddFunction {
mod tests {
use std::sync::Arc;
use datatypes::vectors::{Float32Vector, StringVector};
use arrow_schema::Field;
use datafusion::arrow::array::{Array, AsArray, Float64Array, StringViewArray};
use datafusion_common::config::ConfigOptions;
use super::*;
@@ -137,34 +112,42 @@ mod tests {
fn test_scalar_add() {
let func = ScalarAddFunction;
let input0 = Arc::new(Float32Vector::from(vec![
let input0 = Arc::new(Float64Array::from(vec![
Some(1.0),
Some(-1.0),
None,
Some(3.0),
]));
let input1 = Arc::new(StringVector::from(vec![
let input1 = Arc::new(StringViewArray::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[4.0,5.0,6.0]".to_string()),
Some("[7.0,8.0,9.0]".to_string()),
None,
]));
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let result = func
.eval(&FunctionContext::default(), &[input0, input1])
.invoke_with_args(args)
.and_then(|x| x.to_array(4))
.unwrap();
let result = result.as_ref();
let result = result.as_binary_view();
assert_eq!(result.len(), 4);
assert_eq!(
result.get_ref(0).as_binary().unwrap(),
Some(veclit_to_binlit(&[2.0, 3.0, 4.0]).as_slice())
result.value(0),
veclit_to_binlit(&[2.0, 3.0, 4.0]).as_slice()
);
assert_eq!(
result.get_ref(1).as_binary().unwrap(),
Some(veclit_to_binlit(&[3.0, 4.0, 5.0]).as_slice())
result.value(1),
veclit_to_binlit(&[3.0, 4.0, 5.0]).as_slice()
);
assert!(result.get_ref(2).is_null());
assert!(result.get_ref(3).is_null());
assert!(result.is_null(2));
assert!(result.is_null(3));
}
}

View File

@@ -12,20 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::borrow::Cow;
use std::fmt::Display;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use datafusion_expr::Signature;
use datatypes::arrow::datatypes::DataType;
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
use common_query::error::Result;
use datafusion::arrow::datatypes::DataType;
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::ScalarValue;
use datafusion_expr::{ScalarFunctionArgs, Signature};
use nalgebra::DVectorView;
use snafu::ensure;
use crate::function::{Function, FunctionContext};
use crate::function::Function;
use crate::helper;
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
use crate::scalars::vector::VectorCalculator;
use crate::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit};
const NAME: &str = "vec_scalar_mul";
@@ -60,7 +59,7 @@ impl Function for ScalarMulFunction {
}
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Binary)
Ok(DataType::BinaryView)
}
fn signature(&self) -> Signature {
@@ -70,52 +69,26 @@ impl Function for ScalarMulFunction {
)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly two, have: {}",
columns.len()
),
}
);
let arg0 = &columns[0];
let arg1 = &columns[1];
let len = arg0.len();
let mut result = BinaryVectorBuilder::with_capacity(len);
if len == 0 {
return Ok(result.to_vector());
}
let arg1_const = as_veclit_if_const(arg1)?;
for i in 0..len {
let arg0 = arg0.get(i).as_f64_lossy();
let Some(arg0) = arg0 else {
result.push_null();
continue;
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let body = |v0: &ScalarValue, v1: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
let ScalarValue::Float64(Some(v0)) = v0 else {
return Ok(ScalarValue::BinaryView(None));
};
let arg1 = match arg1_const.as_ref() {
Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
None => as_veclit(arg1.get_ref(i))?,
};
let Some(arg1) = arg1 else {
result.push_null();
continue;
};
let v1 =
as_veclit(v1)?.map(|v1| DVectorView::from_slice(&v1, v1.len()).scale(*v0 as f32));
let result = v1.map(|v1| veclit_to_binlit(v1.as_slice()));
Ok(ScalarValue::BinaryView(result))
};
let vec = DVectorView::from_slice(&arg1, arg1.len());
let vec_res = vec.scale(arg0 as _);
let veclit = vec_res.as_slice();
let binlit = veclit_to_binlit(veclit);
result.push(Some(&binlit));
}
Ok(result.to_vector())
let calculator = VectorCalculator {
name: self.name(),
func: body,
};
calculator.invoke_with_args(args)
}
}
@@ -129,7 +102,9 @@ impl Display for ScalarMulFunction {
mod tests {
use std::sync::Arc;
use datatypes::vectors::{Float32Vector, StringVector};
use arrow_schema::Field;
use datafusion::arrow::array::{Array, AsArray, Float64Array, StringViewArray};
use datafusion_common::config::ConfigOptions;
use super::*;
@@ -137,34 +112,42 @@ mod tests {
fn test_scalar_mul() {
let func = ScalarMulFunction;
let input0 = Arc::new(Float32Vector::from(vec![
let input0 = Arc::new(Float64Array::from(vec![
Some(2.0),
Some(-0.5),
None,
Some(3.0),
]));
let input1 = Arc::new(StringVector::from(vec![
let input1 = Arc::new(StringViewArray::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[8.0,10.0,12.0]".to_string()),
Some("[7.0,8.0,9.0]".to_string()),
None,
]));
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let result = func
.eval(&FunctionContext::default(), &[input0, input1])
.invoke_with_args(args)
.and_then(|x| x.to_array(4))
.unwrap();
let result = result.as_ref();
let result = result.as_binary_view();
assert_eq!(result.len(), 4);
assert_eq!(
result.get_ref(0).as_binary().unwrap(),
Some(veclit_to_binlit(&[2.0, 4.0, 6.0]).as_slice())
result.value(0),
veclit_to_binlit(&[2.0, 4.0, 6.0]).as_slice()
);
assert_eq!(
result.get_ref(1).as_binary().unwrap(),
Some(veclit_to_binlit(&[-4.0, -5.0, -6.0]).as_slice())
result.value(1),
veclit_to_binlit(&[-4.0, -5.0, -6.0]).as_slice()
);
assert!(result.get_ref(2).is_null());
assert!(result.get_ref(3).is_null());
assert!(result.is_null(2));
assert!(result.is_null(3));
}
}

View File

@@ -15,17 +15,17 @@
use std::borrow::Cow;
use std::fmt::Display;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use datafusion_expr::Signature;
use datatypes::arrow::datatypes::DataType;
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
use common_query::error::Result;
use datafusion::arrow::datatypes::DataType;
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{DataFusionError, ScalarValue};
use datafusion_expr::{ScalarFunctionArgs, Signature};
use nalgebra::DVectorView;
use snafu::ensure;
use crate::function::{Function, FunctionContext};
use crate::function::Function;
use crate::helper;
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
use crate::scalars::vector::VectorCalculator;
use crate::scalars::vector::impl_conv::veclit_to_binlit;
const NAME: &str = "vec_add";
@@ -51,7 +51,7 @@ impl Function for VectorAddFunction {
}
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Binary)
Ok(DataType::BinaryView)
}
fn signature(&self) -> Signature {
@@ -61,66 +61,36 @@ impl Function for VectorAddFunction {
)
}
fn eval(
fn invoke_with_args(
&self,
_func_ctx: &FunctionContext,
columns: &[VectorRef],
) -> common_query::error::Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly two, have: {}",
columns.len()
)
}
);
let arg0 = &columns[0];
let arg1 = &columns[1];
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let body = |v0: &Option<Cow<[f32]>>,
v1: &Option<Cow<[f32]>>|
-> datafusion_common::Result<ScalarValue> {
let result = if let (Some(v0), Some(v1)) = (v0, v1) {
let v0 = DVectorView::from_slice(v0, v0.len());
let v1 = DVectorView::from_slice(v1, v1.len());
if v0.len() != v1.len() {
return Err(DataFusionError::Execution(format!(
"vectors length not match: {}",
self.name()
)));
}
ensure!(
arg0.len() == arg1.len(),
InvalidFuncArgsSnafu {
err_msg: format!(
"The lengths of the vector are not aligned, args 0: {}, args 1: {}",
arg0.len(),
arg1.len(),
)
}
);
let len = arg0.len();
let mut result = BinaryVectorBuilder::with_capacity(len);
if len == 0 {
return Ok(result.to_vector());
}
let arg0_const = as_veclit_if_const(arg0)?;
let arg1_const = as_veclit_if_const(arg1)?;
for i in 0..len {
let arg0 = match arg0_const.as_ref() {
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
None => as_veclit(arg0.get_ref(i))?,
let result = veclit_to_binlit((v0 + v1).as_slice());
Some(result)
} else {
None
};
let arg1 = match arg1_const.as_ref() {
Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
None => as_veclit(arg1.get_ref(i))?,
};
let (Some(arg0), Some(arg1)) = (arg0, arg1) else {
result.push_null();
continue;
};
let vec0 = DVectorView::from_slice(&arg0, arg0.len());
let vec1 = DVectorView::from_slice(&arg1, arg1.len());
Ok(ScalarValue::BinaryView(result))
};
let vec_res = vec0 + vec1;
let veclit = vec_res.as_slice();
let binlit = veclit_to_binlit(veclit);
result.push(Some(&binlit));
}
Ok(result.to_vector())
let calculator = VectorCalculator {
name: self.name(),
func: body,
};
calculator.invoke_with_vectors(args)
}
}
@@ -134,8 +104,9 @@ impl Display for VectorAddFunction {
mod tests {
use std::sync::Arc;
use common_query::error::Error;
use datatypes::vectors::StringVector;
use arrow_schema::Field;
use datafusion::arrow::array::{Array, AsArray, StringViewArray};
use datafusion_common::config::ConfigOptions;
use super::*;
@@ -143,63 +114,71 @@ mod tests {
fn test_sub() {
let func = VectorAddFunction;
let input0 = Arc::new(StringVector::from(vec![
let input0 = Arc::new(StringViewArray::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[4.0,5.0,6.0]".to_string()),
None,
Some("[2.0,3.0,3.0]".to_string()),
]));
let input1 = Arc::new(StringVector::from(vec![
let input1 = Arc::new(StringViewArray::from(vec![
Some("[1.0,1.0,1.0]".to_string()),
Some("[6.0,5.0,4.0]".to_string()),
Some("[3.0,2.0,2.0]".to_string()),
None,
]));
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let result = func
.eval(&FunctionContext::default(), &[input0, input1])
.invoke_with_args(args)
.and_then(|x| x.to_array(4))
.unwrap();
let result = result.as_ref();
let result = result.as_binary_view();
assert_eq!(result.len(), 4);
assert_eq!(
result.get_ref(0).as_binary().unwrap(),
Some(veclit_to_binlit(&[2.0, 3.0, 4.0]).as_slice())
result.value(0),
veclit_to_binlit(&[2.0, 3.0, 4.0]).as_slice()
);
assert_eq!(
result.get_ref(1).as_binary().unwrap(),
Some(veclit_to_binlit(&[10.0, 10.0, 10.0]).as_slice())
result.value(1),
veclit_to_binlit(&[10.0, 10.0, 10.0]).as_slice()
);
assert!(result.get_ref(2).is_null());
assert!(result.get_ref(3).is_null());
assert!(result.is_null(2));
assert!(result.is_null(3));
}
#[test]
fn test_sub_error() {
let func = VectorAddFunction;
let input0 = Arc::new(StringVector::from(vec![
let input0 = Arc::new(StringViewArray::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[4.0,5.0,6.0]".to_string()),
None,
Some("[2.0,3.0,3.0]".to_string()),
]));
let input1 = Arc::new(StringVector::from(vec![
let input1 = Arc::new(StringViewArray::from(vec![
Some("[1.0,1.0,1.0]".to_string()),
Some("[6.0,5.0,4.0]".to_string()),
Some("[3.0,2.0,2.0]".to_string()),
]));
let result = func.eval(&FunctionContext::default(), &[input0, input1]);
match result {
Err(Error::InvalidFuncArgs { err_msg, .. }) => {
assert_eq!(
err_msg,
"The lengths of the vector are not aligned, args 0: 4, args 1: 3"
)
}
_ => unreachable!(),
}
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let e = func.invoke_with_args(args).unwrap_err();
assert!(e.to_string().starts_with(
"Internal error: Arguments has mixed length. Expected length: 4, found length: 3."
));
}
}

View File

@@ -12,19 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::borrow::Cow;
use std::fmt::Display;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use common_query::error::Result;
use datafusion::arrow::datatypes::DataType;
use datafusion::logical_expr::ColumnarValue;
use datafusion::logical_expr_common::type_coercion::aggregates::{BINARYS, STRINGS};
use datafusion_expr::{Signature, TypeSignature, Volatility};
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::vectors::{MutableVector, UInt64VectorBuilder, VectorRef};
use snafu::ensure;
use datafusion_common::ScalarValue;
use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
use crate::function::{Function, FunctionContext};
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
use crate::function::Function;
use crate::scalars::vector::VectorCalculator;
use crate::scalars::vector::impl_conv::as_veclit;
const NAME: &str = "vec_dim";
@@ -63,43 +62,20 @@ impl Function for VectorDimFunction {
)
}
fn eval(
fn invoke_with_args(
&self,
_func_ctx: &FunctionContext,
columns: &[VectorRef],
) -> common_query::error::Result<VectorRef> {
ensure!(
columns.len() == 1,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly one, have: {}",
columns.len()
)
}
);
let arg0 = &columns[0];
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
let v = as_veclit(v0)?.map(|v0| v0.len() as u64);
Ok(ScalarValue::UInt64(v))
};
let len = arg0.len();
let mut result = UInt64VectorBuilder::with_capacity(len);
if len == 0 {
return Ok(result.to_vector());
}
let arg0_const = as_veclit_if_const(arg0)?;
for i in 0..len {
let arg0 = match arg0_const.as_ref() {
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
None => as_veclit(arg0.get_ref(i))?,
};
let Some(arg0) = arg0 else {
result.push_null();
continue;
};
result.push(Some(arg0.len() as u64));
}
Ok(result.to_vector())
let calculator = VectorCalculator {
name: self.name(),
func: body,
};
calculator.invoke_with_single_argument(args)
}
}
@@ -113,8 +89,10 @@ impl Display for VectorDimFunction {
mod tests {
use std::sync::Arc;
use common_query::error::Error;
use datatypes::vectors::StringVector;
use arrow_schema::Field;
use datafusion::arrow::array::{Array, AsArray, StringViewArray};
use datafusion::arrow::datatypes::UInt64Type;
use datafusion_common::config::ConfigOptions;
use super::*;
@@ -122,49 +100,60 @@ mod tests {
fn test_vec_dim() {
let func = VectorDimFunction;
let input0 = Arc::new(StringVector::from(vec![
let input0 = Arc::new(StringViewArray::from(vec![
Some("[0.0,2.0,3.0]".to_string()),
Some("[1.0,2.0,3.0,4.0]".to_string()),
None,
Some("[5.0]".to_string()),
]));
let result = func.eval(&FunctionContext::default(), &[input0]).unwrap();
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input0)],
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new("x", DataType::UInt64, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let result = func
.invoke_with_args(args)
.and_then(|x| x.to_array(4))
.unwrap();
let result = result.as_ref();
let result = result.as_primitive::<UInt64Type>();
assert_eq!(result.len(), 4);
assert_eq!(result.get_ref(0).as_u64().unwrap(), Some(3));
assert_eq!(result.get_ref(1).as_u64().unwrap(), Some(4));
assert!(result.get_ref(2).is_null());
assert_eq!(result.get_ref(3).as_u64().unwrap(), Some(1));
assert_eq!(result.value(0), 3);
assert_eq!(result.value(1), 4);
assert!(result.is_null(2));
assert_eq!(result.value(3), 1);
}
#[test]
fn test_dim_error() {
let func = VectorDimFunction;
let input0 = Arc::new(StringVector::from(vec![
let input0 = Arc::new(StringViewArray::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[4.0,5.0,6.0]".to_string()),
None,
Some("[2.0,3.0,3.0]".to_string()),
]));
let input1 = Arc::new(StringVector::from(vec![
let input1 = Arc::new(StringViewArray::from(vec![
Some("[1.0,1.0,1.0]".to_string()),
Some("[6.0,5.0,4.0]".to_string()),
Some("[3.0,2.0,2.0]".to_string()),
]));
let result = func.eval(&FunctionContext::default(), &[input0, input1]);
match result {
Err(Error::InvalidFuncArgs { err_msg, .. }) => {
assert_eq!(
err_msg,
"The length of the args is not correct, expect exactly one, have: 2"
)
}
_ => unreachable!(),
}
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new("x", DataType::UInt64, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let e = func.invoke_with_args(args).unwrap_err();
assert!(
e.to_string()
.starts_with("Execution error: vec_dim function requires 1 argument, got 2")
)
}
}

View File

@@ -15,17 +15,17 @@
use std::borrow::Cow;
use std::fmt::Display;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use datafusion_expr::Signature;
use datatypes::arrow::datatypes::DataType;
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
use common_query::error::Result;
use datafusion::arrow::datatypes::DataType;
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{DataFusionError, ScalarValue};
use datafusion_expr::{ScalarFunctionArgs, Signature};
use nalgebra::DVectorView;
use snafu::ensure;
use crate::function::{Function, FunctionContext};
use crate::function::Function;
use crate::helper;
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
use crate::scalars::vector::VectorCalculator;
use crate::scalars::vector::impl_conv::veclit_to_binlit;
const NAME: &str = "vec_div";
@@ -52,7 +52,7 @@ impl Function for VectorDivFunction {
}
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Binary)
Ok(DataType::BinaryView)
}
fn signature(&self) -> Signature {
@@ -62,64 +62,36 @@ impl Function for VectorDivFunction {
)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly two, have: {}",
columns.len()
),
}
);
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let body = |v0: &Option<Cow<[f32]>>,
v1: &Option<Cow<[f32]>>|
-> datafusion_common::Result<ScalarValue> {
let result = if let (Some(v0), Some(v1)) = (v0, v1) {
let v0 = DVectorView::from_slice(v0, v0.len());
let v1 = DVectorView::from_slice(v1, v1.len());
if v0.len() != v1.len() {
return Err(DataFusionError::Execution(format!(
"vectors length not match: {}",
self.name()
)));
}
let arg0 = &columns[0];
let arg1 = &columns[1];
let len = arg0.len();
let mut result = BinaryVectorBuilder::with_capacity(len);
if len == 0 {
return Ok(result.to_vector());
}
let arg0_const = as_veclit_if_const(arg0)?;
let arg1_const = as_veclit_if_const(arg1)?;
for i in 0..len {
let arg0 = match arg0_const.as_ref() {
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
None => as_veclit(arg0.get_ref(i))?,
};
let arg1 = match arg1_const.as_ref() {
Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
None => as_veclit(arg1.get_ref(i))?,
};
if let (Some(arg0), Some(arg1)) = (arg0, arg1) {
ensure!(
arg0.len() == arg1.len(),
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the vectors must match for division, have: {} vs {}",
arg0.len(),
arg1.len()
),
}
);
let vec0 = DVectorView::from_slice(&arg0, arg0.len());
let vec1 = DVectorView::from_slice(&arg1, arg1.len());
let vec_res = vec0.component_div(&vec1);
let veclit = vec_res.as_slice();
let binlit = veclit_to_binlit(veclit);
result.push(Some(&binlit));
let result = veclit_to_binlit((v0.component_div(&v1)).as_slice());
Some(result)
} else {
result.push_null();
}
}
None
};
Ok(ScalarValue::BinaryView(result))
};
Ok(result.to_vector())
let calculator = VectorCalculator {
name: self.name(),
func: body,
};
calculator.invoke_with_vectors(args)
}
}
@@ -133,8 +105,9 @@ impl Display for VectorDivFunction {
mod tests {
use std::sync::Arc;
use common_query::error;
use datatypes::vectors::StringVector;
use arrow_schema::Field;
use datafusion::arrow::array::{Array, AsArray, StringViewArray};
use datafusion_common::config::ConfigOptions;
use super::*;
@@ -144,69 +117,80 @@ mod tests {
let vec0 = vec![1.0, 2.0, 3.0];
let vec1 = vec![1.0, 1.0];
let (len0, len1) = (vec0.len(), vec1.len());
let input0 = Arc::new(StringVector::from(vec![Some(format!("{vec0:?}"))]));
let input1 = Arc::new(StringVector::from(vec![Some(format!("{vec1:?}"))]));
let input0 = Arc::new(StringViewArray::from(vec![Some(format!("{vec0:?}"))]));
let input1 = Arc::new(StringViewArray::from(vec![Some(format!("{vec1:?}"))]));
let err = func
.eval(&FunctionContext::default(), &[input0, input1])
.unwrap_err();
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
arg_fields: vec![],
number_rows: 3,
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let e = func.invoke_with_args(args).unwrap_err();
assert_eq!(
e.to_string(),
"Execution error: vectors length not match: vec_div"
);
match err {
error::Error::InvalidFuncArgs { err_msg, .. } => {
assert_eq!(
err_msg,
format!(
"The length of the vectors must match for division, have: {} vs {}",
len0, len1
)
)
}
_ => unreachable!(),
}
let input0 = Arc::new(StringVector::from(vec![
let input0 = Arc::new(StringViewArray::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[8.0,10.0,12.0]".to_string()),
Some("[7.0,8.0,9.0]".to_string()),
None,
]));
let input1 = Arc::new(StringVector::from(vec![
let input1 = Arc::new(StringViewArray::from(vec![
Some("[1.0,1.0,1.0]".to_string()),
Some("[2.0,2.0,2.0]".to_string()),
None,
Some("[3.0,3.0,3.0]".to_string()),
]));
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let result = func
.eval(&FunctionContext::default(), &[input0, input1])
.invoke_with_args(args)
.and_then(|x| x.to_array(4))
.unwrap();
let result = result.as_ref();
let result = result.as_binary_view();
assert_eq!(result.len(), 4);
assert_eq!(
result.get_ref(0).as_binary().unwrap(),
Some(veclit_to_binlit(&[1.0, 2.0, 3.0]).as_slice())
result.value(0),
veclit_to_binlit(&[1.0, 2.0, 3.0]).as_slice()
);
assert_eq!(
result.get_ref(1).as_binary().unwrap(),
Some(veclit_to_binlit(&[4.0, 5.0, 6.0]).as_slice())
result.value(1),
veclit_to_binlit(&[4.0, 5.0, 6.0]).as_slice()
);
assert!(result.get_ref(2).is_null());
assert!(result.get_ref(3).is_null());
assert!(result.is_null(2));
assert!(result.is_null(3));
let input0 = Arc::new(StringVector::from(vec![Some("[1.0,-2.0]".to_string())]));
let input1 = Arc::new(StringVector::from(vec![Some("[0.0,0.0]".to_string())]));
let input0 = Arc::new(StringViewArray::from(vec![Some("[1.0,-2.0]".to_string())]));
let input1 = Arc::new(StringViewArray::from(vec![Some("[0.0,0.0]".to_string())]));
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
arg_fields: vec![],
number_rows: 2,
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let result = func
.eval(&FunctionContext::default(), &[input0, input1])
.invoke_with_args(args)
.and_then(|x| x.to_array(2))
.unwrap();
let result = result.as_ref();
let result = result.as_binary_view();
assert_eq!(
result.get_ref(0).as_binary().unwrap(),
Some(veclit_to_binlit(&[f64::INFINITY as f32, f64::NEG_INFINITY as f32]).as_slice())
result.value(0),
veclit_to_binlit(&[f64::INFINITY as f32, f64::NEG_INFINITY as f32]).as_slice()
);
}
}

View File

@@ -12,19 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::borrow::Cow;
use std::fmt::Display;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use datafusion_expr::Signature;
use common_query::error::Result;
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{DataFusionError, ScalarValue};
use datafusion_expr::{ScalarFunctionArgs, Signature};
use datatypes::arrow::datatypes::DataType;
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef};
use snafu::ensure;
use crate::function::{Function, FunctionContext};
use crate::function::Function;
use crate::helper;
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
use crate::scalars::vector::VectorCalculator;
use crate::scalars::vector::impl_conv::as_veclit;
const NAME: &str = "vec_kth_elem";
@@ -63,72 +62,44 @@ impl Function for VectorKthElemFunction {
)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly two, have: {}",
columns.len()
),
}
);
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let body = |v0: &ScalarValue, v1: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
let v0 = as_veclit(v0)?;
let arg0 = &columns[0];
let arg1 = &columns[1];
let v1 = match v1 {
ScalarValue::Int64(None) => return Ok(ScalarValue::Float32(None)),
ScalarValue::Int64(Some(v1)) if *v1 >= 0 => *v1 as usize,
_ => {
return Err(DataFusionError::Execution(format!(
"2nd argument not a valid index or expected datatype: {}",
self.name()
)));
}
};
let len = arg0.len();
let mut result = Float32VectorBuilder::with_capacity(len);
if len == 0 {
return Ok(result.to_vector());
let result = v0
.map(|v0| {
if v1 >= v0.len() {
Err(DataFusionError::Execution(format!(
"index out of bound: {}",
self.name()
)))
} else {
Ok(v0[v1])
}
})
.transpose()?;
Ok(ScalarValue::Float32(result))
};
let arg0_const = as_veclit_if_const(arg0)?;
for i in 0..len {
let arg0 = match arg0_const.as_ref() {
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
None => as_veclit(arg0.get_ref(i))?,
};
let Some(arg0) = arg0 else {
result.push_null();
continue;
};
let arg1 = arg1.get(i).as_f64_lossy();
let Some(arg1) = arg1 else {
result.push_null();
continue;
};
ensure!(
arg1 >= 0.0 && arg1.fract() == 0.0,
InvalidFuncArgsSnafu {
err_msg: format!(
"Invalid argument: k must be a non-negative integer, but got k = {}.",
arg1
),
}
);
let k = arg1 as usize;
ensure!(
k < arg0.len(),
InvalidFuncArgsSnafu {
err_msg: format!(
"Out of range: k must be in the range [0, {}], but got k = {}.",
arg0.len() - 1,
k
),
}
);
let value = arg0[k];
result.push(Some(value));
}
Ok(result.to_vector())
let calculator = VectorCalculator {
name: self.name(),
func: body,
};
calculator.invoke_with_args(args)
}
}
@@ -142,8 +113,10 @@ impl Display for VectorKthElemFunction {
mod tests {
use std::sync::Arc;
use common_query::error;
use datatypes::vectors::{Int64Vector, StringVector};
use arrow_schema::Field;
use datafusion::arrow::array::{Array, ArrayRef, AsArray, Int64Array, StringViewArray};
use datafusion::arrow::datatypes::Float32Type;
use datafusion_common::config::ConfigOptions;
use super::*;
@@ -151,55 +124,66 @@ mod tests {
fn test_vec_kth_elem() {
let func = VectorKthElemFunction;
let input0 = Arc::new(StringVector::from(vec![
let input0: ArrayRef = Arc::new(StringViewArray::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[4.0,5.0,6.0]".to_string()),
Some("[7.0,8.0,9.0]".to_string()),
None,
]));
let input1 = Arc::new(Int64Vector::from(vec![Some(0), Some(2), None, Some(1)]));
let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(0), Some(2), None, Some(1)]));
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new("x", DataType::Float32, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let result = func
.eval(&FunctionContext::default(), &[input0, input1])
.invoke_with_args(args)
.and_then(|x| x.to_array(4))
.unwrap();
let result = result.as_ref();
let result = result.as_primitive::<Float32Type>();
assert_eq!(result.len(), 4);
assert_eq!(result.get_ref(0).as_f32().unwrap(), Some(1.0));
assert_eq!(result.get_ref(1).as_f32().unwrap(), Some(6.0));
assert!(result.get_ref(2).is_null());
assert!(result.get_ref(3).is_null());
assert_eq!(result.value(0), 1.0);
assert_eq!(result.value(1), 6.0);
assert!(result.is_null(2));
assert!(result.is_null(3));
let input0 = Arc::new(StringVector::from(vec![Some("[1.0,2.0,3.0]".to_string())]));
let input1 = Arc::new(Int64Vector::from(vec![Some(3)]));
let input0: ArrayRef = Arc::new(StringViewArray::from(vec![Some(
"[1.0,2.0,3.0]".to_string(),
)]));
let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(3)]));
let err = func
.eval(&FunctionContext::default(), &[input0, input1])
.unwrap_err();
match err {
error::Error::InvalidFuncArgs { err_msg, .. } => {
assert_eq!(
err_msg,
format!("Out of range: k must be in the range [0, 2], but got k = 3.")
)
}
_ => unreachable!(),
}
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
arg_fields: vec![],
number_rows: 3,
return_field: Arc::new(Field::new("x", DataType::Float32, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let e = func.invoke_with_args(args).unwrap_err();
assert!(
e.to_string()
.starts_with("Execution error: index out of bound: vec_kth_elem")
);
let input0 = Arc::new(StringVector::from(vec![Some("[1.0,2.0,3.0]".to_string())]));
let input1 = Arc::new(Int64Vector::from(vec![Some(-1)]));
let input0: ArrayRef = Arc::new(StringViewArray::from(vec![Some(
"[1.0,2.0,3.0]".to_string(),
)]));
let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(-1)]));
let err = func
.eval(&FunctionContext::default(), &[input0, input1])
.unwrap_err();
match err {
error::Error::InvalidFuncArgs { err_msg, .. } => {
assert_eq!(
err_msg,
format!("Invalid argument: k must be a non-negative integer, but got k = -1.")
)
}
_ => unreachable!(),
}
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
arg_fields: vec![],
number_rows: 3,
return_field: Arc::new(Field::new("x", DataType::Float32, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let e = func.invoke_with_args(args).unwrap_err();
assert!(e.to_string().starts_with(
"Execution error: 2nd argument not a valid index or expected datatype: vec_kth_elem"
));
}
}

View File

@@ -15,17 +15,17 @@
use std::borrow::Cow;
use std::fmt::Display;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use datafusion_expr::Signature;
use datatypes::arrow::datatypes::DataType;
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
use common_query::error::Result;
use datafusion::arrow::datatypes::DataType;
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{DataFusionError, ScalarValue};
use datafusion_expr::{ScalarFunctionArgs, Signature};
use nalgebra::DVectorView;
use snafu::ensure;
use crate::function::{Function, FunctionContext};
use crate::function::Function;
use crate::helper;
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
use crate::scalars::vector::VectorCalculator;
use crate::scalars::vector::impl_conv::veclit_to_binlit;
const NAME: &str = "vec_mul";
@@ -52,7 +52,7 @@ impl Function for VectorMulFunction {
}
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Binary)
Ok(DataType::BinaryView)
}
fn signature(&self) -> Signature {
@@ -62,64 +62,36 @@ impl Function for VectorMulFunction {
)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly two, have: {}",
columns.len()
),
}
);
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let body = |v0: &Option<Cow<[f32]>>,
v1: &Option<Cow<[f32]>>|
-> datafusion_common::Result<ScalarValue> {
let result = if let (Some(v0), Some(v1)) = (v0, v1) {
let v0 = DVectorView::from_slice(v0, v0.len());
let v1 = DVectorView::from_slice(v1, v1.len());
if v0.len() != v1.len() {
return Err(DataFusionError::Execution(format!(
"vectors length not match: {}",
self.name()
)));
}
let arg0 = &columns[0];
let arg1 = &columns[1];
let len = arg0.len();
let mut result = BinaryVectorBuilder::with_capacity(len);
if len == 0 {
return Ok(result.to_vector());
}
let arg0_const = as_veclit_if_const(arg0)?;
let arg1_const = as_veclit_if_const(arg1)?;
for i in 0..len {
let arg0 = match arg0_const.as_ref() {
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
None => as_veclit(arg0.get_ref(i))?,
};
let arg1 = match arg1_const.as_ref() {
Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
None => as_veclit(arg1.get_ref(i))?,
};
if let (Some(arg0), Some(arg1)) = (arg0, arg1) {
ensure!(
arg0.len() == arg1.len(),
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the vectors must match for multiplying, have: {} vs {}",
arg0.len(),
arg1.len()
),
}
);
let vec0 = DVectorView::from_slice(&arg0, arg0.len());
let vec1 = DVectorView::from_slice(&arg1, arg1.len());
let vec_res = vec1.component_mul(&vec0);
let veclit = vec_res.as_slice();
let binlit = veclit_to_binlit(veclit);
result.push(Some(&binlit));
let result = veclit_to_binlit((v0.component_mul(&v1)).as_slice());
Some(result)
} else {
result.push_null();
}
}
None
};
Ok(ScalarValue::BinaryView(result))
};
Ok(result.to_vector())
let calculator = VectorCalculator {
name: self.name(),
func: body,
};
calculator.invoke_with_vectors(args)
}
}
@@ -133,8 +105,9 @@ impl Display for VectorMulFunction {
mod tests {
use std::sync::Arc;
use common_query::error;
use datatypes::vectors::StringVector;
use arrow_schema::Field;
use datafusion::arrow::array::{Array, AsArray, StringViewArray};
use datafusion_common::config::ConfigOptions;
use super::*;
@@ -144,56 +117,59 @@ mod tests {
let vec0 = vec![1.0, 2.0, 3.0];
let vec1 = vec![1.0, 1.0];
let (len0, len1) = (vec0.len(), vec1.len());
let input0 = Arc::new(StringVector::from(vec![Some(format!("{vec0:?}"))]));
let input1 = Arc::new(StringVector::from(vec![Some(format!("{vec1:?}"))]));
let input0 = Arc::new(StringViewArray::from(vec![Some(format!("{vec0:?}"))]));
let input1 = Arc::new(StringViewArray::from(vec![Some(format!("{vec1:?}"))]));
let err = func
.eval(&FunctionContext::default(), &[input0, input1])
.unwrap_err();
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let e = func.invoke_with_args(args).unwrap_err();
assert!(
e.to_string()
.starts_with("Execution error: vectors length not match: vec_mul")
);
match err {
error::Error::InvalidFuncArgs { err_msg, .. } => {
assert_eq!(
err_msg,
format!(
"The length of the vectors must match for multiplying, have: {} vs {}",
len0, len1
)
)
}
_ => unreachable!(),
}
let input0 = Arc::new(StringVector::from(vec![
let input0 = Arc::new(StringViewArray::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[8.0,10.0,12.0]".to_string()),
Some("[7.0,8.0,9.0]".to_string()),
None,
]));
let input1 = Arc::new(StringVector::from(vec![
let input1 = Arc::new(StringViewArray::from(vec![
Some("[1.0,1.0,1.0]".to_string()),
Some("[2.0,2.0,2.0]".to_string()),
None,
Some("[3.0,3.0,3.0]".to_string()),
]));
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let result = func
.eval(&FunctionContext::default(), &[input0, input1])
.invoke_with_args(args)
.and_then(|x| x.to_array(4))
.unwrap();
let result = result.as_ref();
let result = result.as_binary_view();
assert_eq!(result.len(), 4);
assert_eq!(
result.get_ref(0).as_binary().unwrap(),
Some(veclit_to_binlit(&[1.0, 2.0, 3.0]).as_slice())
result.value(0),
veclit_to_binlit(&[1.0, 2.0, 3.0]).as_slice()
);
assert_eq!(
result.get_ref(1).as_binary().unwrap(),
Some(veclit_to_binlit(&[16.0, 20.0, 24.0]).as_slice())
result.value(1),
veclit_to_binlit(&[16.0, 20.0, 24.0]).as_slice()
);
assert!(result.get_ref(2).is_null());
assert!(result.get_ref(3).is_null());
assert!(result.is_null(2));
assert!(result.is_null(3));
}
}

View File

@@ -12,20 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::borrow::Cow;
use std::fmt::Display;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use common_query::error::Result;
use datafusion::arrow::datatypes::DataType;
use datafusion::logical_expr::ColumnarValue;
use datafusion::logical_expr_common::type_coercion::aggregates::{BINARYS, STRINGS};
use datafusion_expr::{Signature, TypeSignature, Volatility};
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
use datafusion_common::ScalarValue;
use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
use nalgebra::DVectorView;
use snafu::ensure;
use crate::function::{Function, FunctionContext};
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
use crate::function::Function;
use crate::scalars::vector::VectorCalculator;
use crate::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit};
const NAME: &str = "vec_norm";
@@ -53,7 +52,7 @@ impl Function for VectorNormFunction {
}
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Binary)
Ok(DataType::BinaryView)
}
fn signature(&self) -> Signature {
@@ -66,55 +65,27 @@ impl Function for VectorNormFunction {
)
}
fn eval(
fn invoke_with_args(
&self,
_func_ctx: &FunctionContext,
columns: &[VectorRef],
) -> common_query::error::Result<VectorRef> {
ensure!(
columns.len() == 1,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly one, have: {}",
columns.len()
)
}
);
let arg0 = &columns[0];
let len = arg0.len();
let mut result = BinaryVectorBuilder::with_capacity(len);
if len == 0 {
return Ok(result.to_vector());
}
let arg0_const = as_veclit_if_const(arg0)?;
for i in 0..len {
let arg0 = match arg0_const.as_ref() {
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
None => as_veclit(arg0.get_ref(i))?,
};
let Some(arg0) = arg0 else {
result.push_null();
continue;
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
let v0 = as_veclit(v0)?;
let Some(v0) = v0 else {
return Ok(ScalarValue::BinaryView(None));
};
let vec0 = DVectorView::from_slice(&arg0, arg0.len());
let vec1 = DVectorView::from_slice(&arg0, arg0.len());
let vec2scalar = vec1.component_mul(&vec0);
let scalar_var = vec2scalar.sum().sqrt();
let v0 = DVectorView::from_slice(&v0, v0.len());
let result =
veclit_to_binlit(v0.unscale(v0.component_mul(&v0).sum().sqrt()).as_slice());
Ok(ScalarValue::BinaryView(Some(result)))
};
let vec = DVectorView::from_slice(&arg0, arg0.len());
// Use unscale to avoid division by zero and keep more precision as possible
let vec_res = vec.unscale(scalar_var);
let veclit = vec_res.as_slice();
let binlit = veclit_to_binlit(veclit);
result.push(Some(&binlit));
}
Ok(result.to_vector())
let calculator = VectorCalculator {
name: self.name(),
func: body,
};
calculator.invoke_with_single_argument(args)
}
}
@@ -128,7 +99,9 @@ impl Display for VectorNormFunction {
mod tests {
use std::sync::Arc;
use datatypes::vectors::StringVector;
use arrow_schema::Field;
use datafusion::arrow::array::{Array, AsArray, StringViewArray};
use datafusion_common::config::ConfigOptions;
use super::*;
@@ -136,7 +109,7 @@ mod tests {
fn test_vec_norm() {
let func = VectorNormFunction;
let input0 = Arc::new(StringVector::from(vec![
let input0 = Arc::new(StringViewArray::from(vec![
Some("[0.0,2.0,3.0]".to_string()),
Some("[1.0,2.0,3.0]".to_string()),
Some("[7.0,8.0,9.0]".to_string()),
@@ -144,26 +117,36 @@ mod tests {
None,
]));
let result = func.eval(&FunctionContext::default(), &[input0]).unwrap();
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input0)],
arg_fields: vec![],
number_rows: 5,
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let result = func
.invoke_with_args(args)
.and_then(|x| x.to_array(5))
.unwrap();
let result = result.as_ref();
let result = result.as_binary_view();
assert_eq!(result.len(), 5);
assert_eq!(
result.get_ref(0).as_binary().unwrap(),
Some(veclit_to_binlit(&[0.0, 0.5547002, 0.8320503]).as_slice())
result.value(0),
veclit_to_binlit(&[0.0, 0.5547002, 0.8320503]).as_slice()
);
assert_eq!(
result.get_ref(1).as_binary().unwrap(),
Some(veclit_to_binlit(&[0.26726124, 0.5345225, 0.8017837]).as_slice())
result.value(1),
veclit_to_binlit(&[0.26726124, 0.5345225, 0.8017837]).as_slice()
);
assert_eq!(
result.get_ref(2).as_binary().unwrap(),
Some(veclit_to_binlit(&[0.5025707, 0.5743665, 0.64616233]).as_slice())
result.value(2),
veclit_to_binlit(&[0.5025707, 0.5743665, 0.64616233]).as_slice()
);
assert_eq!(
result.get_ref(3).as_binary().unwrap(),
Some(veclit_to_binlit(&[0.5025707, -0.5743665, 0.64616233]).as_slice())
result.value(3),
veclit_to_binlit(&[0.5025707, -0.5743665, 0.64616233]).as_slice()
);
assert!(result.get_ref(4).is_null());
assert!(result.is_null(4));
}
}

View File

@@ -15,17 +15,17 @@
use std::borrow::Cow;
use std::fmt::Display;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use datafusion_expr::Signature;
use datatypes::arrow::datatypes::DataType;
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
use common_query::error::Result;
use datafusion::arrow::datatypes::DataType;
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{DataFusionError, ScalarValue};
use datafusion_expr::{ScalarFunctionArgs, Signature};
use nalgebra::DVectorView;
use snafu::ensure;
use crate::function::{Function, FunctionContext};
use crate::function::Function;
use crate::helper;
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
use crate::scalars::vector::VectorCalculator;
use crate::scalars::vector::impl_conv::veclit_to_binlit;
const NAME: &str = "vec_sub";
@@ -51,7 +51,7 @@ impl Function for VectorSubFunction {
}
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Binary)
Ok(DataType::BinaryView)
}
fn signature(&self) -> Signature {
@@ -61,66 +61,36 @@ impl Function for VectorSubFunction {
)
}
fn eval(
fn invoke_with_args(
&self,
_func_ctx: &FunctionContext,
columns: &[VectorRef],
) -> common_query::error::Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly two, have: {}",
columns.len()
)
}
);
let arg0 = &columns[0];
let arg1 = &columns[1];
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let body = |v0: &Option<Cow<[f32]>>,
v1: &Option<Cow<[f32]>>|
-> datafusion_common::Result<ScalarValue> {
let result = if let (Some(v0), Some(v1)) = (v0, v1) {
let v0 = DVectorView::from_slice(v0, v0.len());
let v1 = DVectorView::from_slice(v1, v1.len());
if v0.len() != v1.len() {
return Err(DataFusionError::Execution(format!(
"vectors length not match: {}",
self.name()
)));
}
ensure!(
arg0.len() == arg1.len(),
InvalidFuncArgsSnafu {
err_msg: format!(
"The lengths of the vector are not aligned, args 0: {}, args 1: {}",
arg0.len(),
arg1.len(),
)
}
);
let len = arg0.len();
let mut result = BinaryVectorBuilder::with_capacity(len);
if len == 0 {
return Ok(result.to_vector());
}
let arg0_const = as_veclit_if_const(arg0)?;
let arg1_const = as_veclit_if_const(arg1)?;
for i in 0..len {
let arg0 = match arg0_const.as_ref() {
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
None => as_veclit(arg0.get_ref(i))?,
let result = veclit_to_binlit((v0 - v1).as_slice());
Some(result)
} else {
None
};
let arg1 = match arg1_const.as_ref() {
Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
None => as_veclit(arg1.get_ref(i))?,
};
let (Some(arg0), Some(arg1)) = (arg0, arg1) else {
result.push_null();
continue;
};
let vec0 = DVectorView::from_slice(&arg0, arg0.len());
let vec1 = DVectorView::from_slice(&arg1, arg1.len());
Ok(ScalarValue::BinaryView(result))
};
let vec_res = vec0 - vec1;
let veclit = vec_res.as_slice();
let binlit = veclit_to_binlit(veclit);
result.push(Some(&binlit));
}
Ok(result.to_vector())
let calculator = VectorCalculator {
name: self.name(),
func: body,
};
calculator.invoke_with_vectors(args)
}
}
@@ -134,8 +104,9 @@ impl Display for VectorSubFunction {
mod tests {
use std::sync::Arc;
use common_query::error::Error;
use datatypes::vectors::StringVector;
use arrow_schema::Field;
use datafusion::arrow::array::{Array, ArrayRef, AsArray, StringViewArray};
use datafusion_common::config::ConfigOptions;
use super::*;
@@ -143,63 +114,71 @@ mod tests {
fn test_sub() {
let func = VectorSubFunction;
let input0 = Arc::new(StringVector::from(vec![
let input0: ArrayRef = Arc::new(StringViewArray::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[4.0,5.0,6.0]".to_string()),
None,
Some("[2.0,3.0,3.0]".to_string()),
]));
let input1 = Arc::new(StringVector::from(vec![
let input1: ArrayRef = Arc::new(StringViewArray::from(vec![
Some("[1.0,1.0,1.0]".to_string()),
Some("[6.0,5.0,4.0]".to_string()),
Some("[3.0,2.0,2.0]".to_string()),
None,
]));
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let result = func
.eval(&FunctionContext::default(), &[input0, input1])
.invoke_with_args(args)
.and_then(|x| x.to_array(4))
.unwrap();
let result = result.as_ref();
let result = result.as_binary_view();
assert_eq!(result.len(), 4);
assert_eq!(
result.get_ref(0).as_binary().unwrap(),
Some(veclit_to_binlit(&[0.0, 1.0, 2.0]).as_slice())
result.value(0),
veclit_to_binlit(&[0.0, 1.0, 2.0]).as_slice()
);
assert_eq!(
result.get_ref(1).as_binary().unwrap(),
Some(veclit_to_binlit(&[-2.0, 0.0, 2.0]).as_slice())
result.value(1),
veclit_to_binlit(&[-2.0, 0.0, 2.0]).as_slice()
);
assert!(result.get_ref(2).is_null());
assert!(result.get_ref(3).is_null());
assert!(result.is_null(2));
assert!(result.is_null(3));
}
#[test]
fn test_sub_error() {
let func = VectorSubFunction;
let input0 = Arc::new(StringVector::from(vec![
let input0: ArrayRef = Arc::new(StringViewArray::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[4.0,5.0,6.0]".to_string()),
None,
Some("[2.0,3.0,3.0]".to_string()),
]));
let input1 = Arc::new(StringVector::from(vec![
let input1: ArrayRef = Arc::new(StringViewArray::from(vec![
Some("[1.0,1.0,1.0]".to_string()),
Some("[6.0,5.0,4.0]".to_string()),
Some("[3.0,2.0,2.0]".to_string()),
]));
let result = func.eval(&FunctionContext::default(), &[input0, input1]);
match result {
Err(Error::InvalidFuncArgs { err_msg, .. }) => {
assert_eq!(
err_msg,
"The lengths of the vector are not aligned, args 0: 4, args 1: 3"
)
}
_ => unreachable!(),
}
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let e = func.invoke_with_args(args).unwrap_err();
assert!(e.to_string().starts_with(
"Internal error: Arguments has mixed length. Expected length: 4, found length: 3."
));
}
}

View File

@@ -12,18 +12,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::borrow::Cow;
use std::fmt::Display;
use std::sync::Arc;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use datafusion_expr::{Signature, TypeSignature, Volatility};
use datafusion::arrow::array::{Array, AsArray, BinaryViewBuilder};
use datafusion::arrow::datatypes::Int64Type;
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{ScalarValue, utils};
use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
use datatypes::arrow::datatypes::DataType;
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
use snafu::ensure;
use crate::function::{Function, FunctionContext};
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
use crate::function::Function;
use crate::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit};
const NAME: &str = "vec_subvector";
@@ -52,7 +54,7 @@ impl Function for VectorSubvectorFunction {
}
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Binary)
Ok(DataType::BinaryView)
}
fn signature(&self) -> Signature {
@@ -65,50 +67,28 @@ impl Function for VectorSubvectorFunction {
)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 3,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly three, have: {}",
columns.len()
)
}
);
let arg0 = &columns[0];
let arg1 = &columns[1];
let arg2 = &columns[2];
ensure!(
arg0.len() == arg1.len() && arg1.len() == arg2.len(),
InvalidFuncArgsSnafu {
err_msg: format!(
"The lengths of the vector are not aligned, args 0: {}, args 1: {}, args 2: {}",
arg0.len(),
arg1.len(),
arg2.len()
)
}
);
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(&args.args)?;
let [arg0, arg1, arg2] = utils::take_function_args(self.name(), args)?;
let arg1 = arg1.as_primitive::<Int64Type>();
let arg2 = arg2.as_primitive::<Int64Type>();
let len = arg0.len();
let mut result = BinaryVectorBuilder::with_capacity(len);
let mut builder = BinaryViewBuilder::with_capacity(len);
if len == 0 {
return Ok(result.to_vector());
return Ok(ColumnarValue::Array(Arc::new(builder.finish())));
}
let arg0_const = as_veclit_if_const(arg0)?;
for i in 0..len {
let arg0 = match arg0_const.as_ref() {
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
None => as_veclit(arg0.get_ref(i))?,
};
let arg1 = arg1.get(i).as_i64();
let arg2 = arg2.get(i).as_i64();
let v = ScalarValue::try_from_array(&arg0, i)?;
let arg0 = as_veclit(&v)?;
let arg1 = arg1.is_valid(i).then(|| arg1.value(i));
let arg2 = arg2.is_valid(i).then(|| arg2.value(i));
let (Some(arg0), Some(arg1), Some(arg2)) = (arg0, arg1, arg2) else {
result.push_null();
builder.append_null();
continue;
};
@@ -126,10 +106,10 @@ impl Function for VectorSubvectorFunction {
let subvector = &arg0[arg1 as usize..arg2 as usize];
let binlit = veclit_to_binlit(subvector);
result.push(Some(&binlit));
builder.append_value(&binlit);
}
Ok(result.to_vector())
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
}
@@ -143,89 +123,102 @@ impl Display for VectorSubvectorFunction {
mod tests {
use std::sync::Arc;
use common_query::error::Error;
use datatypes::vectors::{Int64Vector, StringVector};
use arrow_schema::Field;
use datafusion::arrow::array::{ArrayRef, Int64Array, StringViewArray};
use datafusion_common::config::ConfigOptions;
use super::*;
use crate::function::FunctionContext;
#[test]
fn test_subvector() {
let func = VectorSubvectorFunction;
let input0 = Arc::new(StringVector::from(vec![
let input0: ArrayRef = Arc::new(StringViewArray::from(vec![
Some("[1.0, 2.0, 3.0, 4.0, 5.0]".to_string()),
Some("[6.0, 7.0, 8.0, 9.0, 10.0]".to_string()),
None,
Some("[11.0, 12.0, 13.0]".to_string()),
]));
let input1 = Arc::new(Int64Vector::from(vec![Some(1), Some(0), Some(0), Some(1)]));
let input2 = Arc::new(Int64Vector::from(vec![Some(3), Some(5), Some(2), Some(3)]));
let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(0), Some(0), Some(1)]));
let input2: ArrayRef = Arc::new(Int64Array::from(vec![Some(3), Some(5), Some(2), Some(3)]));
let args = ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(input0),
ColumnarValue::Array(input1),
ColumnarValue::Array(input2),
],
arg_fields: vec![],
number_rows: 5,
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let result = func
.eval(&FunctionContext::default(), &[input0, input1, input2])
.invoke_with_args(args)
.and_then(|x| x.to_array(5))
.unwrap();
let result = result.as_ref();
let result = result.as_binary_view();
assert_eq!(result.len(), 4);
assert_eq!(result.value(0), veclit_to_binlit(&[2.0, 3.0]).as_slice());
assert_eq!(
result.get_ref(0).as_binary().unwrap(),
Some(veclit_to_binlit(&[2.0, 3.0]).as_slice())
);
assert_eq!(
result.get_ref(1).as_binary().unwrap(),
Some(veclit_to_binlit(&[6.0, 7.0, 8.0, 9.0, 10.0]).as_slice())
);
assert!(result.get_ref(2).is_null());
assert_eq!(
result.get_ref(3).as_binary().unwrap(),
Some(veclit_to_binlit(&[12.0, 13.0]).as_slice())
result.value(1),
veclit_to_binlit(&[6.0, 7.0, 8.0, 9.0, 10.0]).as_slice()
);
assert!(result.is_null(2));
assert_eq!(result.value(3), veclit_to_binlit(&[12.0, 13.0]).as_slice());
}
#[test]
fn test_subvector_error() {
let func = VectorSubvectorFunction;
let input0 = Arc::new(StringVector::from(vec![
let input0: ArrayRef = Arc::new(StringViewArray::from(vec![
Some("[1.0, 2.0, 3.0]".to_string()),
Some("[4.0, 5.0, 6.0]".to_string()),
]));
let input1 = Arc::new(Int64Vector::from(vec![Some(1), Some(2)]));
let input2 = Arc::new(Int64Vector::from(vec![Some(3)]));
let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(2)]));
let input2: ArrayRef = Arc::new(Int64Array::from(vec![Some(3)]));
let result = func.eval(&FunctionContext::default(), &[input0, input1, input2]);
match result {
Err(Error::InvalidFuncArgs { err_msg, .. }) => {
assert_eq!(
err_msg,
"The lengths of the vector are not aligned, args 0: 2, args 1: 2, args 2: 1"
)
}
_ => unreachable!(),
}
let args = ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(input0),
ColumnarValue::Array(input1),
ColumnarValue::Array(input2),
],
arg_fields: vec![],
number_rows: 3,
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let e = func.invoke_with_args(args).unwrap_err();
assert!(e.to_string().starts_with(
"Internal error: Arguments has mixed length. Expected length: 2, found length: 1."
));
}
#[test]
fn test_subvector_invalid_indices() {
let func = VectorSubvectorFunction;
let input0 = Arc::new(StringVector::from(vec![
let input0 = Arc::new(StringViewArray::from(vec![
Some("[1.0, 2.0, 3.0]".to_string()),
Some("[4.0, 5.0, 6.0]".to_string()),
]));
let input1 = Arc::new(Int64Vector::from(vec![Some(1), Some(3)]));
let input2 = Arc::new(Int64Vector::from(vec![Some(3), Some(4)]));
let input1 = Arc::new(Int64Array::from(vec![Some(1), Some(3)]));
let input2 = Arc::new(Int64Array::from(vec![Some(3), Some(4)]));
let result = func.eval(&FunctionContext::default(), &[input0, input1, input2]);
match result {
Err(Error::InvalidFuncArgs { err_msg, .. }) => {
assert_eq!(
err_msg,
"Invalid start and end indices: start=3, end=4, vec_len=3"
)
}
_ => unreachable!(),
}
let args = ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(input0),
ColumnarValue::Array(input1),
ColumnarValue::Array(input2),
],
arg_fields: vec![],
number_rows: 3,
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let e = func.invoke_with_args(args).unwrap_err();
assert!(e.to_string().starts_with("External error: Invalid function args: Invalid start and end indices: start=3, end=4, vec_len=3"));
}
}

View File

@@ -12,11 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::borrow::Cow;
use common_function::scalars::vector::impl_conv::{
as_veclit, as_veclit_if_const, veclit_to_binlit,
};
use common_function::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit};
use datafusion_common::ScalarValue;
use datatypes::prelude::Value;
use nalgebra::{Const, DVectorView, Dyn, OVector};
@@ -35,14 +32,10 @@ async fn test_vec_product_aggregator() -> Result<(), common_query::error::Error>
let sql = "SELECT vector FROM vectors";
let vectors = exec_selection(engine, sql).await;
let column = vectors[0].column(0);
let vector_const = as_veclit_if_const(column)?;
let column = vectors[0].column(0).to_arrow_array();
for i in 0..column.len() {
let vector = match vector_const.as_ref() {
Some(vector) => Some(Cow::Borrowed(vector.as_ref())),
None => as_veclit(column.get_ref(i))?,
};
let v = ScalarValue::try_from_array(&column, i)?;
let vector = as_veclit(&v)?;
let Some(vector) = vector else {
expected_value = None;
break;

View File

@@ -12,12 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::borrow::Cow;
use std::ops::AddAssign;
use common_function::scalars::vector::impl_conv::{
as_veclit, as_veclit_if_const, veclit_to_binlit,
};
use common_function::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit};
use datafusion_common::ScalarValue;
use datatypes::prelude::Value;
use nalgebra::{Const, DVectorView, Dyn, OVector};
@@ -36,14 +34,10 @@ async fn test_vec_sum_aggregator() -> Result<(), common_query::error::Error> {
let sql = "SELECT vector FROM vectors";
let vectors = exec_selection(engine, sql).await;
let column = vectors[0].column(0);
let vector_const = as_veclit_if_const(column)?;
let column = vectors[0].column(0).to_arrow_array();
for i in 0..column.len() {
let vector = match vector_const.as_ref() {
Some(vector) => Some(Cow::Borrowed(vector.as_ref())),
None => as_veclit(column.get_ref(i))?,
};
let v = ScalarValue::try_from_array(&column, i)?;
let vector = as_veclit(&v)?;
let Some(vector) = vector else {
expected_value = None;
break;

View File

@@ -76,13 +76,7 @@ SELECT CLAMP(0.5, 0, 1);
SELECT CLAMP(10, 1, 0);
Error: 3001(EngineExecuteQuery), Invalid function args: The second arg should be less than or equal to the third arg, have: PrimitiveVector { array: PrimitiveArray<Int64>
[
1,
] }, PrimitiveVector { array: PrimitiveArray<Int64>
[
0,
] }
Error: 3001(EngineExecuteQuery), Execution error: min '1' > max '0'
SELECT CLAMP_MIN(10, 12);

View File

@@ -240,11 +240,11 @@ SELECT geohash(37.76938, -122.3889, 11);
SELECT geohash(37.76938, -122.3889, 100);
Error: 3001(EngineExecuteQuery), Invalid geohash resolution 100, expect value: [1, 12]
Error: 3001(EngineExecuteQuery), Execution error: Invalid geohash resolution 100, valid value range: [1, 12]
SELECT geohash(37.76938, -122.3889, -1);
Error: 3001(EngineExecuteQuery), Invalid geohash resolution -1, expect value: [1, 12]
Error: 3001(EngineExecuteQuery), Cast error: Can't cast value -1 to type UInt8
SELECT geohash(37.76938, -122.3889, 11::Int8);

View File

@@ -375,17 +375,7 @@ TQL EVAL (0, 15, '5s') clamp(host, 6 - 6, 6 + 6);
-- SQLNESS SORT_RESULT 3 1
TQL EVAL (0, 15, '5s') clamp(host, 12, 0);
Error: 3001(EngineExecuteQuery), Invalid function args: The second arg should be less than or equal to the third arg, have: PrimitiveVector { array: PrimitiveArray<Float64>
[
12.0,
0.0,
0.0,
0.0,
12.0,
12.0,
[
] }, PrimitiveVector { array: PrimitiveArray<Float64>
] }
Error: 3001(EngineExecuteQuery), Execution error: min '12' > max '0'
-- SQLNESS SORT_RESULT 3 1
TQL EVAL (0, 15, '5s') clamp(host{host="host1"}, -1, 6);

View File

@@ -99,7 +99,7 @@ SELECT round(vec_cos_distance(v, v), 2) FROM t;
-- Unexpected dimension --
SELECT vec_cos_distance(v, '[1.0]') FROM t;
Error: 3001(EngineExecuteQuery), Invalid function args: The length of the vectors must match to calculate distance, have: 3 vs 1
Error: 3001(EngineExecuteQuery), Execution error: vectors length not match: vec_cos_distance
-- Invalid type --
SELECT vec_cos_distance(v, 1.0) FROM t;
@@ -174,7 +174,7 @@ SELECT round(vec_l2sq_distance(v, v), 2) FROM t;
-- Unexpected dimension --
SELECT vec_l2sq_distance(v, '[1.0]') FROM t;
Error: 3001(EngineExecuteQuery), Invalid function args: The length of the vectors must match to calculate distance, have: 3 vs 1
Error: 3001(EngineExecuteQuery), Execution error: vectors length not match: vec_l2sq_distance
-- Invalid type --
SELECT vec_l2sq_distance(v, 1.0) FROM t;
@@ -249,7 +249,7 @@ SELECT round(vec_dot_product(v, v), 2) FROM t;
-- Unexpected dimension --
SELECT vec_dot_product(v, '[1.0]') FROM t;
Error: 3001(EngineExecuteQuery), Invalid function args: The length of the vectors must match to calculate distance, have: 3 vs 1
Error: 3001(EngineExecuteQuery), Execution error: vectors length not match: vec_dot_product
-- Invalid type --
SELECT vec_dot_product(v, 1.0) FROM t;