mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-01-05 21:02:58 +00:00
refactor: rewrite some UDFs to DataFusion style (part 2) (#6967)
* refactor: rewrite some UDFs to DataFusion style (part 2) Signed-off-by: luofucong <luofc@foxmail.com> * deal with vector UDFs `(scalar, scalar)` situation, and try getting the scalar value reference everytime Signed-off-by: luofucong <luofc@foxmail.com> * reduce some vector literal parsing Signed-off-by: luofucong <luofc@foxmail.com> * fix ci Signed-off-by: luofucong <luofc@foxmail.com> --------- Signed-off-by: luofucong <luofc@foxmail.com>
This commit is contained in:
@@ -14,14 +14,15 @@
|
||||
|
||||
use std::fmt;
|
||||
|
||||
use common_query::error::{ArrowComputeSnafu, IntoVectorSnafu, InvalidFuncArgsSnafu, Result};
|
||||
use datafusion_expr::Signature;
|
||||
use common_query::error::{ArrowComputeSnafu, Result};
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_common::utils;
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature};
|
||||
use datatypes::arrow::compute::kernels::numeric;
|
||||
use datatypes::arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
|
||||
use datatypes::vectors::{Helper, VectorRef};
|
||||
use snafu::{ResultExt, ensure};
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::Function;
|
||||
use crate::helper;
|
||||
|
||||
/// A function adds an interval value to Timestamp, Date, and return the result.
|
||||
@@ -58,25 +59,15 @@ impl Function for DateAddFunction {
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 2,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect 2, have: {}",
|
||||
columns.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
|
||||
let left = columns[0].to_arrow_array();
|
||||
let right = columns[1].to_arrow_array();
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [left, right] = utils::take_function_args(self.name(), args)?;
|
||||
|
||||
let result = numeric::add(&left, &right).context(ArrowComputeSnafu)?;
|
||||
let arrow_type = result.data_type().clone();
|
||||
Helper::try_into_vector(result).context(IntoVectorSnafu {
|
||||
data_type: arrow_type,
|
||||
})
|
||||
Ok(ColumnarValue::Array(result))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,12 +81,14 @@ impl fmt::Display for DateAddFunction {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion_expr::{TypeSignature, Volatility};
|
||||
use datatypes::arrow::datatypes::IntervalDayTime;
|
||||
use datatypes::value::Value;
|
||||
use datatypes::vectors::{
|
||||
DateVector, IntervalDayTimeVector, IntervalYearMonthVector, TimestampSecondVector,
|
||||
use arrow_schema::Field;
|
||||
use datafusion::arrow::array::{
|
||||
Array, AsArray, Date32Array, IntervalDayTimeArray, IntervalYearMonthArray,
|
||||
TimestampSecondArray,
|
||||
};
|
||||
use datafusion::arrow::datatypes::{Date32Type, IntervalDayTime, TimestampSecondType};
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
use datafusion_expr::{TypeSignature, Volatility};
|
||||
|
||||
use super::{DateAddFunction, *};
|
||||
|
||||
@@ -142,25 +135,37 @@ mod tests {
|
||||
];
|
||||
let results = [Some(124), None, Some(45), None];
|
||||
|
||||
let time_vector = TimestampSecondVector::from(times.clone());
|
||||
let interval_vector = IntervalDayTimeVector::from_vec(intervals);
|
||||
let args: Vec<VectorRef> = vec![Arc::new(time_vector), Arc::new(interval_vector)];
|
||||
let vector = f.eval(&FunctionContext::default(), &args).unwrap();
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(TimestampSecondArray::from(times.clone()))),
|
||||
ColumnarValue::Array(Arc::new(IntervalDayTimeArray::from(intervals))),
|
||||
];
|
||||
|
||||
let vector = f
|
||||
.invoke_with_args(ScalarFunctionArgs {
|
||||
args,
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new(
|
||||
"x",
|
||||
DataType::Timestamp(TimeUnit::Second, None),
|
||||
true,
|
||||
)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
})
|
||||
.and_then(|v| ColumnarValue::values_to_arrays(&[v]))
|
||||
.map(|mut a| a.remove(0))
|
||||
.unwrap();
|
||||
let vector = vector.as_primitive::<TimestampSecondType>();
|
||||
|
||||
assert_eq!(4, vector.len());
|
||||
for (i, _t) in times.iter().enumerate() {
|
||||
let v = vector.get(i);
|
||||
let result = results.get(i).unwrap();
|
||||
|
||||
if result.is_none() {
|
||||
assert_eq!(Value::Null, v);
|
||||
continue;
|
||||
}
|
||||
match v {
|
||||
Value::Timestamp(ts) => {
|
||||
assert_eq!(ts.value(), result.unwrap());
|
||||
}
|
||||
_ => unreachable!(),
|
||||
if let Some(x) = result {
|
||||
assert!(vector.is_valid(i));
|
||||
assert_eq!(vector.value(i), *x);
|
||||
} else {
|
||||
assert!(vector.is_null(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -174,25 +179,37 @@ mod tests {
|
||||
let intervals = vec![1, 2, 3, 1];
|
||||
let results = [Some(154), None, Some(131), None];
|
||||
|
||||
let date_vector = DateVector::from(dates.clone());
|
||||
let interval_vector = IntervalYearMonthVector::from_vec(intervals);
|
||||
let args: Vec<VectorRef> = vec![Arc::new(date_vector), Arc::new(interval_vector)];
|
||||
let vector = f.eval(&FunctionContext::default(), &args).unwrap();
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(Date32Array::from(dates.clone()))),
|
||||
ColumnarValue::Array(Arc::new(IntervalYearMonthArray::from(intervals))),
|
||||
];
|
||||
|
||||
let vector = f
|
||||
.invoke_with_args(ScalarFunctionArgs {
|
||||
args,
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new(
|
||||
"x",
|
||||
DataType::Timestamp(TimeUnit::Second, None),
|
||||
true,
|
||||
)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
})
|
||||
.and_then(|v| ColumnarValue::values_to_arrays(&[v]))
|
||||
.map(|mut a| a.remove(0))
|
||||
.unwrap();
|
||||
let vector = vector.as_primitive::<Date32Type>();
|
||||
|
||||
assert_eq!(4, vector.len());
|
||||
for (i, _t) in dates.iter().enumerate() {
|
||||
let v = vector.get(i);
|
||||
let result = results.get(i).unwrap();
|
||||
|
||||
if result.is_none() {
|
||||
assert_eq!(Value::Null, v);
|
||||
continue;
|
||||
}
|
||||
match v {
|
||||
Value::Date(date) => {
|
||||
assert_eq!(date.val(), result.unwrap());
|
||||
}
|
||||
_ => unreachable!(),
|
||||
if let Some(x) = result {
|
||||
assert!(vector.is_valid(i));
|
||||
assert_eq!(vector.value(i), *x);
|
||||
} else {
|
||||
assert!(vector.is_null(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,14 +14,15 @@
|
||||
|
||||
use std::fmt;
|
||||
|
||||
use common_query::error::{ArrowComputeSnafu, IntoVectorSnafu, InvalidFuncArgsSnafu, Result};
|
||||
use datafusion_expr::Signature;
|
||||
use common_query::error::{ArrowComputeSnafu, Result};
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_common::utils;
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature};
|
||||
use datatypes::arrow::compute::kernels::numeric;
|
||||
use datatypes::arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
|
||||
use datatypes::vectors::{Helper, VectorRef};
|
||||
use snafu::{ResultExt, ensure};
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::Function;
|
||||
use crate::helper;
|
||||
|
||||
/// A function subtracts an interval value to Timestamp, Date, and return the result.
|
||||
@@ -58,25 +59,15 @@ impl Function for DateSubFunction {
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 2,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect 2, have: {}",
|
||||
columns.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
|
||||
let left = columns[0].to_arrow_array();
|
||||
let right = columns[1].to_arrow_array();
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [left, right] = utils::take_function_args(self.name(), args)?;
|
||||
|
||||
let result = numeric::sub(&left, &right).context(ArrowComputeSnafu)?;
|
||||
let arrow_type = result.data_type().clone();
|
||||
Helper::try_into_vector(result).context(IntoVectorSnafu {
|
||||
data_type: arrow_type,
|
||||
})
|
||||
Ok(ColumnarValue::Array(result))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,12 +81,14 @@ impl fmt::Display for DateSubFunction {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion_expr::{TypeSignature, Volatility};
|
||||
use datatypes::arrow::datatypes::IntervalDayTime;
|
||||
use datatypes::value::Value;
|
||||
use datatypes::vectors::{
|
||||
DateVector, IntervalDayTimeVector, IntervalYearMonthVector, TimestampSecondVector,
|
||||
use arrow_schema::Field;
|
||||
use datafusion::arrow::array::{
|
||||
Array, AsArray, Date32Array, IntervalDayTimeArray, IntervalYearMonthArray,
|
||||
TimestampSecondArray,
|
||||
};
|
||||
use datafusion::arrow::datatypes::{Date32Type, IntervalDayTime, TimestampSecondType};
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
use datafusion_expr::{TypeSignature, Volatility};
|
||||
|
||||
use super::{DateSubFunction, *};
|
||||
|
||||
@@ -142,25 +135,37 @@ mod tests {
|
||||
];
|
||||
let results = [Some(122), None, Some(39), None];
|
||||
|
||||
let time_vector = TimestampSecondVector::from(times.clone());
|
||||
let interval_vector = IntervalDayTimeVector::from_vec(intervals);
|
||||
let args: Vec<VectorRef> = vec![Arc::new(time_vector), Arc::new(interval_vector)];
|
||||
let vector = f.eval(&FunctionContext::default(), &args).unwrap();
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(TimestampSecondArray::from(times.clone()))),
|
||||
ColumnarValue::Array(Arc::new(IntervalDayTimeArray::from(intervals))),
|
||||
];
|
||||
|
||||
let vector = f
|
||||
.invoke_with_args(ScalarFunctionArgs {
|
||||
args,
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new(
|
||||
"x",
|
||||
DataType::Timestamp(TimeUnit::Second, None),
|
||||
true,
|
||||
)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
})
|
||||
.and_then(|v| ColumnarValue::values_to_arrays(&[v]))
|
||||
.map(|mut a| a.remove(0))
|
||||
.unwrap();
|
||||
let vector = vector.as_primitive::<TimestampSecondType>();
|
||||
|
||||
assert_eq!(4, vector.len());
|
||||
for (i, _t) in times.iter().enumerate() {
|
||||
let v = vector.get(i);
|
||||
let result = results.get(i).unwrap();
|
||||
|
||||
if result.is_none() {
|
||||
assert_eq!(Value::Null, v);
|
||||
continue;
|
||||
}
|
||||
match v {
|
||||
Value::Timestamp(ts) => {
|
||||
assert_eq!(ts.value(), result.unwrap());
|
||||
}
|
||||
_ => unreachable!(),
|
||||
if let Some(x) = result {
|
||||
assert!(vector.is_valid(i));
|
||||
assert_eq!(vector.value(i), *x);
|
||||
} else {
|
||||
assert!(vector.is_null(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -180,25 +185,37 @@ mod tests {
|
||||
let intervals = vec![1, 2, 3, 1];
|
||||
let results = [Some(3659), None, Some(1168), None];
|
||||
|
||||
let date_vector = DateVector::from(dates.clone());
|
||||
let interval_vector = IntervalYearMonthVector::from_vec(intervals);
|
||||
let args: Vec<VectorRef> = vec![Arc::new(date_vector), Arc::new(interval_vector)];
|
||||
let vector = f.eval(&FunctionContext::default(), &args).unwrap();
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(Date32Array::from(dates.clone()))),
|
||||
ColumnarValue::Array(Arc::new(IntervalYearMonthArray::from(intervals))),
|
||||
];
|
||||
|
||||
let vector = f
|
||||
.invoke_with_args(ScalarFunctionArgs {
|
||||
args,
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new(
|
||||
"x",
|
||||
DataType::Timestamp(TimeUnit::Second, None),
|
||||
true,
|
||||
)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
})
|
||||
.and_then(|v| ColumnarValue::values_to_arrays(&[v]))
|
||||
.map(|mut a| a.remove(0))
|
||||
.unwrap();
|
||||
let vector = vector.as_primitive::<Date32Type>();
|
||||
|
||||
assert_eq!(4, vector.len());
|
||||
for (i, _t) in dates.iter().enumerate() {
|
||||
let v = vector.get(i);
|
||||
let result = results.get(i).unwrap();
|
||||
|
||||
if result.is_none() {
|
||||
assert_eq!(Value::Null, v);
|
||||
continue;
|
||||
}
|
||||
match v {
|
||||
Value::Date(date) => {
|
||||
assert_eq!(date.val(), result.unwrap());
|
||||
}
|
||||
_ => unreachable!(),
|
||||
if let Some(x) = result {
|
||||
assert!(vector.is_valid(i));
|
||||
assert_eq!(vector.value(i), *x);
|
||||
} else {
|
||||
assert!(vector.is_null(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,62 +17,26 @@ use std::sync::Arc;
|
||||
|
||||
use common_error::ext::{BoxedError, PlainError};
|
||||
use common_error::status_code::StatusCode;
|
||||
use common_query::error::{self, InvalidFuncArgsSnafu, Result};
|
||||
use datafusion::arrow::datatypes::Field;
|
||||
use common_query::error::{self, Result};
|
||||
use datafusion::arrow::array::{Array, AsArray, ListBuilder, StringViewBuilder};
|
||||
use datafusion::arrow::datatypes::{DataType, Field, Float64Type, UInt8Type};
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_common::{DataFusionError, utils};
|
||||
use datafusion_expr::type_coercion::aggregates::INTEGERS;
|
||||
use datafusion_expr::{Signature, TypeSignature, Volatility};
|
||||
use datatypes::arrow::datatypes::DataType;
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use datatypes::scalars::{Scalar, ScalarVectorBuilder};
|
||||
use datatypes::value::{ListValue, Value};
|
||||
use datatypes::vectors::{ListVectorBuilder, MutableVector, StringVectorBuilder, VectorRef};
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
|
||||
use geohash::Coord;
|
||||
use snafu::{ResultExt, ensure};
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::Function;
|
||||
use crate::scalars::geo::helpers;
|
||||
|
||||
macro_rules! ensure_resolution_usize {
|
||||
($v: ident) => {
|
||||
if !($v > 0 && $v <= 12) {
|
||||
Err(BoxedError::new(PlainError::new(
|
||||
format!("Invalid geohash resolution {}, expect value: [1, 12]", $v),
|
||||
StatusCode::EngineExecuteQuery,
|
||||
)))
|
||||
.context(error::ExecuteSnafu)
|
||||
} else {
|
||||
Ok($v as usize)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
fn try_into_resolution(v: Value) -> Result<usize> {
|
||||
match v {
|
||||
Value::Int8(v) => {
|
||||
ensure_resolution_usize!(v)
|
||||
}
|
||||
Value::Int16(v) => {
|
||||
ensure_resolution_usize!(v)
|
||||
}
|
||||
Value::Int32(v) => {
|
||||
ensure_resolution_usize!(v)
|
||||
}
|
||||
Value::Int64(v) => {
|
||||
ensure_resolution_usize!(v)
|
||||
}
|
||||
Value::UInt8(v) => {
|
||||
ensure_resolution_usize!(v)
|
||||
}
|
||||
Value::UInt16(v) => {
|
||||
ensure_resolution_usize!(v)
|
||||
}
|
||||
Value::UInt32(v) => {
|
||||
ensure_resolution_usize!(v)
|
||||
}
|
||||
Value::UInt64(v) => {
|
||||
ensure_resolution_usize!(v)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
fn ensure_resolution_usize(v: u8) -> datafusion_common::Result<usize> {
|
||||
if v == 0 || v > 12 {
|
||||
return Err(DataFusionError::Execution(format!(
|
||||
"Invalid geohash resolution {v}, valid value range: [1, 12]"
|
||||
)));
|
||||
}
|
||||
Ok(v as usize)
|
||||
}
|
||||
|
||||
/// Function that return geohash string for a given geospatial coordinate.
|
||||
@@ -109,31 +73,33 @@ impl Function for GeohashFunction {
|
||||
Signature::one_of(signatures, Volatility::Stable)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 3,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect 3, provided : {}",
|
||||
columns.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [lat_vec, lon_vec, resolutions] = utils::take_function_args(self.name(), args)?;
|
||||
|
||||
let lat_vec = &columns[0];
|
||||
let lon_vec = &columns[1];
|
||||
let resolution_vec = &columns[2];
|
||||
let lat_vec = helpers::cast::<Float64Type>(&lat_vec)?;
|
||||
let lat_vec = lat_vec.as_primitive::<Float64Type>();
|
||||
let lon_vec = helpers::cast::<Float64Type>(&lon_vec)?;
|
||||
let lon_vec = lon_vec.as_primitive::<Float64Type>();
|
||||
let resolutions = helpers::cast::<UInt8Type>(&resolutions)?;
|
||||
let resolutions = resolutions.as_primitive::<UInt8Type>();
|
||||
|
||||
let size = lat_vec.len();
|
||||
let mut results = StringVectorBuilder::with_capacity(size);
|
||||
let mut builder = StringViewBuilder::with_capacity(size);
|
||||
|
||||
for i in 0..size {
|
||||
let lat = lat_vec.get(i).as_f64_lossy();
|
||||
let lon = lon_vec.get(i).as_f64_lossy();
|
||||
let r = try_into_resolution(resolution_vec.get(i))?;
|
||||
let lat = lat_vec.is_valid(i).then(|| lat_vec.value(i));
|
||||
let lon = lon_vec.is_valid(i).then(|| lon_vec.value(i));
|
||||
let r = resolutions
|
||||
.is_valid(i)
|
||||
.then(|| ensure_resolution_usize(resolutions.value(i)))
|
||||
.transpose()?;
|
||||
|
||||
let result = match (lat, lon) {
|
||||
(Some(lat), Some(lon)) => {
|
||||
let result = match (lat, lon, r) {
|
||||
(Some(lat), Some(lon), Some(r)) => {
|
||||
let coord = Coord { x: lon, y: lat };
|
||||
let encoded = geohash::encode(coord, r)
|
||||
.map_err(|e| {
|
||||
@@ -148,10 +114,10 @@ impl Function for GeohashFunction {
|
||||
_ => None,
|
||||
};
|
||||
|
||||
results.push(result.as_deref());
|
||||
builder.append_option(result);
|
||||
}
|
||||
|
||||
Ok(results.to_vector())
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -176,8 +142,8 @@ impl Function for GeohashNeighboursFunction {
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
Ok(DataType::List(Arc::new(Field::new(
|
||||
"x",
|
||||
DataType::Utf8,
|
||||
"item",
|
||||
DataType::Utf8View,
|
||||
false,
|
||||
))))
|
||||
}
|
||||
@@ -199,32 +165,33 @@ impl Function for GeohashNeighboursFunction {
|
||||
Signature::one_of(signatures, Volatility::Stable)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 3,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect 3, provided : {}",
|
||||
columns.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [lat_vec, lon_vec, resolutions] = utils::take_function_args(self.name(), args)?;
|
||||
|
||||
let lat_vec = &columns[0];
|
||||
let lon_vec = &columns[1];
|
||||
let resolution_vec = &columns[2];
|
||||
let lat_vec = helpers::cast::<Float64Type>(&lat_vec)?;
|
||||
let lat_vec = lat_vec.as_primitive::<Float64Type>();
|
||||
let lon_vec = helpers::cast::<Float64Type>(&lon_vec)?;
|
||||
let lon_vec = lon_vec.as_primitive::<Float64Type>();
|
||||
let resolutions = helpers::cast::<UInt8Type>(&resolutions)?;
|
||||
let resolutions = resolutions.as_primitive::<UInt8Type>();
|
||||
|
||||
let size = lat_vec.len();
|
||||
let mut results =
|
||||
ListVectorBuilder::with_type_capacity(ConcreteDataType::string_datatype(), size);
|
||||
let mut builder = ListBuilder::new(StringViewBuilder::new());
|
||||
|
||||
for i in 0..size {
|
||||
let lat = lat_vec.get(i).as_f64_lossy();
|
||||
let lon = lon_vec.get(i).as_f64_lossy();
|
||||
let r = try_into_resolution(resolution_vec.get(i))?;
|
||||
let lat = lat_vec.is_valid(i).then(|| lat_vec.value(i));
|
||||
let lon = lon_vec.is_valid(i).then(|| lon_vec.value(i));
|
||||
let r = resolutions
|
||||
.is_valid(i)
|
||||
.then(|| ensure_resolution_usize(resolutions.value(i)))
|
||||
.transpose()?;
|
||||
|
||||
let result = match (lat, lon) {
|
||||
(Some(lat), Some(lon)) => {
|
||||
match (lat, lon, r) {
|
||||
(Some(lat), Some(lon), Some(r)) => {
|
||||
let coord = Coord { x: lon, y: lat };
|
||||
let encoded = geohash::encode(coord, r)
|
||||
.map_err(|e| {
|
||||
@@ -242,8 +209,8 @@ impl Function for GeohashNeighboursFunction {
|
||||
))
|
||||
})
|
||||
.context(error::ExecuteSnafu)?;
|
||||
Some(ListValue::new(
|
||||
vec![
|
||||
builder.append_value(
|
||||
[
|
||||
neighbours.n,
|
||||
neighbours.nw,
|
||||
neighbours.w,
|
||||
@@ -254,22 +221,14 @@ impl Function for GeohashNeighboursFunction {
|
||||
neighbours.ne,
|
||||
]
|
||||
.into_iter()
|
||||
.map(Value::from)
|
||||
.collect(),
|
||||
ConcreteDataType::string_datatype(),
|
||||
))
|
||||
.map(Some),
|
||||
);
|
||||
}
|
||||
_ => None,
|
||||
_ => builder.append_null(),
|
||||
};
|
||||
|
||||
if let Some(list_value) = result {
|
||||
results.push(Some(list_value.as_scalar_ref()));
|
||||
} else {
|
||||
results.push(None);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results.to_vector())
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -19,13 +19,11 @@ use common_error::ext::{BoxedError, PlainError};
|
||||
use common_error::status_code::StatusCode;
|
||||
use common_query::error::{self, Result};
|
||||
use datafusion::arrow::array::{
|
||||
Array, ArrayRef, AsArray, BooleanBuilder, Float64Builder, Int32Builder, ListBuilder,
|
||||
StringViewArray, StringViewBuilder, UInt8Builder, UInt64Builder,
|
||||
Array, AsArray, BooleanBuilder, Float64Builder, Int32Builder, ListBuilder, StringViewArray,
|
||||
StringViewBuilder, UInt8Builder, UInt64Builder,
|
||||
};
|
||||
use datafusion::arrow::compute;
|
||||
use datafusion::arrow::datatypes::{
|
||||
ArrowPrimitiveType, Float64Type, Int64Type, UInt8Type, UInt64Type,
|
||||
};
|
||||
use datafusion::arrow::datatypes::{Float64Type, Int64Type, UInt8Type, UInt64Type};
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_common::{DataFusionError, ScalarValue, utils};
|
||||
use datafusion_expr::type_coercion::aggregates::INTEGERS;
|
||||
@@ -36,6 +34,7 @@ use h3o::{CellIndex, LatLng, Resolution};
|
||||
use snafu::prelude::*;
|
||||
|
||||
use crate::function::Function;
|
||||
use crate::scalars::geo::helpers;
|
||||
|
||||
static CELL_TYPES: LazyLock<Vec<DataType>> =
|
||||
LazyLock::new(|| vec![DataType::Int64, DataType::UInt64, DataType::Utf8]);
|
||||
@@ -89,11 +88,11 @@ impl Function for H3LatLngToCell {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [lat_vec, lon_vec, resolution_vec] = utils::take_function_args(self.name(), args)?;
|
||||
|
||||
let lat_vec = cast::<Float64Type>(&lat_vec)?;
|
||||
let lat_vec = helpers::cast::<Float64Type>(&lat_vec)?;
|
||||
let lat_vec = lat_vec.as_primitive::<Float64Type>();
|
||||
let lon_vec = cast::<Float64Type>(&lon_vec)?;
|
||||
let lon_vec = helpers::cast::<Float64Type>(&lon_vec)?;
|
||||
let lon_vec = lon_vec.as_primitive::<Float64Type>();
|
||||
let resolutions = cast::<UInt8Type>(&resolution_vec)?;
|
||||
let resolutions = helpers::cast::<UInt8Type>(&resolution_vec)?;
|
||||
let resolution_vec = resolutions.as_primitive::<UInt8Type>();
|
||||
|
||||
let size = lat_vec.len();
|
||||
@@ -171,11 +170,11 @@ impl Function for H3LatLngToCellString {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [lat_vec, lon_vec, resolution_vec] = utils::take_function_args(self.name(), args)?;
|
||||
|
||||
let lat_vec = cast::<Float64Type>(&lat_vec)?;
|
||||
let lat_vec = helpers::cast::<Float64Type>(&lat_vec)?;
|
||||
let lat_vec = lat_vec.as_primitive::<Float64Type>();
|
||||
let lon_vec = cast::<Float64Type>(&lon_vec)?;
|
||||
let lon_vec = helpers::cast::<Float64Type>(&lon_vec)?;
|
||||
let lon_vec = lon_vec.as_primitive::<Float64Type>();
|
||||
let resolutions = cast::<UInt8Type>(&resolution_vec)?;
|
||||
let resolutions = helpers::cast::<UInt8Type>(&resolution_vec)?;
|
||||
let resolution_vec = resolutions.as_primitive::<UInt8Type>();
|
||||
|
||||
let size = lat_vec.len();
|
||||
@@ -547,7 +546,7 @@ impl Function for H3CellToChildren {
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [cell_vec, res_vec] = utils::take_function_args(self.name(), args)?;
|
||||
let resolutions = cast::<UInt8Type>(&res_vec)?;
|
||||
let resolutions = helpers::cast::<UInt8Type>(&res_vec)?;
|
||||
let resolutions = resolutions.as_primitive::<UInt8Type>();
|
||||
|
||||
let size = cell_vec.len();
|
||||
@@ -641,7 +640,7 @@ where
|
||||
{
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [cells, resolutions] = utils::take_function_args(name, args)?;
|
||||
let resolutions = cast::<UInt8Type>(&resolutions)?;
|
||||
let resolutions = helpers::cast::<UInt8Type>(&resolutions)?;
|
||||
let resolutions = resolutions.as_primitive::<UInt8Type>();
|
||||
|
||||
let mut builder = UInt64Builder::with_capacity(cells.len());
|
||||
@@ -698,7 +697,7 @@ impl Function for H3ChildPosToCell {
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [pos_vec, cell_vec, res_vec] = utils::take_function_args(self.name(), args)?;
|
||||
let resolutions = cast::<UInt8Type>(&res_vec)?;
|
||||
let resolutions = helpers::cast::<UInt8Type>(&res_vec)?;
|
||||
let resolutions = resolutions.as_primitive::<UInt8Type>();
|
||||
|
||||
let size = cell_vec.len();
|
||||
@@ -722,18 +721,6 @@ impl Function for H3ChildPosToCell {
|
||||
}
|
||||
}
|
||||
|
||||
fn cast<T: ArrowPrimitiveType>(array: &ArrayRef) -> datafusion_common::Result<ArrayRef> {
|
||||
let x = compute::cast_with_options(
|
||||
array.as_ref(),
|
||||
&T::DATA_TYPE,
|
||||
&compute::CastOptions {
|
||||
safe: false,
|
||||
..Default::default()
|
||||
},
|
||||
)?;
|
||||
Ok(x)
|
||||
}
|
||||
|
||||
/// Function that returns cells with k distances of given cell
|
||||
#[derive(Clone, Debug, Default, Display)]
|
||||
#[display("{}", self.name())]
|
||||
|
||||
@@ -12,6 +12,9 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use datafusion::arrow::array::{ArrayRef, ArrowPrimitiveType};
|
||||
use datafusion::arrow::compute;
|
||||
|
||||
macro_rules! ensure_columns_len {
|
||||
($columns:ident) => {
|
||||
snafu::ensure!(
|
||||
@@ -73,3 +76,15 @@ macro_rules! ensure_and_coerce {
|
||||
}
|
||||
|
||||
pub(crate) use ensure_and_coerce;
|
||||
|
||||
pub(crate) fn cast<T: ArrowPrimitiveType>(array: &ArrayRef) -> datafusion_common::Result<ArrayRef> {
|
||||
let x = compute::cast_with_options(
|
||||
array.as_ref(),
|
||||
&T::DATA_TYPE,
|
||||
&compute::CastOptions {
|
||||
safe: false,
|
||||
..Default::default()
|
||||
},
|
||||
)?;
|
||||
Ok(x)
|
||||
}
|
||||
|
||||
@@ -16,21 +16,20 @@ use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error::{IntoVectorSnafu, InvalidFuncArgsSnafu, InvalidInputTypeSnafu, Result};
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use datafusion::arrow::array::{Array, ArrayRef, AsArray, BooleanArray};
|
||||
use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion};
|
||||
use datafusion::common::{DFSchema, Result as DfResult};
|
||||
use datafusion::execution::SessionStateBuilder;
|
||||
use datafusion::logical_expr::{self, Expr, Volatility};
|
||||
use datafusion::logical_expr::{self, ColumnarValue, Expr, Volatility};
|
||||
use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner};
|
||||
use datafusion_expr::Signature;
|
||||
use datafusion_common::{DataFusionError, utils};
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature};
|
||||
use datatypes::arrow::array::RecordBatch;
|
||||
use datatypes::arrow::datatypes::{DataType, Field};
|
||||
use datatypes::prelude::VectorRef;
|
||||
use datatypes::vectors::BooleanVector;
|
||||
use snafu::{OptionExt, ResultExt, ensure};
|
||||
use store_api::storage::ConcreteDataType;
|
||||
use snafu::{OptionExt, ensure};
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::Function;
|
||||
use crate::function_registry::FunctionRegistry;
|
||||
|
||||
/// `matches` for full text search.
|
||||
@@ -65,38 +64,36 @@ impl Function for MatchesFunction {
|
||||
}
|
||||
|
||||
// TODO: read case-sensitive config
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 2,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly 2, have: {}",
|
||||
columns.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DfResult<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [data_column, patterns] = utils::take_function_args(self.name(), args)?;
|
||||
|
||||
let data_column = &columns[0];
|
||||
if data_column.is_empty() {
|
||||
return Ok(Arc::new(BooleanVector::from(Vec::<bool>::with_capacity(0))));
|
||||
return Ok(ColumnarValue::Array(Arc::new(BooleanArray::from(
|
||||
Vec::<bool>::with_capacity(0),
|
||||
))));
|
||||
}
|
||||
|
||||
let pattern_vector = &columns[1]
|
||||
.cast(&ConcreteDataType::string_datatype())
|
||||
.context(InvalidInputTypeSnafu {
|
||||
err_msg: "cannot cast `pattern` to string",
|
||||
})?;
|
||||
// Safety: both length and type are checked before
|
||||
let pattern = pattern_vector.get(0).as_string().unwrap();
|
||||
let pattern = match patterns.data_type() {
|
||||
DataType::Utf8View => patterns.as_string_view().value(0),
|
||||
DataType::Utf8 => patterns.as_string::<i32>().value(0),
|
||||
DataType::LargeUtf8 => patterns.as_string::<i64>().value(0),
|
||||
t => {
|
||||
return Err(DataFusionError::Execution(format!(
|
||||
"unsupported datatype {t}"
|
||||
)));
|
||||
}
|
||||
};
|
||||
self.eval(data_column, pattern)
|
||||
}
|
||||
}
|
||||
|
||||
impl MatchesFunction {
|
||||
fn eval(&self, data: &VectorRef, pattern: String) -> Result<VectorRef> {
|
||||
fn eval(&self, data_array: ArrayRef, pattern: &str) -> DfResult<ColumnarValue> {
|
||||
let col_name = "data";
|
||||
let parser_context = ParserContext::default();
|
||||
let raw_ast = parser_context.parse_pattern(&pattern)?;
|
||||
let raw_ast = parser_context.parse_pattern(pattern)?;
|
||||
let ast = raw_ast.transform_ast()?;
|
||||
|
||||
let like_expr = ast.into_like_expr(col_name);
|
||||
@@ -107,19 +104,14 @@ impl MatchesFunction {
|
||||
let physical_expr =
|
||||
planner.create_physical_expr(&like_expr, &input_schema, &session_state)?;
|
||||
|
||||
let data_array = data.to_arrow_array();
|
||||
let arrow_schema = Arc::new(input_schema.as_arrow().clone());
|
||||
let input_record_batch = RecordBatch::try_new(arrow_schema, vec![data_array]).unwrap();
|
||||
|
||||
let num_rows = input_record_batch.num_rows();
|
||||
let result = physical_expr.evaluate(&input_record_batch)?;
|
||||
let result_array = result.into_array(num_rows)?;
|
||||
let result_vector =
|
||||
BooleanVector::try_from_arrow_array(result_array).context(IntoVectorSnafu {
|
||||
data_type: DataType::Boolean,
|
||||
})?;
|
||||
|
||||
Ok(Arc::new(result_vector))
|
||||
Ok(ColumnarValue::Array(Arc::new(result_array)))
|
||||
}
|
||||
|
||||
fn input_schema() -> DFSchema {
|
||||
@@ -833,7 +825,9 @@ impl Tokenizer {
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use datatypes::vectors::StringVector;
|
||||
use datafusion::arrow::array::StringArray;
|
||||
use datafusion_common::ScalarValue;
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -1300,7 +1294,7 @@ mod test {
|
||||
"The quick brown fox jumps over dog",
|
||||
"The quick brown fox jumps over the dog",
|
||||
];
|
||||
let input_vector: VectorRef = Arc::new(StringVector::from(input_data));
|
||||
let col: ArrayRef = Arc::new(StringArray::from(input_data));
|
||||
let cases = [
|
||||
// basic cases
|
||||
("quick", vec![true, false, true, true, true, true, true]),
|
||||
@@ -1391,9 +1385,22 @@ mod test {
|
||||
|
||||
let f = MatchesFunction;
|
||||
for (pattern, expected) in cases {
|
||||
let actual: VectorRef = f.eval(&input_vector, pattern.to_string()).unwrap();
|
||||
let expected: VectorRef = Arc::new(BooleanVector::from(expected)) as _;
|
||||
assert_eq!(expected, actual, "{pattern}");
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![
|
||||
ColumnarValue::Array(col.clone()),
|
||||
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(pattern.to_string()))),
|
||||
],
|
||||
arg_fields: vec![],
|
||||
number_rows: col.len(),
|
||||
return_field: Arc::new(Field::new("x", col.data_type().clone(), true)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let actual = f
|
||||
.invoke_with_args(args)
|
||||
.and_then(|x| x.to_array(col.len()))
|
||||
.unwrap();
|
||||
let expected: ArrayRef = Arc::new(BooleanArray::from(expected));
|
||||
assert_eq!(expected.as_ref(), actual.as_ref(), "{pattern}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,10 +23,9 @@ use common_query::error::Result;
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion::error::DataFusionError;
|
||||
use datafusion_expr::{Signature, Volatility};
|
||||
use datatypes::vectors::VectorRef;
|
||||
pub use rate::RateFunction;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::Function;
|
||||
use crate::function_registry::FunctionRegistry;
|
||||
use crate::scalars::math::modulo::ModuloFunction;
|
||||
|
||||
@@ -75,11 +74,4 @@ impl Function for RangeFunction {
|
||||
fn signature(&self) -> Signature {
|
||||
Signature::variadic_any(Volatility::Immutable)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
Err(DataFusionError::Internal(
|
||||
"range_fn just a empty function used in range select, It should not be eval!".into(),
|
||||
)
|
||||
.into())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,54 +15,21 @@
|
||||
use std::fmt::{self, Display};
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use datafusion::arrow::array::{ArrayIter, PrimitiveArray};
|
||||
use common_query::error::Result;
|
||||
use datafusion::arrow::array::{Array, ArrayRef, AsArray, PrimitiveArray};
|
||||
use datafusion::arrow::datatypes::DataType as ArrowDataType;
|
||||
use datafusion::logical_expr::Volatility;
|
||||
use datafusion_expr::Signature;
|
||||
use datafusion::logical_expr::{ColumnarValue, Volatility};
|
||||
use datafusion_common::{DataFusionError, ScalarValue, utils};
|
||||
use datafusion_expr::type_coercion::aggregates::NUMERICS;
|
||||
use datatypes::data_type::DataType;
|
||||
use datatypes::prelude::VectorRef;
|
||||
use datatypes::types::LogicalPrimitiveType;
|
||||
use datatypes::value::TryAsPrimitive;
|
||||
use datatypes::vectors::PrimitiveVector;
|
||||
use datatypes::with_match_primitive_type_id;
|
||||
use snafu::{OptionExt, ensure};
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature};
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::Function;
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct ClampFunction;
|
||||
|
||||
const CLAMP_NAME: &str = "clamp";
|
||||
|
||||
/// Ensure the vector is constant and not empty (i.e., all values are identical)
|
||||
fn ensure_constant_vector(vector: &VectorRef) -> Result<()> {
|
||||
ensure!(
|
||||
!vector.is_empty(),
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: "Expect at least one value",
|
||||
}
|
||||
);
|
||||
|
||||
if vector.is_const() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let first = vector.get_ref(0);
|
||||
for i in 1..vector.len() {
|
||||
let v = vector.get_ref(i);
|
||||
if first != v {
|
||||
return InvalidFuncArgsSnafu {
|
||||
err_msg: "All values in min/max argument must be identical",
|
||||
}
|
||||
.fail();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl Function for ClampFunction {
|
||||
fn name(&self) -> &str {
|
||||
CLAMP_NAME
|
||||
@@ -78,76 +45,12 @@ impl Function for ClampFunction {
|
||||
Signature::uniform(3, NUMERICS.to_vec(), Volatility::Immutable)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 3,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly 3, have: {}",
|
||||
columns.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
ensure!(
|
||||
columns[0].data_type().is_numeric(),
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The first arg's type is not numeric, have: {}",
|
||||
columns[0].data_type()
|
||||
),
|
||||
}
|
||||
);
|
||||
ensure!(
|
||||
columns[0].data_type() == columns[1].data_type()
|
||||
&& columns[1].data_type() == columns[2].data_type(),
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"Arguments don't have identical types: {}, {}, {}",
|
||||
columns[0].data_type(),
|
||||
columns[1].data_type(),
|
||||
columns[2].data_type()
|
||||
),
|
||||
}
|
||||
);
|
||||
|
||||
ensure_constant_vector(&columns[1])?;
|
||||
ensure_constant_vector(&columns[2])?;
|
||||
|
||||
with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
|
||||
let input_array = columns[0].to_arrow_array();
|
||||
let input = input_array
|
||||
.as_any()
|
||||
.downcast_ref::<PrimitiveArray<<$S as LogicalPrimitiveType>::ArrowPrimitive>>()
|
||||
.unwrap();
|
||||
|
||||
let min = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0))
|
||||
.with_context(|| {
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: "The second arg should not be none",
|
||||
}
|
||||
})?;
|
||||
let max = TryAsPrimitive::<$S>::try_as_primitive(&columns[2].get(0))
|
||||
.with_context(|| {
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: "The third arg should not be none",
|
||||
}
|
||||
})?;
|
||||
|
||||
// ensure min <= max
|
||||
ensure!(
|
||||
min <= max,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The second arg should be less than or equal to the third arg, have: {:?}, {:?}",
|
||||
columns[1], columns[2]
|
||||
),
|
||||
}
|
||||
);
|
||||
|
||||
clamp_impl::<$S, true, true>(input, min, max)
|
||||
},{
|
||||
unreachable!()
|
||||
})
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let [col, min, max] = utils::take_function_args(self.name(), args.args)?;
|
||||
clamp_impl(col, min, max)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -157,25 +60,155 @@ impl Display for ClampFunction {
|
||||
}
|
||||
}
|
||||
|
||||
fn clamp_impl<T: LogicalPrimitiveType, const CLAMP_MIN: bool, const CLAMP_MAX: bool>(
|
||||
input: &PrimitiveArray<T::ArrowPrimitive>,
|
||||
min: T::Native,
|
||||
max: T::Native,
|
||||
) -> Result<VectorRef> {
|
||||
let iter = ArrayIter::new(input);
|
||||
let result = iter.map(|x| {
|
||||
x.map(|x| {
|
||||
if CLAMP_MIN && x < min {
|
||||
min
|
||||
} else if CLAMP_MAX && x > max {
|
||||
max
|
||||
} else {
|
||||
x
|
||||
fn clamp_impl(
|
||||
col: ColumnarValue,
|
||||
min: ColumnarValue,
|
||||
max: ColumnarValue,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
if col.data_type() != min.data_type() || min.data_type() != max.data_type() {
|
||||
return Err(DataFusionError::Execution(format!(
|
||||
"argument data types mismatch: {}, {}, {}",
|
||||
col.data_type(),
|
||||
min.data_type(),
|
||||
max.data_type(),
|
||||
)));
|
||||
}
|
||||
|
||||
macro_rules! with_match_numerics_types {
|
||||
($data_type:expr, | $_:tt $T:ident | $body:tt) => {{
|
||||
macro_rules! __with_ty__ {
|
||||
( $_ $T:ident ) => {
|
||||
$body
|
||||
};
|
||||
}
|
||||
})
|
||||
});
|
||||
let result = PrimitiveArray::<T::ArrowPrimitive>::from_iter(result);
|
||||
Ok(Arc::new(PrimitiveVector::<T>::from(result)))
|
||||
|
||||
use datafusion::arrow::datatypes::{
|
||||
Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, Int64Type, UInt8Type,
|
||||
UInt16Type, UInt32Type, UInt64Type,
|
||||
};
|
||||
|
||||
match $data_type {
|
||||
ArrowDataType::Int8 => Ok(__with_ty__! { Int8Type }),
|
||||
ArrowDataType::Int16 => Ok(__with_ty__! { Int16Type }),
|
||||
ArrowDataType::Int32 => Ok(__with_ty__! { Int32Type }),
|
||||
ArrowDataType::Int64 => Ok(__with_ty__! { Int64Type }),
|
||||
ArrowDataType::UInt8 => Ok(__with_ty__! { UInt8Type }),
|
||||
ArrowDataType::UInt16 => Ok(__with_ty__! { UInt16Type }),
|
||||
ArrowDataType::UInt32 => Ok(__with_ty__! { UInt32Type }),
|
||||
ArrowDataType::UInt64 => Ok(__with_ty__! { UInt64Type }),
|
||||
ArrowDataType::Float32 => Ok(__with_ty__! { Float32Type }),
|
||||
ArrowDataType::Float64 => Ok(__with_ty__! { Float64Type }),
|
||||
_ => Err(DataFusionError::Execution(format!(
|
||||
"unsupported numeric data type: '{}'",
|
||||
$data_type
|
||||
))),
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
macro_rules! clamp {
|
||||
($v: ident, $min: ident, $max: ident) => {
|
||||
if $v < $min {
|
||||
$min
|
||||
} else if $v > $max {
|
||||
$max
|
||||
} else {
|
||||
$v
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
match (col, min, max) {
|
||||
(ColumnarValue::Scalar(col), ColumnarValue::Scalar(min), ColumnarValue::Scalar(max)) => {
|
||||
if min > max {
|
||||
return Err(DataFusionError::Execution(format!(
|
||||
"min '{}' > max '{}'",
|
||||
min, max
|
||||
)));
|
||||
}
|
||||
Ok(ColumnarValue::Scalar(clamp!(col, min, max)))
|
||||
}
|
||||
|
||||
(ColumnarValue::Array(col), ColumnarValue::Array(min), ColumnarValue::Array(max)) => {
|
||||
if col.len() != min.len() || col.len() != max.len() {
|
||||
return Err(DataFusionError::Internal(
|
||||
"arguments not of same length".to_string(),
|
||||
));
|
||||
}
|
||||
let result = with_match_numerics_types!(
|
||||
col.data_type(),
|
||||
|$S| {
|
||||
let col = col.as_primitive::<$S>();
|
||||
let min = min.as_primitive::<$S>();
|
||||
let max = max.as_primitive::<$S>();
|
||||
Arc::new(PrimitiveArray::<$S>::from(
|
||||
(0..col.len())
|
||||
.map(|i| {
|
||||
let v = col.is_valid(i).then(|| col.value(i));
|
||||
// Index safety: checked above, all have same length.
|
||||
let min = min.is_valid(i).then(|| min.value(i));
|
||||
let max = max.is_valid(i).then(|| max.value(i));
|
||||
Ok(match (v, min, max) {
|
||||
(Some(v), Some(min), Some(max)) => {
|
||||
if min > max {
|
||||
return Err(DataFusionError::Execution(format!(
|
||||
"min '{}' > max '{}'",
|
||||
min, max
|
||||
)));
|
||||
}
|
||||
Some(clamp!(v, min, max))
|
||||
},
|
||||
_ => None,
|
||||
})
|
||||
})
|
||||
.collect::<datafusion_common::Result<Vec<_>>>()?,
|
||||
)
|
||||
) as ArrayRef
|
||||
}
|
||||
)?;
|
||||
Ok(ColumnarValue::Array(result))
|
||||
}
|
||||
|
||||
(ColumnarValue::Array(col), ColumnarValue::Scalar(min), ColumnarValue::Scalar(max)) => {
|
||||
if min.is_null() || max.is_null() {
|
||||
return Err(DataFusionError::Execution(
|
||||
"argument 'min' or 'max' is null".to_string(),
|
||||
));
|
||||
}
|
||||
let min = min.to_array()?;
|
||||
let max = max.to_array()?;
|
||||
let result = with_match_numerics_types!(
|
||||
col.data_type(),
|
||||
|$S| {
|
||||
let col = col.as_primitive::<$S>();
|
||||
// Index safety: checked above, both are not nulls.
|
||||
let min = min.as_primitive::<$S>().value(0);
|
||||
let max = max.as_primitive::<$S>().value(0);
|
||||
if min > max {
|
||||
return Err(DataFusionError::Execution(format!(
|
||||
"min '{}' > max '{}'",
|
||||
min, max
|
||||
)));
|
||||
}
|
||||
Arc::new(PrimitiveArray::<$S>::from(
|
||||
(0..col.len())
|
||||
.map(|x| {
|
||||
col.is_valid(x).then(|| {
|
||||
let v = col.value(x);
|
||||
clamp!(v, min, max)
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
) as ArrayRef
|
||||
}
|
||||
)?;
|
||||
Ok(ColumnarValue::Array(result))
|
||||
}
|
||||
_ => Err(DataFusionError::Internal(
|
||||
"argument column types mismatch".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
@@ -197,59 +230,19 @@ impl Function for ClampMinFunction {
|
||||
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 2,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly 2, have: {}",
|
||||
columns.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
ensure!(
|
||||
columns[0].data_type().is_numeric(),
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The first arg's type is not numeric, have: {}",
|
||||
columns[0].data_type()
|
||||
),
|
||||
}
|
||||
);
|
||||
ensure!(
|
||||
columns[0].data_type() == columns[1].data_type(),
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"Arguments don't have identical types: {}, {}",
|
||||
columns[0].data_type(),
|
||||
columns[1].data_type()
|
||||
),
|
||||
}
|
||||
);
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let [col, min] = utils::take_function_args(self.name(), args.args)?;
|
||||
|
||||
ensure_constant_vector(&columns[1])?;
|
||||
|
||||
with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
|
||||
let input_array = columns[0].to_arrow_array();
|
||||
let input = input_array
|
||||
.as_any()
|
||||
.downcast_ref::<PrimitiveArray<<$S as LogicalPrimitiveType>::ArrowPrimitive>>()
|
||||
.unwrap();
|
||||
|
||||
let min = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0))
|
||||
.with_context(|| {
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: "The second arg (min) should not be none",
|
||||
}
|
||||
})?;
|
||||
// For clamp_min, max is effectively infinity, so we don't use it in the clamp_impl logic.
|
||||
// We pass a default/dummy value for max.
|
||||
let max_dummy = <$S as LogicalPrimitiveType>::Native::default();
|
||||
|
||||
clamp_impl::<$S, true, false>(input, min, max_dummy)
|
||||
},{
|
||||
unreachable!()
|
||||
})
|
||||
let Some(max) = ScalarValue::max(&min.data_type()) else {
|
||||
return Err(DataFusionError::Internal(format!(
|
||||
"cannot find a max value for numeric data type {}",
|
||||
min.data_type()
|
||||
)));
|
||||
};
|
||||
clamp_impl(col, min, ColumnarValue::Scalar(max))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -278,59 +271,19 @@ impl Function for ClampMaxFunction {
|
||||
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 2,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly 2, have: {}",
|
||||
columns.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
ensure!(
|
||||
columns[0].data_type().is_numeric(),
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The first arg's type is not numeric, have: {}",
|
||||
columns[0].data_type()
|
||||
),
|
||||
}
|
||||
);
|
||||
ensure!(
|
||||
columns[0].data_type() == columns[1].data_type(),
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"Arguments don't have identical types: {}, {}",
|
||||
columns[0].data_type(),
|
||||
columns[1].data_type()
|
||||
),
|
||||
}
|
||||
);
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let [col, max] = utils::take_function_args(self.name(), args.args)?;
|
||||
|
||||
ensure_constant_vector(&columns[1])?;
|
||||
|
||||
with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
|
||||
let input_array = columns[0].to_arrow_array();
|
||||
let input = input_array
|
||||
.as_any()
|
||||
.downcast_ref::<PrimitiveArray<<$S as LogicalPrimitiveType>::ArrowPrimitive>>()
|
||||
.unwrap();
|
||||
|
||||
let max = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0))
|
||||
.with_context(|| {
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: "The second arg (max) should not be none",
|
||||
}
|
||||
})?;
|
||||
// For clamp_max, min is effectively -infinity, so we don't use it in the clamp_impl logic.
|
||||
// We pass a default/dummy value for min.
|
||||
let min_dummy = <$S as LogicalPrimitiveType>::Native::default();
|
||||
|
||||
clamp_impl::<$S, false, true>(input, min_dummy, max)
|
||||
},{
|
||||
unreachable!()
|
||||
})
|
||||
let Some(min) = ScalarValue::min(&max.data_type()) else {
|
||||
return Err(DataFusionError::Internal(format!(
|
||||
"cannot find a min value for numeric data type {}",
|
||||
max.data_type()
|
||||
)));
|
||||
};
|
||||
clamp_impl(col, ColumnarValue::Scalar(min), max)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -345,55 +298,80 @@ mod test {
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use datatypes::prelude::ScalarVector;
|
||||
use datatypes::vectors::{
|
||||
ConstantVector, Float64Vector, Int64Vector, StringVector, UInt64Vector,
|
||||
};
|
||||
use arrow_schema::Field;
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
use datatypes::arrow::array::{ArrayRef, Float64Array, Int64Array, UInt64Array};
|
||||
use datatypes::arrow_array::StringArray;
|
||||
|
||||
use super::*;
|
||||
use crate::function::FunctionContext;
|
||||
|
||||
macro_rules! impl_test_eval {
|
||||
($func: ty) => {
|
||||
impl $func {
|
||||
fn test_eval(
|
||||
&self,
|
||||
args: Vec<ColumnarValue>,
|
||||
number_rows: usize,
|
||||
) -> datafusion_common::Result<ArrayRef> {
|
||||
let input_type = args[0].data_type();
|
||||
self.invoke_with_args(ScalarFunctionArgs {
|
||||
args,
|
||||
arg_fields: vec![],
|
||||
number_rows,
|
||||
return_field: Arc::new(Field::new("x", input_type, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
})
|
||||
.and_then(|v| ColumnarValue::values_to_arrays(&[v]).map_err(Into::into))
|
||||
.map(|mut a| a.remove(0))
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_test_eval!(ClampFunction);
|
||||
impl_test_eval!(ClampMinFunction);
|
||||
impl_test_eval!(ClampMaxFunction);
|
||||
|
||||
#[test]
|
||||
fn clamp_i64() {
|
||||
let inputs = [
|
||||
(
|
||||
vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
|
||||
-1,
|
||||
10,
|
||||
-1i64,
|
||||
10i64,
|
||||
vec![Some(-1), Some(-1), Some(-1), Some(0), Some(1), Some(2)],
|
||||
),
|
||||
(
|
||||
vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
|
||||
0,
|
||||
0,
|
||||
0i64,
|
||||
0i64,
|
||||
vec![Some(0), Some(0), Some(0), Some(0), Some(0), Some(0)],
|
||||
),
|
||||
(
|
||||
vec![Some(-3), None, Some(-1), None, None, Some(2)],
|
||||
-2,
|
||||
1,
|
||||
-2i64,
|
||||
1i64,
|
||||
vec![Some(-2), None, Some(-1), None, None, Some(1)],
|
||||
),
|
||||
(
|
||||
vec![None, None, None, None, None],
|
||||
0,
|
||||
1,
|
||||
0i64,
|
||||
1i64,
|
||||
vec![None, None, None, None, None],
|
||||
),
|
||||
];
|
||||
|
||||
let func = ClampFunction;
|
||||
for (in_data, min, max, expected) in inputs {
|
||||
let args = [
|
||||
Arc::new(Int64Vector::from(in_data)) as _,
|
||||
Arc::new(Int64Vector::from_vec(vec![min])) as _,
|
||||
Arc::new(Int64Vector::from_vec(vec![max])) as _,
|
||||
let number_rows = in_data.len();
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(Int64Array::from(in_data))),
|
||||
ColumnarValue::Scalar(min.into()),
|
||||
ColumnarValue::Scalar(max.into()),
|
||||
];
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), args.as_slice())
|
||||
.unwrap();
|
||||
let expected: VectorRef = Arc::new(Int64Vector::from(expected));
|
||||
assert_eq!(expected, result);
|
||||
let result = func.test_eval(args, number_rows).unwrap();
|
||||
let expected: ArrayRef = Arc::new(Int64Array::from(expected));
|
||||
assert_eq!(expected.as_ref(), result.as_ref());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -402,42 +380,41 @@ mod test {
|
||||
let inputs = [
|
||||
(
|
||||
vec![Some(0), Some(1), Some(2), Some(3), Some(4), Some(5)],
|
||||
1,
|
||||
3,
|
||||
1u64,
|
||||
3u64,
|
||||
vec![Some(1), Some(1), Some(2), Some(3), Some(3), Some(3)],
|
||||
),
|
||||
(
|
||||
vec![Some(0), Some(1), Some(2), Some(3), Some(4), Some(5)],
|
||||
0,
|
||||
0,
|
||||
0u64,
|
||||
0u64,
|
||||
vec![Some(0), Some(0), Some(0), Some(0), Some(0), Some(0)],
|
||||
),
|
||||
(
|
||||
vec![Some(0), None, Some(2), None, None, Some(5)],
|
||||
1,
|
||||
3,
|
||||
1u64,
|
||||
3u64,
|
||||
vec![Some(1), None, Some(2), None, None, Some(3)],
|
||||
),
|
||||
(
|
||||
vec![None, None, None, None, None],
|
||||
0,
|
||||
1,
|
||||
0u64,
|
||||
1u64,
|
||||
vec![None, None, None, None, None],
|
||||
),
|
||||
];
|
||||
|
||||
let func = ClampFunction;
|
||||
for (in_data, min, max, expected) in inputs {
|
||||
let args = [
|
||||
Arc::new(UInt64Vector::from(in_data)) as _,
|
||||
Arc::new(UInt64Vector::from_vec(vec![min])) as _,
|
||||
Arc::new(UInt64Vector::from_vec(vec![max])) as _,
|
||||
let number_rows = in_data.len();
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(UInt64Array::from(in_data))),
|
||||
ColumnarValue::Scalar(min.into()),
|
||||
ColumnarValue::Scalar(max.into()),
|
||||
];
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), args.as_slice())
|
||||
.unwrap();
|
||||
let expected: VectorRef = Arc::new(UInt64Vector::from(expected));
|
||||
assert_eq!(expected, result);
|
||||
let result = func.test_eval(args, number_rows).unwrap();
|
||||
let expected: ArrayRef = Arc::new(UInt64Array::from(expected));
|
||||
assert_eq!(expected.as_ref(), result.as_ref());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -472,38 +449,18 @@ mod test {
|
||||
|
||||
let func = ClampFunction;
|
||||
for (in_data, min, max, expected) in inputs {
|
||||
let args = [
|
||||
Arc::new(Float64Vector::from(in_data)) as _,
|
||||
Arc::new(Float64Vector::from_vec(vec![min])) as _,
|
||||
Arc::new(Float64Vector::from_vec(vec![max])) as _,
|
||||
let number_rows = in_data.len();
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(Float64Array::from(in_data))),
|
||||
ColumnarValue::Scalar(min.into()),
|
||||
ColumnarValue::Scalar(max.into()),
|
||||
];
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), args.as_slice())
|
||||
.unwrap();
|
||||
let expected: VectorRef = Arc::new(Float64Vector::from(expected));
|
||||
assert_eq!(expected, result);
|
||||
let result = func.test_eval(args, number_rows).unwrap();
|
||||
let expected: ArrayRef = Arc::new(Float64Array::from(expected));
|
||||
assert_eq!(expected.as_ref(), result.as_ref());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clamp_const_i32() {
|
||||
let input = vec![Some(5)];
|
||||
let min = 2;
|
||||
let max = 4;
|
||||
|
||||
let func = ClampFunction;
|
||||
let args = [
|
||||
Arc::new(ConstantVector::new(Arc::new(Int64Vector::from(input)), 1)) as _,
|
||||
Arc::new(Int64Vector::from_vec(vec![min])) as _,
|
||||
Arc::new(Int64Vector::from_vec(vec![max])) as _,
|
||||
];
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), args.as_slice())
|
||||
.unwrap();
|
||||
let expected: VectorRef = Arc::new(Int64Vector::from(vec![Some(4)]));
|
||||
assert_eq!(expected, result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clamp_invalid_min_max() {
|
||||
let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
|
||||
@@ -511,28 +468,30 @@ mod test {
|
||||
let max = -1.0;
|
||||
|
||||
let func = ClampFunction;
|
||||
let args = [
|
||||
Arc::new(Float64Vector::from(input)) as _,
|
||||
Arc::new(Float64Vector::from_vec(vec![min])) as _,
|
||||
Arc::new(Float64Vector::from_vec(vec![max])) as _,
|
||||
let number_rows = input.len();
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(Float64Array::from(input))),
|
||||
ColumnarValue::Scalar(min.into()),
|
||||
ColumnarValue::Scalar(max.into()),
|
||||
];
|
||||
let result = func.eval(&FunctionContext::default(), args.as_slice());
|
||||
let result = func.test_eval(args, number_rows);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clamp_type_not_match() {
|
||||
let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
|
||||
let min = -1;
|
||||
let max = 10;
|
||||
let min = -1i64;
|
||||
let max = 10u64;
|
||||
|
||||
let func = ClampFunction;
|
||||
let args = [
|
||||
Arc::new(Float64Vector::from(input)) as _,
|
||||
Arc::new(Int64Vector::from_vec(vec![min])) as _,
|
||||
Arc::new(UInt64Vector::from_vec(vec![max])) as _,
|
||||
let number_rows = input.len();
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(Float64Array::from(input))),
|
||||
ColumnarValue::Scalar(min.into()),
|
||||
ColumnarValue::Scalar(max.into()),
|
||||
];
|
||||
let result = func.eval(&FunctionContext::default(), args.as_slice());
|
||||
let result = func.test_eval(args, number_rows);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
@@ -543,12 +502,13 @@ mod test {
|
||||
let max = 1.0;
|
||||
|
||||
let func = ClampFunction;
|
||||
let args = [
|
||||
Arc::new(Float64Vector::from(input)) as _,
|
||||
Arc::new(Float64Vector::from_vec(vec![min, max])) as _,
|
||||
Arc::new(Float64Vector::from_vec(vec![max, min])) as _,
|
||||
let number_rows = input.len();
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(Float64Array::from(input))),
|
||||
ColumnarValue::Array(Arc::new(Float64Array::from(vec![min, max]))),
|
||||
ColumnarValue::Array(Arc::new(Float64Array::from(vec![max, min]))),
|
||||
];
|
||||
let result = func.eval(&FunctionContext::default(), args.as_slice());
|
||||
let result = func.test_eval(args, number_rows);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
@@ -558,11 +518,12 @@ mod test {
|
||||
let min = -10.0;
|
||||
|
||||
let func = ClampFunction;
|
||||
let args = [
|
||||
Arc::new(Float64Vector::from(input)) as _,
|
||||
Arc::new(Float64Vector::from_vec(vec![min])) as _,
|
||||
let number_rows = input.len();
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(Float64Array::from(input))),
|
||||
ColumnarValue::Scalar(min.into()),
|
||||
];
|
||||
let result = func.eval(&FunctionContext::default(), args.as_slice());
|
||||
let result = func.test_eval(args, number_rows);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
@@ -571,12 +532,13 @@ mod test {
|
||||
let input = vec![Some("foo"), Some("foo"), Some("foo"), Some("foo")];
|
||||
|
||||
let func = ClampFunction;
|
||||
let args = [
|
||||
Arc::new(StringVector::from(input)) as _,
|
||||
Arc::new(StringVector::from_vec(vec!["bar"])) as _,
|
||||
Arc::new(StringVector::from_vec(vec!["baz"])) as _,
|
||||
let number_rows = input.len();
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(StringArray::from(input))),
|
||||
ColumnarValue::Scalar("bar".into()),
|
||||
ColumnarValue::Scalar("baz".into()),
|
||||
];
|
||||
let result = func.eval(&FunctionContext::default(), args.as_slice());
|
||||
let result = func.test_eval(args, number_rows);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
@@ -585,27 +547,26 @@ mod test {
|
||||
let inputs = [
|
||||
(
|
||||
vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
|
||||
-1,
|
||||
-1i64,
|
||||
vec![Some(-1), Some(-1), Some(-1), Some(0), Some(1), Some(2)],
|
||||
),
|
||||
(
|
||||
vec![Some(-3), None, Some(-1), None, None, Some(2)],
|
||||
-2,
|
||||
-2i64,
|
||||
vec![Some(-2), None, Some(-1), None, None, Some(2)],
|
||||
),
|
||||
];
|
||||
|
||||
let func = ClampMinFunction;
|
||||
for (in_data, min, expected) in inputs {
|
||||
let args = [
|
||||
Arc::new(Int64Vector::from(in_data)) as _,
|
||||
Arc::new(Int64Vector::from_vec(vec![min])) as _,
|
||||
let number_rows = in_data.len();
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(Int64Array::from(in_data))),
|
||||
ColumnarValue::Scalar(min.into()),
|
||||
];
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), args.as_slice())
|
||||
.unwrap();
|
||||
let expected: VectorRef = Arc::new(Int64Vector::from(expected));
|
||||
assert_eq!(expected, result);
|
||||
let result = func.test_eval(args, number_rows).unwrap();
|
||||
let expected: ArrayRef = Arc::new(Int64Array::from(expected));
|
||||
assert_eq!(expected.as_ref(), result.as_ref());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -614,27 +575,26 @@ mod test {
|
||||
let inputs = [
|
||||
(
|
||||
vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
|
||||
1,
|
||||
1i64,
|
||||
vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(1)],
|
||||
),
|
||||
(
|
||||
vec![Some(-3), None, Some(-1), None, None, Some(2)],
|
||||
0,
|
||||
0i64,
|
||||
vec![Some(-3), None, Some(-1), None, None, Some(0)],
|
||||
),
|
||||
];
|
||||
|
||||
let func = ClampMaxFunction;
|
||||
for (in_data, max, expected) in inputs {
|
||||
let args = [
|
||||
Arc::new(Int64Vector::from(in_data)) as _,
|
||||
Arc::new(Int64Vector::from_vec(vec![max])) as _,
|
||||
let number_rows = in_data.len();
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(Int64Array::from(in_data))),
|
||||
ColumnarValue::Scalar(max.into()),
|
||||
];
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), args.as_slice())
|
||||
.unwrap();
|
||||
let expected: VectorRef = Arc::new(Int64Vector::from(expected));
|
||||
assert_eq!(expected, result);
|
||||
let result = func.test_eval(args, number_rows).unwrap();
|
||||
let expected: ArrayRef = Arc::new(Int64Array::from(expected));
|
||||
assert_eq!(expected.as_ref(), result.as_ref());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -648,15 +608,14 @@ mod test {
|
||||
|
||||
let func = ClampMinFunction;
|
||||
for (in_data, min, expected) in inputs {
|
||||
let args = [
|
||||
Arc::new(Float64Vector::from(in_data)) as _,
|
||||
Arc::new(Float64Vector::from_vec(vec![min])) as _,
|
||||
let number_rows = in_data.len();
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(Float64Array::from(in_data))),
|
||||
ColumnarValue::Scalar(min.into()),
|
||||
];
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), args.as_slice())
|
||||
.unwrap();
|
||||
let expected: VectorRef = Arc::new(Float64Vector::from(expected));
|
||||
assert_eq!(expected, result);
|
||||
let result = func.test_eval(args, number_rows).unwrap();
|
||||
let expected: ArrayRef = Arc::new(Float64Array::from(expected));
|
||||
assert_eq!(expected.as_ref(), result.as_ref());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -670,43 +629,44 @@ mod test {
|
||||
|
||||
let func = ClampMaxFunction;
|
||||
for (in_data, max, expected) in inputs {
|
||||
let args = [
|
||||
Arc::new(Float64Vector::from(in_data)) as _,
|
||||
Arc::new(Float64Vector::from_vec(vec![max])) as _,
|
||||
let number_rows = in_data.len();
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(Float64Array::from(in_data))),
|
||||
ColumnarValue::Scalar(max.into()),
|
||||
];
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), args.as_slice())
|
||||
.unwrap();
|
||||
let expected: VectorRef = Arc::new(Float64Vector::from(expected));
|
||||
assert_eq!(expected, result);
|
||||
let result = func.test_eval(args, number_rows).unwrap();
|
||||
let expected: ArrayRef = Arc::new(Float64Array::from(expected));
|
||||
assert_eq!(expected.as_ref(), result.as_ref());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clamp_min_type_not_match() {
|
||||
let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
|
||||
let min = -1;
|
||||
let min = -1i64;
|
||||
|
||||
let func = ClampMinFunction;
|
||||
let args = [
|
||||
Arc::new(Float64Vector::from(input)) as _,
|
||||
Arc::new(Int64Vector::from_vec(vec![min])) as _,
|
||||
let number_rows = input.len();
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(Float64Array::from(input))),
|
||||
ColumnarValue::Scalar(min.into()),
|
||||
];
|
||||
let result = func.eval(&FunctionContext::default(), args.as_slice());
|
||||
let result = func.test_eval(args, number_rows);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clamp_max_type_not_match() {
|
||||
let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
|
||||
let max = 1;
|
||||
let max = 1i64;
|
||||
|
||||
let func = ClampMaxFunction;
|
||||
let args = [
|
||||
Arc::new(Float64Vector::from(input)) as _,
|
||||
Arc::new(Int64Vector::from_vec(vec![max])) as _,
|
||||
let number_rows = input.len();
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(Float64Array::from(input))),
|
||||
ColumnarValue::Scalar(max.into()),
|
||||
];
|
||||
let result = func.eval(&FunctionContext::default(), args.as_slice());
|
||||
let result = func.test_eval(args, number_rows);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,7 +28,14 @@ mod vector_norm;
|
||||
mod vector_sub;
|
||||
mod vector_subvector;
|
||||
|
||||
use std::borrow::Cow;
|
||||
|
||||
use datafusion_common::{DataFusionError, Result, ScalarValue, utils};
|
||||
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
|
||||
|
||||
use crate::function_registry::FunctionRegistry;
|
||||
use crate::scalars::vector::impl_conv::as_veclit;
|
||||
|
||||
pub(crate) struct VectorFunction;
|
||||
|
||||
impl VectorFunction {
|
||||
@@ -59,3 +66,155 @@ impl VectorFunction {
|
||||
registry.register_scalar(elem_product::ElemProductFunction);
|
||||
}
|
||||
}
|
||||
|
||||
// Use macro instead of function to "return" the reference to `ScalarValue` in the
|
||||
// `ColumnarValue::Array` match arm.
|
||||
macro_rules! try_get_scalar_value {
|
||||
($col: ident, $i: ident) => {
|
||||
match $col {
|
||||
datafusion::logical_expr::ColumnarValue::Array(a) => {
|
||||
&datafusion_common::ScalarValue::try_from_array(a.as_ref(), $i)?
|
||||
}
|
||||
datafusion::logical_expr::ColumnarValue::Scalar(v) => v,
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub(crate) fn ensure_same_length(values: &[&ColumnarValue]) -> Result<usize> {
|
||||
if values.is_empty() {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
let mut array_len = None;
|
||||
for v in values {
|
||||
array_len = match (v, array_len) {
|
||||
(ColumnarValue::Array(a), None) => Some(a.len()),
|
||||
(ColumnarValue::Array(a), Some(array_len)) => {
|
||||
if array_len == a.len() {
|
||||
Some(array_len)
|
||||
} else {
|
||||
return Err(DataFusionError::Internal(format!(
|
||||
"Arguments has mixed length. Expected length: {array_len}, found length: {}",
|
||||
a.len()
|
||||
)));
|
||||
}
|
||||
}
|
||||
(ColumnarValue::Scalar(_), array_len) => array_len,
|
||||
}
|
||||
}
|
||||
|
||||
// If array_len is none, it means there are only scalars, treat them each as 1 element array.
|
||||
let array_len = array_len.unwrap_or(1);
|
||||
Ok(array_len)
|
||||
}
|
||||
|
||||
struct VectorCalculator<'a, F> {
|
||||
name: &'a str,
|
||||
func: F,
|
||||
}
|
||||
|
||||
impl<F> VectorCalculator<'_, F>
|
||||
where
|
||||
F: Fn(&ScalarValue, &ScalarValue) -> Result<ScalarValue>,
|
||||
{
|
||||
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
|
||||
let [arg0, arg1] = utils::take_function_args(self.name, &args.args)?;
|
||||
|
||||
if let (ColumnarValue::Scalar(v0), ColumnarValue::Scalar(v1)) = (arg0, arg1) {
|
||||
let result = (self.func)(v0, v1)?;
|
||||
return Ok(ColumnarValue::Scalar(result));
|
||||
}
|
||||
|
||||
let len = ensure_same_length(&[arg0, arg1])?;
|
||||
let mut results = Vec::with_capacity(len);
|
||||
for i in 0..len {
|
||||
let v0 = try_get_scalar_value!(arg0, i);
|
||||
let v1 = try_get_scalar_value!(arg1, i);
|
||||
results.push((self.func)(v0, v1)?);
|
||||
}
|
||||
|
||||
let results = ScalarValue::iter_to_array(results.into_iter())?;
|
||||
Ok(ColumnarValue::Array(results))
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> VectorCalculator<'_, F>
|
||||
where
|
||||
F: Fn(&Option<Cow<[f32]>>, &Option<Cow<[f32]>>) -> Result<ScalarValue>,
|
||||
{
|
||||
fn invoke_with_vectors(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
|
||||
let [arg0, arg1] = utils::take_function_args(self.name, &args.args)?;
|
||||
|
||||
if let (ColumnarValue::Scalar(v0), ColumnarValue::Scalar(v1)) = (arg0, arg1) {
|
||||
let v0 = as_veclit(v0)?;
|
||||
let v1 = as_veclit(v1)?;
|
||||
let result = (self.func)(&v0, &v1)?;
|
||||
return Ok(ColumnarValue::Scalar(result));
|
||||
}
|
||||
|
||||
let len = ensure_same_length(&[arg0, arg1])?;
|
||||
let mut results = Vec::with_capacity(len);
|
||||
|
||||
match (arg0, arg1) {
|
||||
(ColumnarValue::Scalar(v0), ColumnarValue::Array(a1)) => {
|
||||
let v0 = as_veclit(v0)?;
|
||||
for i in 0..len {
|
||||
let v1 = ScalarValue::try_from_array(a1, i)?;
|
||||
let v1 = as_veclit(&v1)?;
|
||||
results.push((self.func)(&v0, &v1)?);
|
||||
}
|
||||
}
|
||||
(ColumnarValue::Array(a0), ColumnarValue::Scalar(v1)) => {
|
||||
let v1 = as_veclit(v1)?;
|
||||
for i in 0..len {
|
||||
let v0 = ScalarValue::try_from_array(a0, i)?;
|
||||
let v0 = as_veclit(&v0)?;
|
||||
results.push((self.func)(&v0, &v1)?);
|
||||
}
|
||||
}
|
||||
(ColumnarValue::Array(a0), ColumnarValue::Array(a1)) => {
|
||||
for i in 0..len {
|
||||
let v0 = ScalarValue::try_from_array(a0, i)?;
|
||||
let v0 = as_veclit(&v0)?;
|
||||
let v1 = ScalarValue::try_from_array(a1, i)?;
|
||||
let v1 = as_veclit(&v1)?;
|
||||
results.push((self.func)(&v0, &v1)?);
|
||||
}
|
||||
}
|
||||
(ColumnarValue::Scalar(_), ColumnarValue::Scalar(_)) => {
|
||||
// unreachable because this arm has been separately dealt with above
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
let results = ScalarValue::iter_to_array(results.into_iter())?;
|
||||
Ok(ColumnarValue::Array(results))
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> VectorCalculator<'_, F>
|
||||
where
|
||||
F: Fn(&ScalarValue) -> Result<ScalarValue>,
|
||||
{
|
||||
fn invoke_with_single_argument(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
|
||||
let [arg0] = utils::take_function_args(self.name, &args.args)?;
|
||||
|
||||
let arg0 = match arg0 {
|
||||
ColumnarValue::Scalar(v) => {
|
||||
let result = (self.func)(v)?;
|
||||
return Ok(ColumnarValue::Scalar(result));
|
||||
}
|
||||
ColumnarValue::Array(a) => a,
|
||||
};
|
||||
|
||||
let len = arg0.len();
|
||||
let mut results = Vec::with_capacity(len);
|
||||
for i in 0..len {
|
||||
let v = ScalarValue::try_from_array(arg0, i)?;
|
||||
results.push((self.func)(&v)?);
|
||||
}
|
||||
|
||||
let results = ScalarValue::iter_to_array(results.into_iter())?;
|
||||
Ok(ColumnarValue::Array(results))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ use std::fmt::Display;
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion_expr::type_coercion::aggregates::BINARYS;
|
||||
use datafusion_expr::{Signature, Volatility};
|
||||
use datafusion_expr::{Signature, TypeSignature, Volatility};
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::types::vector_type_value_to_string;
|
||||
use datatypes::value::Value;
|
||||
@@ -41,7 +41,13 @@ impl Function for VectorToStringFunction {
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
Signature::uniform(1, BINARYS.to_vec(), Volatility::Immutable)
|
||||
Signature::one_of(
|
||||
vec![
|
||||
TypeSignature::Uniform(1, vec![DataType::BinaryView]),
|
||||
TypeSignature::Uniform(1, BINARYS.to_vec()),
|
||||
],
|
||||
Volatility::Immutable,
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
|
||||
@@ -19,20 +19,17 @@ mod l2sq;
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Display;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use datafusion_expr::Signature;
|
||||
use common_query::error::Result;
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_common::ScalarValue;
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature};
|
||||
use datatypes::arrow::datatypes::DataType;
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef};
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::Function;
|
||||
use crate::helper;
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
|
||||
|
||||
macro_rules! define_distance_function {
|
||||
($StructName:ident, $display_name:expr, $similarity_method:path) => {
|
||||
|
||||
/// A function calculates the distance between two vectors.
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
@@ -54,59 +51,34 @@ macro_rules! define_distance_function {
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 2,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly two, have: {}",
|
||||
columns.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
let arg0 = &columns[0];
|
||||
let arg1 = &columns[1];
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let body = |v0: &Option<Cow<[f32]>>,
|
||||
v1: &Option<Cow<[f32]>>|
|
||||
-> datafusion_common::Result<ScalarValue> {
|
||||
let result = if let (Some(v0), Some(v1)) = (v0, v1) {
|
||||
if v0.len() != v1.len() {
|
||||
return Err(datafusion_common::DataFusionError::Execution(format!(
|
||||
"vectors length not match: {}",
|
||||
self.name()
|
||||
)));
|
||||
}
|
||||
|
||||
let size = arg0.len();
|
||||
let mut result = Float32VectorBuilder::with_capacity(size);
|
||||
if size == 0 {
|
||||
return Ok(result.to_vector());
|
||||
}
|
||||
|
||||
let arg0_const = as_veclit_if_const(arg0)?;
|
||||
let arg1_const = as_veclit_if_const(arg1)?;
|
||||
|
||||
for i in 0..size {
|
||||
let vec0 = match arg0_const.as_ref() {
|
||||
Some(a) => Some(Cow::Borrowed(a.as_ref())),
|
||||
None => as_veclit(arg0.get_ref(i))?,
|
||||
};
|
||||
let vec1 = match arg1_const.as_ref() {
|
||||
Some(b) => Some(Cow::Borrowed(b.as_ref())),
|
||||
None => as_veclit(arg1.get_ref(i))?,
|
||||
};
|
||||
|
||||
if let (Some(vec0), Some(vec1)) = (vec0, vec1) {
|
||||
ensure!(
|
||||
vec0.len() == vec1.len(),
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the vectors must match to calculate distance, have: {} vs {}",
|
||||
vec0.len(),
|
||||
vec1.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
|
||||
// Checked if the length of the vectors match
|
||||
let d = $similarity_method(vec0.as_ref(), vec1.as_ref());
|
||||
result.push(Some(d));
|
||||
let d = $similarity_method(v0, v1);
|
||||
Some(d)
|
||||
} else {
|
||||
result.push_null();
|
||||
}
|
||||
}
|
||||
None
|
||||
};
|
||||
Ok(ScalarValue::Float32(result))
|
||||
};
|
||||
|
||||
return Ok(result.to_vector());
|
||||
let calculator = $crate::scalars::vector::VectorCalculator {
|
||||
name: self.name(),
|
||||
func: body,
|
||||
};
|
||||
calculator.invoke_with_vectors(args)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,7 +87,7 @@ macro_rules! define_distance_function {
|
||||
write!(f, "{}", $display_name.to_ascii_uppercase())
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
define_distance_function!(CosDistanceFunction, "vec_cos_distance", cos::cos);
|
||||
@@ -126,10 +98,29 @@ define_distance_function!(DotProductFunction, "vec_dot_product", dot::dot);
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datatypes::vectors::{BinaryVector, ConstantVector, StringVector};
|
||||
use arrow_schema::Field;
|
||||
use datafusion::arrow::array::{Array, ArrayRef, AsArray, BinaryArray, StringViewArray};
|
||||
use datafusion::arrow::datatypes::Float32Type;
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn test_invoke(func: &dyn Function, args: &[ArrayRef]) -> datafusion_common::Result<ArrayRef> {
|
||||
let number_rows = args[0].len();
|
||||
let args = ScalarFunctionArgs {
|
||||
args: args
|
||||
.iter()
|
||||
.map(|x| ColumnarValue::Array(x.clone()))
|
||||
.collect::<Vec<_>>(),
|
||||
arg_fields: vec![],
|
||||
number_rows,
|
||||
return_field: Arc::new(Field::new("x", DataType::Float32, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
func.invoke_with_args(args)
|
||||
.and_then(|x| x.to_array(number_rows))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_distance_string_string() {
|
||||
let funcs = [
|
||||
@@ -139,36 +130,34 @@ mod tests {
|
||||
];
|
||||
|
||||
for func in funcs {
|
||||
let vec1 = Arc::new(StringVector::from(vec![
|
||||
let vec1: ArrayRef = Arc::new(StringViewArray::from(vec![
|
||||
Some("[0.0, 1.0]"),
|
||||
Some("[1.0, 0.0]"),
|
||||
None,
|
||||
Some("[1.0, 0.0]"),
|
||||
])) as VectorRef;
|
||||
let vec2 = Arc::new(StringVector::from(vec![
|
||||
]));
|
||||
let vec2: ArrayRef = Arc::new(StringViewArray::from(vec![
|
||||
Some("[0.0, 1.0]"),
|
||||
Some("[0.0, 1.0]"),
|
||||
Some("[0.0, 1.0]"),
|
||||
None,
|
||||
])) as VectorRef;
|
||||
]));
|
||||
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), &[vec1.clone(), vec2.clone()])
|
||||
.unwrap();
|
||||
let result = test_invoke(func.as_ref(), &[vec1.clone(), vec2.clone()]).unwrap();
|
||||
let result = result.as_primitive::<Float32Type>();
|
||||
|
||||
assert!(!result.get(0).is_null());
|
||||
assert!(!result.get(1).is_null());
|
||||
assert!(result.get(2).is_null());
|
||||
assert!(result.get(3).is_null());
|
||||
assert!(!result.is_null(0));
|
||||
assert!(!result.is_null(1));
|
||||
assert!(result.is_null(2));
|
||||
assert!(result.is_null(3));
|
||||
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), &[vec2, vec1])
|
||||
.unwrap();
|
||||
let result = test_invoke(func.as_ref(), &[vec2, vec1]).unwrap();
|
||||
let result = result.as_primitive::<Float32Type>();
|
||||
|
||||
assert!(!result.get(0).is_null());
|
||||
assert!(!result.get(1).is_null());
|
||||
assert!(result.get(2).is_null());
|
||||
assert!(result.get(3).is_null());
|
||||
assert!(!result.is_null(0));
|
||||
assert!(!result.is_null(1));
|
||||
assert!(result.is_null(2));
|
||||
assert!(result.is_null(3));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -181,37 +170,35 @@ mod tests {
|
||||
];
|
||||
|
||||
for func in funcs {
|
||||
let vec1 = Arc::new(BinaryVector::from(vec![
|
||||
let vec1: ArrayRef = Arc::new(BinaryArray::from_iter(vec![
|
||||
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
|
||||
Some(vec![0, 0, 128, 63, 0, 0, 0, 0]),
|
||||
None,
|
||||
Some(vec![0, 0, 128, 63, 0, 0, 0, 0]),
|
||||
])) as VectorRef;
|
||||
let vec2 = Arc::new(BinaryVector::from(vec![
|
||||
]));
|
||||
let vec2: ArrayRef = Arc::new(BinaryArray::from_iter(vec![
|
||||
// [0.0, 1.0]
|
||||
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
|
||||
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
|
||||
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
|
||||
None,
|
||||
])) as VectorRef;
|
||||
]));
|
||||
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), &[vec1.clone(), vec2.clone()])
|
||||
.unwrap();
|
||||
let result = test_invoke(func.as_ref(), &[vec1.clone(), vec2.clone()]).unwrap();
|
||||
let result = result.as_primitive::<Float32Type>();
|
||||
|
||||
assert!(!result.get(0).is_null());
|
||||
assert!(!result.get(1).is_null());
|
||||
assert!(result.get(2).is_null());
|
||||
assert!(result.get(3).is_null());
|
||||
assert!(!result.is_null(0));
|
||||
assert!(!result.is_null(1));
|
||||
assert!(result.is_null(2));
|
||||
assert!(result.is_null(3));
|
||||
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), &[vec2, vec1])
|
||||
.unwrap();
|
||||
let result = test_invoke(func.as_ref(), &[vec2, vec1]).unwrap();
|
||||
let result = result.as_primitive::<Float32Type>();
|
||||
|
||||
assert!(!result.get(0).is_null());
|
||||
assert!(!result.get(1).is_null());
|
||||
assert!(result.get(2).is_null());
|
||||
assert!(result.get(3).is_null());
|
||||
assert!(!result.is_null(0));
|
||||
assert!(!result.is_null(1));
|
||||
assert!(result.is_null(2));
|
||||
assert!(result.is_null(3));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -224,115 +211,35 @@ mod tests {
|
||||
];
|
||||
|
||||
for func in funcs {
|
||||
let vec1 = Arc::new(StringVector::from(vec![
|
||||
let vec1: ArrayRef = Arc::new(StringViewArray::from(vec![
|
||||
Some("[0.0, 1.0]"),
|
||||
Some("[1.0, 0.0]"),
|
||||
None,
|
||||
Some("[1.0, 0.0]"),
|
||||
])) as VectorRef;
|
||||
let vec2 = Arc::new(BinaryVector::from(vec![
|
||||
]));
|
||||
let vec2: ArrayRef = Arc::new(BinaryArray::from_iter(vec![
|
||||
// [0.0, 1.0]
|
||||
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
|
||||
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
|
||||
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
|
||||
None,
|
||||
])) as VectorRef;
|
||||
]));
|
||||
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), &[vec1.clone(), vec2.clone()])
|
||||
.unwrap();
|
||||
let result = test_invoke(func.as_ref(), &[vec1.clone(), vec2.clone()]).unwrap();
|
||||
let result = result.as_primitive::<Float32Type>();
|
||||
|
||||
assert!(!result.get(0).is_null());
|
||||
assert!(!result.get(1).is_null());
|
||||
assert!(result.get(2).is_null());
|
||||
assert!(result.get(3).is_null());
|
||||
assert!(!result.is_null(0));
|
||||
assert!(!result.is_null(1));
|
||||
assert!(result.is_null(2));
|
||||
assert!(result.is_null(3));
|
||||
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), &[vec2, vec1])
|
||||
.unwrap();
|
||||
let result = test_invoke(func.as_ref(), &[vec2, vec1]).unwrap();
|
||||
let result = result.as_primitive::<Float32Type>();
|
||||
|
||||
assert!(!result.get(0).is_null());
|
||||
assert!(!result.get(1).is_null());
|
||||
assert!(result.get(2).is_null());
|
||||
assert!(result.get(3).is_null());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_distance_const_string() {
|
||||
let funcs = [
|
||||
Box::new(CosDistanceFunction {}) as Box<dyn Function>,
|
||||
Box::new(L2SqDistanceFunction {}) as Box<dyn Function>,
|
||||
Box::new(DotProductFunction {}) as Box<dyn Function>,
|
||||
];
|
||||
|
||||
for func in funcs {
|
||||
let const_str = Arc::new(ConstantVector::new(
|
||||
Arc::new(StringVector::from(vec!["[0.0, 1.0]"])),
|
||||
4,
|
||||
));
|
||||
|
||||
let vec1 = Arc::new(StringVector::from(vec![
|
||||
Some("[0.0, 1.0]"),
|
||||
Some("[1.0, 0.0]"),
|
||||
None,
|
||||
Some("[1.0, 0.0]"),
|
||||
])) as VectorRef;
|
||||
let vec2 = Arc::new(BinaryVector::from(vec![
|
||||
// [0.0, 1.0]
|
||||
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
|
||||
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
|
||||
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
|
||||
None,
|
||||
])) as VectorRef;
|
||||
|
||||
let result = func
|
||||
.eval(
|
||||
&FunctionContext::default(),
|
||||
&[const_str.clone(), vec1.clone()],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.get(0).is_null());
|
||||
assert!(!result.get(1).is_null());
|
||||
assert!(result.get(2).is_null());
|
||||
assert!(!result.get(3).is_null());
|
||||
|
||||
let result = func
|
||||
.eval(
|
||||
&FunctionContext::default(),
|
||||
&[vec1.clone(), const_str.clone()],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.get(0).is_null());
|
||||
assert!(!result.get(1).is_null());
|
||||
assert!(result.get(2).is_null());
|
||||
assert!(!result.get(3).is_null());
|
||||
|
||||
let result = func
|
||||
.eval(
|
||||
&FunctionContext::default(),
|
||||
&[const_str.clone(), vec2.clone()],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.get(0).is_null());
|
||||
assert!(!result.get(1).is_null());
|
||||
assert!(!result.get(2).is_null());
|
||||
assert!(result.get(3).is_null());
|
||||
|
||||
let result = func
|
||||
.eval(
|
||||
&FunctionContext::default(),
|
||||
&[vec2.clone(), const_str.clone()],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.get(0).is_null());
|
||||
assert!(!result.get(1).is_null());
|
||||
assert!(!result.get(2).is_null());
|
||||
assert!(result.get(3).is_null());
|
||||
assert!(!result.is_null(0));
|
||||
assert!(!result.is_null(1));
|
||||
assert!(result.is_null(2));
|
||||
assert!(result.is_null(3));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -345,15 +252,16 @@ mod tests {
|
||||
];
|
||||
|
||||
for func in funcs {
|
||||
let vec1 = Arc::new(StringVector::from(vec!["[1.0]"])) as VectorRef;
|
||||
let vec2 = Arc::new(StringVector::from(vec!["[1.0, 1.0]"])) as VectorRef;
|
||||
let result = func.eval(&FunctionContext::default(), &[vec1, vec2]);
|
||||
let vec1: ArrayRef = Arc::new(StringViewArray::from(vec!["[1.0]"]));
|
||||
let vec2: ArrayRef = Arc::new(StringViewArray::from(vec!["[1.0, 1.0]"]));
|
||||
let result = test_invoke(func.as_ref(), &[vec1, vec2]);
|
||||
assert!(result.is_err());
|
||||
|
||||
let vec1 = Arc::new(BinaryVector::from(vec![vec![0, 0, 128, 63]])) as VectorRef;
|
||||
let vec2 =
|
||||
Arc::new(BinaryVector::from(vec![vec![0, 0, 128, 63, 0, 0, 0, 64]])) as VectorRef;
|
||||
let result = func.eval(&FunctionContext::default(), &[vec1, vec2]);
|
||||
let vec1: ArrayRef = Arc::new(BinaryArray::from_iter_values(vec![vec![0, 0, 128, 63]]));
|
||||
let vec2: ArrayRef = Arc::new(BinaryArray::from_iter_values(vec![vec![
|
||||
0, 0, 128, 63, 0, 0, 0, 64,
|
||||
]]));
|
||||
let result = test_invoke(func.as_ref(), &[vec1, vec2]);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,20 +12,18 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Display;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use common_query::error::Result;
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion::logical_expr_common::type_coercion::aggregates::{BINARYS, STRINGS};
|
||||
use datafusion_expr::{Signature, TypeSignature, Volatility};
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef};
|
||||
use datafusion_common::ScalarValue;
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
|
||||
use nalgebra::DVectorView;
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
|
||||
use crate::function::Function;
|
||||
use crate::scalars::vector::{VectorCalculator, impl_conv};
|
||||
|
||||
const NAME: &str = "vec_elem_product";
|
||||
|
||||
@@ -64,43 +62,21 @@ impl Function for ElemProductFunction {
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
_func_ctx: &FunctionContext,
|
||||
columns: &[VectorRef],
|
||||
) -> common_query::error::Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 1,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly one, have: {}",
|
||||
columns.len()
|
||||
)
|
||||
}
|
||||
);
|
||||
let arg0 = &columns[0];
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
|
||||
let v0 = impl_conv::as_veclit(v0)?
|
||||
.map(|v0| DVectorView::from_slice(&v0, v0.len()).product());
|
||||
Ok(ScalarValue::Float32(v0))
|
||||
};
|
||||
|
||||
let len = arg0.len();
|
||||
let mut result = Float32VectorBuilder::with_capacity(len);
|
||||
if len == 0 {
|
||||
return Ok(result.to_vector());
|
||||
}
|
||||
|
||||
let arg0_const = as_veclit_if_const(arg0)?;
|
||||
|
||||
for i in 0..len {
|
||||
let arg0 = match arg0_const.as_ref() {
|
||||
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
|
||||
None => as_veclit(arg0.get_ref(i))?,
|
||||
};
|
||||
let Some(arg0) = arg0 else {
|
||||
result.push_null();
|
||||
continue;
|
||||
};
|
||||
result.push(Some(DVectorView::from_slice(&arg0, arg0.len()).product()));
|
||||
}
|
||||
|
||||
Ok(result.to_vector())
|
||||
let calculator = VectorCalculator {
|
||||
name: self.name(),
|
||||
func: body,
|
||||
};
|
||||
calculator.invoke_with_single_argument(args)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,27 +90,39 @@ impl Display for ElemProductFunction {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datatypes::vectors::StringVector;
|
||||
use arrow_schema::Field;
|
||||
use datafusion::arrow::array::{Array, AsArray, StringArray};
|
||||
use datafusion::arrow::datatypes::Float32Type;
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
|
||||
use super::*;
|
||||
use crate::function::FunctionContext;
|
||||
|
||||
#[test]
|
||||
fn test_elem_product() {
|
||||
let func = ElemProductFunction;
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![
|
||||
let input = Arc::new(StringArray::from(vec![
|
||||
Some("[1.0,2.0,3.0]".to_string()),
|
||||
Some("[4.0,5.0,6.0]".to_string()),
|
||||
None,
|
||||
]));
|
||||
|
||||
let result = func.eval(&FunctionContext::default(), &[input0]).unwrap();
|
||||
let result = func
|
||||
.invoke_with_args(ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input.clone())],
|
||||
arg_fields: vec![],
|
||||
number_rows: input.len(),
|
||||
return_field: Arc::new(Field::new("x", DataType::Float32, true)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
})
|
||||
.and_then(|v| ColumnarValue::values_to_arrays(&[v]))
|
||||
.map(|mut a| a.remove(0))
|
||||
.unwrap();
|
||||
let result = result.as_primitive::<Float32Type>();
|
||||
|
||||
let result = result.as_ref();
|
||||
assert_eq!(result.len(), 3);
|
||||
assert_eq!(result.get_ref(0).as_f32().unwrap(), Some(6.0));
|
||||
assert_eq!(result.get_ref(1).as_f32().unwrap(), Some(120.0));
|
||||
assert_eq!(result.get_ref(2).as_f32().unwrap(), None);
|
||||
assert_eq!(result.value(0), 6.0);
|
||||
assert_eq!(result.value(1), 120.0);
|
||||
assert!(result.is_null(2));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,20 +12,18 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Display;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use common_query::error::Result;
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_common::ScalarValue;
|
||||
use datafusion_expr::type_coercion::aggregates::{BINARYS, STRINGS};
|
||||
use datafusion_expr::{Signature, TypeSignature, Volatility};
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef};
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
|
||||
use nalgebra::DVectorView;
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
|
||||
use crate::function::Function;
|
||||
use crate::scalars::vector::{VectorCalculator, impl_conv};
|
||||
|
||||
const NAME: &str = "vec_elem_sum";
|
||||
|
||||
@@ -51,43 +49,21 @@ impl Function for ElemSumFunction {
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
_func_ctx: &FunctionContext,
|
||||
columns: &[VectorRef],
|
||||
) -> common_query::error::Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 1,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly one, have: {}",
|
||||
columns.len()
|
||||
)
|
||||
}
|
||||
);
|
||||
let arg0 = &columns[0];
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
|
||||
let v0 =
|
||||
impl_conv::as_veclit(v0)?.map(|v0| DVectorView::from_slice(&v0, v0.len()).sum());
|
||||
Ok(ScalarValue::Float32(v0))
|
||||
};
|
||||
|
||||
let len = arg0.len();
|
||||
let mut result = Float32VectorBuilder::with_capacity(len);
|
||||
if len == 0 {
|
||||
return Ok(result.to_vector());
|
||||
}
|
||||
|
||||
let arg0_const = as_veclit_if_const(arg0)?;
|
||||
|
||||
for i in 0..len {
|
||||
let arg0 = match arg0_const.as_ref() {
|
||||
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
|
||||
None => as_veclit(arg0.get_ref(i))?,
|
||||
};
|
||||
let Some(arg0) = arg0 else {
|
||||
result.push_null();
|
||||
continue;
|
||||
};
|
||||
result.push(Some(DVectorView::from_slice(&arg0, arg0.len()).sum()));
|
||||
}
|
||||
|
||||
Ok(result.to_vector())
|
||||
let calculator = VectorCalculator {
|
||||
name: self.name(),
|
||||
func: body,
|
||||
};
|
||||
calculator.invoke_with_single_argument(args)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -101,27 +77,40 @@ impl Display for ElemSumFunction {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datatypes::vectors::StringVector;
|
||||
use arrow::array::StringViewArray;
|
||||
use arrow_schema::Field;
|
||||
use datafusion::arrow::array::{Array, AsArray};
|
||||
use datafusion::arrow::datatypes::Float32Type;
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
|
||||
use super::*;
|
||||
use crate::function::FunctionContext;
|
||||
|
||||
#[test]
|
||||
fn test_elem_sum() {
|
||||
let func = ElemSumFunction;
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![
|
||||
let input = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0,2.0,3.0]".to_string()),
|
||||
Some("[4.0,5.0,6.0]".to_string()),
|
||||
None,
|
||||
]));
|
||||
|
||||
let result = func.eval(&FunctionContext::default(), &[input0]).unwrap();
|
||||
let result = func
|
||||
.invoke_with_args(ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input.clone())],
|
||||
arg_fields: vec![],
|
||||
number_rows: input.len(),
|
||||
return_field: Arc::new(Field::new("x", DataType::Float32, true)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
})
|
||||
.and_then(|v| ColumnarValue::values_to_arrays(&[v]))
|
||||
.map(|mut a| a.remove(0))
|
||||
.unwrap();
|
||||
let result = result.as_primitive::<Float32Type>();
|
||||
|
||||
let result = result.as_ref();
|
||||
assert_eq!(result.len(), 3);
|
||||
assert_eq!(result.get_ref(0).as_f32().unwrap(), Some(6.0));
|
||||
assert_eq!(result.get_ref(1).as_f32().unwrap(), Some(15.0));
|
||||
assert_eq!(result.get_ref(2).as_f32().unwrap(), None);
|
||||
assert_eq!(result.value(0), 6.0);
|
||||
assert_eq!(result.value(1), 15.0);
|
||||
assert!(result.is_null(2));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,40 +13,18 @@
|
||||
// limitations under the License.
|
||||
|
||||
use std::borrow::Cow;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use datatypes::value::ValueRef;
|
||||
use datatypes::vectors::Vector;
|
||||
|
||||
/// Convert a constant string or binary literal to a vector literal.
|
||||
pub fn as_veclit_if_const(arg: &Arc<dyn Vector>) -> Result<Option<Cow<'_, [f32]>>> {
|
||||
if !arg.is_const() {
|
||||
return Ok(None);
|
||||
}
|
||||
if arg.data_type() != ConcreteDataType::string_datatype()
|
||||
&& arg.data_type() != ConcreteDataType::binary_datatype()
|
||||
{
|
||||
return Ok(None);
|
||||
}
|
||||
as_veclit(arg.get_ref(0))
|
||||
}
|
||||
use datafusion_common::ScalarValue;
|
||||
|
||||
/// Convert a string or binary literal to a vector literal.
|
||||
pub fn as_veclit(arg: ValueRef<'_>) -> Result<Option<Cow<'_, [f32]>>> {
|
||||
match arg.data_type() {
|
||||
ConcreteDataType::Binary(_) => arg
|
||||
.as_binary()
|
||||
.unwrap() // Safe: checked if it is a binary
|
||||
.map(binlit_as_veclit)
|
||||
pub fn as_veclit(arg: &ScalarValue) -> Result<Option<Cow<'_, [f32]>>> {
|
||||
match arg {
|
||||
ScalarValue::Binary(b) => b.as_ref().map(|x| binlit_as_veclit(x)).transpose(),
|
||||
ScalarValue::Utf8(s) | ScalarValue::Utf8View(s) => s
|
||||
.as_ref()
|
||||
.map(|x| parse_veclit_from_strlit(x).map(Cow::Owned))
|
||||
.transpose(),
|
||||
ConcreteDataType::String(_) => arg
|
||||
.as_string()
|
||||
.unwrap() // Safe: checked if it is a string
|
||||
.map(|s| Ok(Cow::Owned(parse_veclit_from_strlit(s)?)))
|
||||
.transpose(),
|
||||
ConcreteDataType::Null(_) => Ok(None),
|
||||
_ => InvalidFuncArgsSnafu {
|
||||
err_msg: format!("Unsupported data type: {:?}", arg.data_type()),
|
||||
}
|
||||
|
||||
@@ -12,20 +12,19 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Display;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use datafusion_expr::Signature;
|
||||
use datatypes::arrow::datatypes::DataType;
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
|
||||
use common_query::error::Result;
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_common::ScalarValue;
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature};
|
||||
use nalgebra::DVectorView;
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::Function;
|
||||
use crate::helper;
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
|
||||
use crate::scalars::vector::VectorCalculator;
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit};
|
||||
|
||||
const NAME: &str = "vec_scalar_add";
|
||||
|
||||
@@ -60,7 +59,7 @@ impl Function for ScalarAddFunction {
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
Ok(DataType::Binary)
|
||||
Ok(DataType::BinaryView)
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
@@ -70,52 +69,26 @@ impl Function for ScalarAddFunction {
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 2,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly two, have: {}",
|
||||
columns.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
let arg0 = &columns[0];
|
||||
let arg1 = &columns[1];
|
||||
|
||||
let len = arg0.len();
|
||||
let mut result = BinaryVectorBuilder::with_capacity(len);
|
||||
if len == 0 {
|
||||
return Ok(result.to_vector());
|
||||
}
|
||||
|
||||
let arg1_const = as_veclit_if_const(arg1)?;
|
||||
|
||||
for i in 0..len {
|
||||
let arg0 = arg0.get(i).as_f64_lossy();
|
||||
let Some(arg0) = arg0 else {
|
||||
result.push_null();
|
||||
continue;
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let body = |v0: &ScalarValue, v1: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
|
||||
let ScalarValue::Float64(Some(v0)) = v0 else {
|
||||
return Ok(ScalarValue::BinaryView(None));
|
||||
};
|
||||
|
||||
let arg1 = match arg1_const.as_ref() {
|
||||
Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
|
||||
None => as_veclit(arg1.get_ref(i))?,
|
||||
};
|
||||
let Some(arg1) = arg1 else {
|
||||
result.push_null();
|
||||
continue;
|
||||
};
|
||||
let v1 = as_veclit(v1)?
|
||||
.map(|v1| DVectorView::from_slice(&v1, v1.len()).add_scalar(*v0 as f32));
|
||||
let result = v1.map(|v1| veclit_to_binlit(v1.as_slice()));
|
||||
Ok(ScalarValue::BinaryView(result))
|
||||
};
|
||||
|
||||
let vec = DVectorView::from_slice(&arg1, arg1.len());
|
||||
let vec_res = vec.add_scalar(arg0 as _);
|
||||
|
||||
let veclit = vec_res.as_slice();
|
||||
let binlit = veclit_to_binlit(veclit);
|
||||
result.push(Some(&binlit));
|
||||
}
|
||||
|
||||
Ok(result.to_vector())
|
||||
let calculator = VectorCalculator {
|
||||
name: self.name(),
|
||||
func: body,
|
||||
};
|
||||
calculator.invoke_with_args(args)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -129,7 +102,9 @@ impl Display for ScalarAddFunction {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datatypes::vectors::{Float32Vector, StringVector};
|
||||
use arrow_schema::Field;
|
||||
use datafusion::arrow::array::{Array, AsArray, Float64Array, StringViewArray};
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -137,34 +112,42 @@ mod tests {
|
||||
fn test_scalar_add() {
|
||||
let func = ScalarAddFunction;
|
||||
|
||||
let input0 = Arc::new(Float32Vector::from(vec![
|
||||
let input0 = Arc::new(Float64Array::from(vec![
|
||||
Some(1.0),
|
||||
Some(-1.0),
|
||||
None,
|
||||
Some(3.0),
|
||||
]));
|
||||
let input1 = Arc::new(StringVector::from(vec![
|
||||
let input1 = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0,2.0,3.0]".to_string()),
|
||||
Some("[4.0,5.0,6.0]".to_string()),
|
||||
Some("[7.0,8.0,9.0]".to_string()),
|
||||
None,
|
||||
]));
|
||||
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), &[input0, input1])
|
||||
.invoke_with_args(args)
|
||||
.and_then(|x| x.to_array(4))
|
||||
.unwrap();
|
||||
|
||||
let result = result.as_ref();
|
||||
let result = result.as_binary_view();
|
||||
assert_eq!(result.len(), 4);
|
||||
assert_eq!(
|
||||
result.get_ref(0).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[2.0, 3.0, 4.0]).as_slice())
|
||||
result.value(0),
|
||||
veclit_to_binlit(&[2.0, 3.0, 4.0]).as_slice()
|
||||
);
|
||||
assert_eq!(
|
||||
result.get_ref(1).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[3.0, 4.0, 5.0]).as_slice())
|
||||
result.value(1),
|
||||
veclit_to_binlit(&[3.0, 4.0, 5.0]).as_slice()
|
||||
);
|
||||
assert!(result.get_ref(2).is_null());
|
||||
assert!(result.get_ref(3).is_null());
|
||||
assert!(result.is_null(2));
|
||||
assert!(result.is_null(3));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,20 +12,19 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Display;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use datafusion_expr::Signature;
|
||||
use datatypes::arrow::datatypes::DataType;
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
|
||||
use common_query::error::Result;
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_common::ScalarValue;
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature};
|
||||
use nalgebra::DVectorView;
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::Function;
|
||||
use crate::helper;
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
|
||||
use crate::scalars::vector::VectorCalculator;
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit};
|
||||
|
||||
const NAME: &str = "vec_scalar_mul";
|
||||
|
||||
@@ -60,7 +59,7 @@ impl Function for ScalarMulFunction {
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
Ok(DataType::Binary)
|
||||
Ok(DataType::BinaryView)
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
@@ -70,52 +69,26 @@ impl Function for ScalarMulFunction {
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 2,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly two, have: {}",
|
||||
columns.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
let arg0 = &columns[0];
|
||||
let arg1 = &columns[1];
|
||||
|
||||
let len = arg0.len();
|
||||
let mut result = BinaryVectorBuilder::with_capacity(len);
|
||||
if len == 0 {
|
||||
return Ok(result.to_vector());
|
||||
}
|
||||
|
||||
let arg1_const = as_veclit_if_const(arg1)?;
|
||||
|
||||
for i in 0..len {
|
||||
let arg0 = arg0.get(i).as_f64_lossy();
|
||||
let Some(arg0) = arg0 else {
|
||||
result.push_null();
|
||||
continue;
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let body = |v0: &ScalarValue, v1: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
|
||||
let ScalarValue::Float64(Some(v0)) = v0 else {
|
||||
return Ok(ScalarValue::BinaryView(None));
|
||||
};
|
||||
|
||||
let arg1 = match arg1_const.as_ref() {
|
||||
Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
|
||||
None => as_veclit(arg1.get_ref(i))?,
|
||||
};
|
||||
let Some(arg1) = arg1 else {
|
||||
result.push_null();
|
||||
continue;
|
||||
};
|
||||
let v1 =
|
||||
as_veclit(v1)?.map(|v1| DVectorView::from_slice(&v1, v1.len()).scale(*v0 as f32));
|
||||
let result = v1.map(|v1| veclit_to_binlit(v1.as_slice()));
|
||||
Ok(ScalarValue::BinaryView(result))
|
||||
};
|
||||
|
||||
let vec = DVectorView::from_slice(&arg1, arg1.len());
|
||||
let vec_res = vec.scale(arg0 as _);
|
||||
|
||||
let veclit = vec_res.as_slice();
|
||||
let binlit = veclit_to_binlit(veclit);
|
||||
result.push(Some(&binlit));
|
||||
}
|
||||
|
||||
Ok(result.to_vector())
|
||||
let calculator = VectorCalculator {
|
||||
name: self.name(),
|
||||
func: body,
|
||||
};
|
||||
calculator.invoke_with_args(args)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -129,7 +102,9 @@ impl Display for ScalarMulFunction {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datatypes::vectors::{Float32Vector, StringVector};
|
||||
use arrow_schema::Field;
|
||||
use datafusion::arrow::array::{Array, AsArray, Float64Array, StringViewArray};
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -137,34 +112,42 @@ mod tests {
|
||||
fn test_scalar_mul() {
|
||||
let func = ScalarMulFunction;
|
||||
|
||||
let input0 = Arc::new(Float32Vector::from(vec![
|
||||
let input0 = Arc::new(Float64Array::from(vec![
|
||||
Some(2.0),
|
||||
Some(-0.5),
|
||||
None,
|
||||
Some(3.0),
|
||||
]));
|
||||
let input1 = Arc::new(StringVector::from(vec![
|
||||
let input1 = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0,2.0,3.0]".to_string()),
|
||||
Some("[8.0,10.0,12.0]".to_string()),
|
||||
Some("[7.0,8.0,9.0]".to_string()),
|
||||
None,
|
||||
]));
|
||||
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), &[input0, input1])
|
||||
.invoke_with_args(args)
|
||||
.and_then(|x| x.to_array(4))
|
||||
.unwrap();
|
||||
|
||||
let result = result.as_ref();
|
||||
let result = result.as_binary_view();
|
||||
assert_eq!(result.len(), 4);
|
||||
assert_eq!(
|
||||
result.get_ref(0).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[2.0, 4.0, 6.0]).as_slice())
|
||||
result.value(0),
|
||||
veclit_to_binlit(&[2.0, 4.0, 6.0]).as_slice()
|
||||
);
|
||||
assert_eq!(
|
||||
result.get_ref(1).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[-4.0, -5.0, -6.0]).as_slice())
|
||||
result.value(1),
|
||||
veclit_to_binlit(&[-4.0, -5.0, -6.0]).as_slice()
|
||||
);
|
||||
assert!(result.get_ref(2).is_null());
|
||||
assert!(result.get_ref(3).is_null());
|
||||
assert!(result.is_null(2));
|
||||
assert!(result.is_null(3));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,17 +15,17 @@
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Display;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use datafusion_expr::Signature;
|
||||
use datatypes::arrow::datatypes::DataType;
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
|
||||
use common_query::error::Result;
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_common::{DataFusionError, ScalarValue};
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature};
|
||||
use nalgebra::DVectorView;
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::Function;
|
||||
use crate::helper;
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
|
||||
use crate::scalars::vector::VectorCalculator;
|
||||
use crate::scalars::vector::impl_conv::veclit_to_binlit;
|
||||
|
||||
const NAME: &str = "vec_add";
|
||||
|
||||
@@ -51,7 +51,7 @@ impl Function for VectorAddFunction {
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
Ok(DataType::Binary)
|
||||
Ok(DataType::BinaryView)
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
@@ -61,66 +61,36 @@ impl Function for VectorAddFunction {
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
_func_ctx: &FunctionContext,
|
||||
columns: &[VectorRef],
|
||||
) -> common_query::error::Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 2,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly two, have: {}",
|
||||
columns.len()
|
||||
)
|
||||
}
|
||||
);
|
||||
let arg0 = &columns[0];
|
||||
let arg1 = &columns[1];
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let body = |v0: &Option<Cow<[f32]>>,
|
||||
v1: &Option<Cow<[f32]>>|
|
||||
-> datafusion_common::Result<ScalarValue> {
|
||||
let result = if let (Some(v0), Some(v1)) = (v0, v1) {
|
||||
let v0 = DVectorView::from_slice(v0, v0.len());
|
||||
let v1 = DVectorView::from_slice(v1, v1.len());
|
||||
if v0.len() != v1.len() {
|
||||
return Err(DataFusionError::Execution(format!(
|
||||
"vectors length not match: {}",
|
||||
self.name()
|
||||
)));
|
||||
}
|
||||
|
||||
ensure!(
|
||||
arg0.len() == arg1.len(),
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The lengths of the vector are not aligned, args 0: {}, args 1: {}",
|
||||
arg0.len(),
|
||||
arg1.len(),
|
||||
)
|
||||
}
|
||||
);
|
||||
|
||||
let len = arg0.len();
|
||||
let mut result = BinaryVectorBuilder::with_capacity(len);
|
||||
if len == 0 {
|
||||
return Ok(result.to_vector());
|
||||
}
|
||||
|
||||
let arg0_const = as_veclit_if_const(arg0)?;
|
||||
let arg1_const = as_veclit_if_const(arg1)?;
|
||||
|
||||
for i in 0..len {
|
||||
let arg0 = match arg0_const.as_ref() {
|
||||
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
|
||||
None => as_veclit(arg0.get_ref(i))?,
|
||||
let result = veclit_to_binlit((v0 + v1).as_slice());
|
||||
Some(result)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let arg1 = match arg1_const.as_ref() {
|
||||
Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
|
||||
None => as_veclit(arg1.get_ref(i))?,
|
||||
};
|
||||
let (Some(arg0), Some(arg1)) = (arg0, arg1) else {
|
||||
result.push_null();
|
||||
continue;
|
||||
};
|
||||
let vec0 = DVectorView::from_slice(&arg0, arg0.len());
|
||||
let vec1 = DVectorView::from_slice(&arg1, arg1.len());
|
||||
Ok(ScalarValue::BinaryView(result))
|
||||
};
|
||||
|
||||
let vec_res = vec0 + vec1;
|
||||
let veclit = vec_res.as_slice();
|
||||
let binlit = veclit_to_binlit(veclit);
|
||||
result.push(Some(&binlit));
|
||||
}
|
||||
|
||||
Ok(result.to_vector())
|
||||
let calculator = VectorCalculator {
|
||||
name: self.name(),
|
||||
func: body,
|
||||
};
|
||||
calculator.invoke_with_vectors(args)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -134,8 +104,9 @@ impl Display for VectorAddFunction {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error::Error;
|
||||
use datatypes::vectors::StringVector;
|
||||
use arrow_schema::Field;
|
||||
use datafusion::arrow::array::{Array, AsArray, StringViewArray};
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -143,63 +114,71 @@ mod tests {
|
||||
fn test_sub() {
|
||||
let func = VectorAddFunction;
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![
|
||||
let input0 = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0,2.0,3.0]".to_string()),
|
||||
Some("[4.0,5.0,6.0]".to_string()),
|
||||
None,
|
||||
Some("[2.0,3.0,3.0]".to_string()),
|
||||
]));
|
||||
let input1 = Arc::new(StringVector::from(vec![
|
||||
let input1 = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0,1.0,1.0]".to_string()),
|
||||
Some("[6.0,5.0,4.0]".to_string()),
|
||||
Some("[3.0,2.0,2.0]".to_string()),
|
||||
None,
|
||||
]));
|
||||
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), &[input0, input1])
|
||||
.invoke_with_args(args)
|
||||
.and_then(|x| x.to_array(4))
|
||||
.unwrap();
|
||||
|
||||
let result = result.as_ref();
|
||||
let result = result.as_binary_view();
|
||||
assert_eq!(result.len(), 4);
|
||||
assert_eq!(
|
||||
result.get_ref(0).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[2.0, 3.0, 4.0]).as_slice())
|
||||
result.value(0),
|
||||
veclit_to_binlit(&[2.0, 3.0, 4.0]).as_slice()
|
||||
);
|
||||
assert_eq!(
|
||||
result.get_ref(1).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[10.0, 10.0, 10.0]).as_slice())
|
||||
result.value(1),
|
||||
veclit_to_binlit(&[10.0, 10.0, 10.0]).as_slice()
|
||||
);
|
||||
assert!(result.get_ref(2).is_null());
|
||||
assert!(result.get_ref(3).is_null());
|
||||
assert!(result.is_null(2));
|
||||
assert!(result.is_null(3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sub_error() {
|
||||
let func = VectorAddFunction;
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![
|
||||
let input0 = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0,2.0,3.0]".to_string()),
|
||||
Some("[4.0,5.0,6.0]".to_string()),
|
||||
None,
|
||||
Some("[2.0,3.0,3.0]".to_string()),
|
||||
]));
|
||||
let input1 = Arc::new(StringVector::from(vec![
|
||||
let input1 = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0,1.0,1.0]".to_string()),
|
||||
Some("[6.0,5.0,4.0]".to_string()),
|
||||
Some("[3.0,2.0,2.0]".to_string()),
|
||||
]));
|
||||
|
||||
let result = func.eval(&FunctionContext::default(), &[input0, input1]);
|
||||
|
||||
match result {
|
||||
Err(Error::InvalidFuncArgs { err_msg, .. }) => {
|
||||
assert_eq!(
|
||||
err_msg,
|
||||
"The lengths of the vector are not aligned, args 0: 4, args 1: 3"
|
||||
)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let e = func.invoke_with_args(args).unwrap_err();
|
||||
assert!(e.to_string().starts_with(
|
||||
"Internal error: Arguments has mixed length. Expected length: 4, found length: 3."
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,19 +12,18 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Display;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use common_query::error::Result;
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion::logical_expr_common::type_coercion::aggregates::{BINARYS, STRINGS};
|
||||
use datafusion_expr::{Signature, TypeSignature, Volatility};
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::vectors::{MutableVector, UInt64VectorBuilder, VectorRef};
|
||||
use snafu::ensure;
|
||||
use datafusion_common::ScalarValue;
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
|
||||
use crate::function::Function;
|
||||
use crate::scalars::vector::VectorCalculator;
|
||||
use crate::scalars::vector::impl_conv::as_veclit;
|
||||
|
||||
const NAME: &str = "vec_dim";
|
||||
|
||||
@@ -63,43 +62,20 @@ impl Function for VectorDimFunction {
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
_func_ctx: &FunctionContext,
|
||||
columns: &[VectorRef],
|
||||
) -> common_query::error::Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 1,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly one, have: {}",
|
||||
columns.len()
|
||||
)
|
||||
}
|
||||
);
|
||||
let arg0 = &columns[0];
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
|
||||
let v = as_veclit(v0)?.map(|v0| v0.len() as u64);
|
||||
Ok(ScalarValue::UInt64(v))
|
||||
};
|
||||
|
||||
let len = arg0.len();
|
||||
let mut result = UInt64VectorBuilder::with_capacity(len);
|
||||
if len == 0 {
|
||||
return Ok(result.to_vector());
|
||||
}
|
||||
|
||||
let arg0_const = as_veclit_if_const(arg0)?;
|
||||
|
||||
for i in 0..len {
|
||||
let arg0 = match arg0_const.as_ref() {
|
||||
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
|
||||
None => as_veclit(arg0.get_ref(i))?,
|
||||
};
|
||||
let Some(arg0) = arg0 else {
|
||||
result.push_null();
|
||||
continue;
|
||||
};
|
||||
result.push(Some(arg0.len() as u64));
|
||||
}
|
||||
|
||||
Ok(result.to_vector())
|
||||
let calculator = VectorCalculator {
|
||||
name: self.name(),
|
||||
func: body,
|
||||
};
|
||||
calculator.invoke_with_single_argument(args)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,8 +89,10 @@ impl Display for VectorDimFunction {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error::Error;
|
||||
use datatypes::vectors::StringVector;
|
||||
use arrow_schema::Field;
|
||||
use datafusion::arrow::array::{Array, AsArray, StringViewArray};
|
||||
use datafusion::arrow::datatypes::UInt64Type;
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -122,49 +100,60 @@ mod tests {
|
||||
fn test_vec_dim() {
|
||||
let func = VectorDimFunction;
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![
|
||||
let input0 = Arc::new(StringViewArray::from(vec![
|
||||
Some("[0.0,2.0,3.0]".to_string()),
|
||||
Some("[1.0,2.0,3.0,4.0]".to_string()),
|
||||
None,
|
||||
Some("[5.0]".to_string()),
|
||||
]));
|
||||
|
||||
let result = func.eval(&FunctionContext::default(), &[input0]).unwrap();
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input0)],
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new("x", DataType::UInt64, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let result = func
|
||||
.invoke_with_args(args)
|
||||
.and_then(|x| x.to_array(4))
|
||||
.unwrap();
|
||||
|
||||
let result = result.as_ref();
|
||||
let result = result.as_primitive::<UInt64Type>();
|
||||
assert_eq!(result.len(), 4);
|
||||
assert_eq!(result.get_ref(0).as_u64().unwrap(), Some(3));
|
||||
assert_eq!(result.get_ref(1).as_u64().unwrap(), Some(4));
|
||||
assert!(result.get_ref(2).is_null());
|
||||
assert_eq!(result.get_ref(3).as_u64().unwrap(), Some(1));
|
||||
assert_eq!(result.value(0), 3);
|
||||
assert_eq!(result.value(1), 4);
|
||||
assert!(result.is_null(2));
|
||||
assert_eq!(result.value(3), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dim_error() {
|
||||
let func = VectorDimFunction;
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![
|
||||
let input0 = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0,2.0,3.0]".to_string()),
|
||||
Some("[4.0,5.0,6.0]".to_string()),
|
||||
None,
|
||||
Some("[2.0,3.0,3.0]".to_string()),
|
||||
]));
|
||||
let input1 = Arc::new(StringVector::from(vec![
|
||||
let input1 = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0,1.0,1.0]".to_string()),
|
||||
Some("[6.0,5.0,4.0]".to_string()),
|
||||
Some("[3.0,2.0,2.0]".to_string()),
|
||||
]));
|
||||
|
||||
let result = func.eval(&FunctionContext::default(), &[input0, input1]);
|
||||
|
||||
match result {
|
||||
Err(Error::InvalidFuncArgs { err_msg, .. }) => {
|
||||
assert_eq!(
|
||||
err_msg,
|
||||
"The length of the args is not correct, expect exactly one, have: 2"
|
||||
)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new("x", DataType::UInt64, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let e = func.invoke_with_args(args).unwrap_err();
|
||||
assert!(
|
||||
e.to_string()
|
||||
.starts_with("Execution error: vec_dim function requires 1 argument, got 2")
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,17 +15,17 @@
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Display;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use datafusion_expr::Signature;
|
||||
use datatypes::arrow::datatypes::DataType;
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
|
||||
use common_query::error::Result;
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_common::{DataFusionError, ScalarValue};
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature};
|
||||
use nalgebra::DVectorView;
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::Function;
|
||||
use crate::helper;
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
|
||||
use crate::scalars::vector::VectorCalculator;
|
||||
use crate::scalars::vector::impl_conv::veclit_to_binlit;
|
||||
|
||||
const NAME: &str = "vec_div";
|
||||
|
||||
@@ -52,7 +52,7 @@ impl Function for VectorDivFunction {
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
Ok(DataType::Binary)
|
||||
Ok(DataType::BinaryView)
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
@@ -62,64 +62,36 @@ impl Function for VectorDivFunction {
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 2,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly two, have: {}",
|
||||
columns.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let body = |v0: &Option<Cow<[f32]>>,
|
||||
v1: &Option<Cow<[f32]>>|
|
||||
-> datafusion_common::Result<ScalarValue> {
|
||||
let result = if let (Some(v0), Some(v1)) = (v0, v1) {
|
||||
let v0 = DVectorView::from_slice(v0, v0.len());
|
||||
let v1 = DVectorView::from_slice(v1, v1.len());
|
||||
if v0.len() != v1.len() {
|
||||
return Err(DataFusionError::Execution(format!(
|
||||
"vectors length not match: {}",
|
||||
self.name()
|
||||
)));
|
||||
}
|
||||
|
||||
let arg0 = &columns[0];
|
||||
let arg1 = &columns[1];
|
||||
|
||||
let len = arg0.len();
|
||||
let mut result = BinaryVectorBuilder::with_capacity(len);
|
||||
if len == 0 {
|
||||
return Ok(result.to_vector());
|
||||
}
|
||||
|
||||
let arg0_const = as_veclit_if_const(arg0)?;
|
||||
let arg1_const = as_veclit_if_const(arg1)?;
|
||||
|
||||
for i in 0..len {
|
||||
let arg0 = match arg0_const.as_ref() {
|
||||
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
|
||||
None => as_veclit(arg0.get_ref(i))?,
|
||||
};
|
||||
|
||||
let arg1 = match arg1_const.as_ref() {
|
||||
Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
|
||||
None => as_veclit(arg1.get_ref(i))?,
|
||||
};
|
||||
|
||||
if let (Some(arg0), Some(arg1)) = (arg0, arg1) {
|
||||
ensure!(
|
||||
arg0.len() == arg1.len(),
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the vectors must match for division, have: {} vs {}",
|
||||
arg0.len(),
|
||||
arg1.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
let vec0 = DVectorView::from_slice(&arg0, arg0.len());
|
||||
let vec1 = DVectorView::from_slice(&arg1, arg1.len());
|
||||
let vec_res = vec0.component_div(&vec1);
|
||||
|
||||
let veclit = vec_res.as_slice();
|
||||
let binlit = veclit_to_binlit(veclit);
|
||||
result.push(Some(&binlit));
|
||||
let result = veclit_to_binlit((v0.component_div(&v1)).as_slice());
|
||||
Some(result)
|
||||
} else {
|
||||
result.push_null();
|
||||
}
|
||||
}
|
||||
None
|
||||
};
|
||||
Ok(ScalarValue::BinaryView(result))
|
||||
};
|
||||
|
||||
Ok(result.to_vector())
|
||||
let calculator = VectorCalculator {
|
||||
name: self.name(),
|
||||
func: body,
|
||||
};
|
||||
calculator.invoke_with_vectors(args)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -133,8 +105,9 @@ impl Display for VectorDivFunction {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error;
|
||||
use datatypes::vectors::StringVector;
|
||||
use arrow_schema::Field;
|
||||
use datafusion::arrow::array::{Array, AsArray, StringViewArray};
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -144,69 +117,80 @@ mod tests {
|
||||
|
||||
let vec0 = vec![1.0, 2.0, 3.0];
|
||||
let vec1 = vec![1.0, 1.0];
|
||||
let (len0, len1) = (vec0.len(), vec1.len());
|
||||
let input0 = Arc::new(StringVector::from(vec![Some(format!("{vec0:?}"))]));
|
||||
let input1 = Arc::new(StringVector::from(vec![Some(format!("{vec1:?}"))]));
|
||||
let input0 = Arc::new(StringViewArray::from(vec![Some(format!("{vec0:?}"))]));
|
||||
let input1 = Arc::new(StringViewArray::from(vec![Some(format!("{vec1:?}"))]));
|
||||
|
||||
let err = func
|
||||
.eval(&FunctionContext::default(), &[input0, input1])
|
||||
.unwrap_err();
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
|
||||
arg_fields: vec![],
|
||||
number_rows: 3,
|
||||
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let e = func.invoke_with_args(args).unwrap_err();
|
||||
assert_eq!(
|
||||
e.to_string(),
|
||||
"Execution error: vectors length not match: vec_div"
|
||||
);
|
||||
|
||||
match err {
|
||||
error::Error::InvalidFuncArgs { err_msg, .. } => {
|
||||
assert_eq!(
|
||||
err_msg,
|
||||
format!(
|
||||
"The length of the vectors must match for division, have: {} vs {}",
|
||||
len0, len1
|
||||
)
|
||||
)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![
|
||||
let input0 = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0,2.0,3.0]".to_string()),
|
||||
Some("[8.0,10.0,12.0]".to_string()),
|
||||
Some("[7.0,8.0,9.0]".to_string()),
|
||||
None,
|
||||
]));
|
||||
|
||||
let input1 = Arc::new(StringVector::from(vec![
|
||||
let input1 = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0,1.0,1.0]".to_string()),
|
||||
Some("[2.0,2.0,2.0]".to_string()),
|
||||
None,
|
||||
Some("[3.0,3.0,3.0]".to_string()),
|
||||
]));
|
||||
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), &[input0, input1])
|
||||
.invoke_with_args(args)
|
||||
.and_then(|x| x.to_array(4))
|
||||
.unwrap();
|
||||
|
||||
let result = result.as_ref();
|
||||
let result = result.as_binary_view();
|
||||
assert_eq!(result.len(), 4);
|
||||
assert_eq!(
|
||||
result.get_ref(0).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[1.0, 2.0, 3.0]).as_slice())
|
||||
result.value(0),
|
||||
veclit_to_binlit(&[1.0, 2.0, 3.0]).as_slice()
|
||||
);
|
||||
assert_eq!(
|
||||
result.get_ref(1).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[4.0, 5.0, 6.0]).as_slice())
|
||||
result.value(1),
|
||||
veclit_to_binlit(&[4.0, 5.0, 6.0]).as_slice()
|
||||
);
|
||||
assert!(result.get_ref(2).is_null());
|
||||
assert!(result.get_ref(3).is_null());
|
||||
assert!(result.is_null(2));
|
||||
assert!(result.is_null(3));
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![Some("[1.0,-2.0]".to_string())]));
|
||||
let input1 = Arc::new(StringVector::from(vec![Some("[0.0,0.0]".to_string())]));
|
||||
let input0 = Arc::new(StringViewArray::from(vec![Some("[1.0,-2.0]".to_string())]));
|
||||
let input1 = Arc::new(StringViewArray::from(vec![Some("[0.0,0.0]".to_string())]));
|
||||
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
|
||||
arg_fields: vec![],
|
||||
number_rows: 2,
|
||||
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), &[input0, input1])
|
||||
.invoke_with_args(args)
|
||||
.and_then(|x| x.to_array(2))
|
||||
.unwrap();
|
||||
|
||||
let result = result.as_ref();
|
||||
let result = result.as_binary_view();
|
||||
assert_eq!(
|
||||
result.get_ref(0).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[f64::INFINITY as f32, f64::NEG_INFINITY as f32]).as_slice())
|
||||
result.value(0),
|
||||
veclit_to_binlit(&[f64::INFINITY as f32, f64::NEG_INFINITY as f32]).as_slice()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,19 +12,18 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Display;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use datafusion_expr::Signature;
|
||||
use common_query::error::Result;
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_common::{DataFusionError, ScalarValue};
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature};
|
||||
use datatypes::arrow::datatypes::DataType;
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef};
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::Function;
|
||||
use crate::helper;
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
|
||||
use crate::scalars::vector::VectorCalculator;
|
||||
use crate::scalars::vector::impl_conv::as_veclit;
|
||||
|
||||
const NAME: &str = "vec_kth_elem";
|
||||
|
||||
@@ -63,72 +62,44 @@ impl Function for VectorKthElemFunction {
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 2,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly two, have: {}",
|
||||
columns.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let body = |v0: &ScalarValue, v1: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
|
||||
let v0 = as_veclit(v0)?;
|
||||
|
||||
let arg0 = &columns[0];
|
||||
let arg1 = &columns[1];
|
||||
let v1 = match v1 {
|
||||
ScalarValue::Int64(None) => return Ok(ScalarValue::Float32(None)),
|
||||
ScalarValue::Int64(Some(v1)) if *v1 >= 0 => *v1 as usize,
|
||||
_ => {
|
||||
return Err(DataFusionError::Execution(format!(
|
||||
"2nd argument not a valid index or expected datatype: {}",
|
||||
self.name()
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
let len = arg0.len();
|
||||
let mut result = Float32VectorBuilder::with_capacity(len);
|
||||
if len == 0 {
|
||||
return Ok(result.to_vector());
|
||||
let result = v0
|
||||
.map(|v0| {
|
||||
if v1 >= v0.len() {
|
||||
Err(DataFusionError::Execution(format!(
|
||||
"index out of bound: {}",
|
||||
self.name()
|
||||
)))
|
||||
} else {
|
||||
Ok(v0[v1])
|
||||
}
|
||||
})
|
||||
.transpose()?;
|
||||
Ok(ScalarValue::Float32(result))
|
||||
};
|
||||
|
||||
let arg0_const = as_veclit_if_const(arg0)?;
|
||||
|
||||
for i in 0..len {
|
||||
let arg0 = match arg0_const.as_ref() {
|
||||
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
|
||||
None => as_veclit(arg0.get_ref(i))?,
|
||||
};
|
||||
let Some(arg0) = arg0 else {
|
||||
result.push_null();
|
||||
continue;
|
||||
};
|
||||
|
||||
let arg1 = arg1.get(i).as_f64_lossy();
|
||||
let Some(arg1) = arg1 else {
|
||||
result.push_null();
|
||||
continue;
|
||||
};
|
||||
|
||||
ensure!(
|
||||
arg1 >= 0.0 && arg1.fract() == 0.0,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"Invalid argument: k must be a non-negative integer, but got k = {}.",
|
||||
arg1
|
||||
),
|
||||
}
|
||||
);
|
||||
|
||||
let k = arg1 as usize;
|
||||
|
||||
ensure!(
|
||||
k < arg0.len(),
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"Out of range: k must be in the range [0, {}], but got k = {}.",
|
||||
arg0.len() - 1,
|
||||
k
|
||||
),
|
||||
}
|
||||
);
|
||||
|
||||
let value = arg0[k];
|
||||
|
||||
result.push(Some(value));
|
||||
}
|
||||
Ok(result.to_vector())
|
||||
let calculator = VectorCalculator {
|
||||
name: self.name(),
|
||||
func: body,
|
||||
};
|
||||
calculator.invoke_with_args(args)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -142,8 +113,10 @@ impl Display for VectorKthElemFunction {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error;
|
||||
use datatypes::vectors::{Int64Vector, StringVector};
|
||||
use arrow_schema::Field;
|
||||
use datafusion::arrow::array::{Array, ArrayRef, AsArray, Int64Array, StringViewArray};
|
||||
use datafusion::arrow::datatypes::Float32Type;
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -151,55 +124,66 @@ mod tests {
|
||||
fn test_vec_kth_elem() {
|
||||
let func = VectorKthElemFunction;
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![
|
||||
let input0: ArrayRef = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0,2.0,3.0]".to_string()),
|
||||
Some("[4.0,5.0,6.0]".to_string()),
|
||||
Some("[7.0,8.0,9.0]".to_string()),
|
||||
None,
|
||||
]));
|
||||
let input1 = Arc::new(Int64Vector::from(vec![Some(0), Some(2), None, Some(1)]));
|
||||
let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(0), Some(2), None, Some(1)]));
|
||||
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new("x", DataType::Float32, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), &[input0, input1])
|
||||
.invoke_with_args(args)
|
||||
.and_then(|x| x.to_array(4))
|
||||
.unwrap();
|
||||
|
||||
let result = result.as_ref();
|
||||
let result = result.as_primitive::<Float32Type>();
|
||||
assert_eq!(result.len(), 4);
|
||||
assert_eq!(result.get_ref(0).as_f32().unwrap(), Some(1.0));
|
||||
assert_eq!(result.get_ref(1).as_f32().unwrap(), Some(6.0));
|
||||
assert!(result.get_ref(2).is_null());
|
||||
assert!(result.get_ref(3).is_null());
|
||||
assert_eq!(result.value(0), 1.0);
|
||||
assert_eq!(result.value(1), 6.0);
|
||||
assert!(result.is_null(2));
|
||||
assert!(result.is_null(3));
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![Some("[1.0,2.0,3.0]".to_string())]));
|
||||
let input1 = Arc::new(Int64Vector::from(vec![Some(3)]));
|
||||
let input0: ArrayRef = Arc::new(StringViewArray::from(vec![Some(
|
||||
"[1.0,2.0,3.0]".to_string(),
|
||||
)]));
|
||||
let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(3)]));
|
||||
|
||||
let err = func
|
||||
.eval(&FunctionContext::default(), &[input0, input1])
|
||||
.unwrap_err();
|
||||
match err {
|
||||
error::Error::InvalidFuncArgs { err_msg, .. } => {
|
||||
assert_eq!(
|
||||
err_msg,
|
||||
format!("Out of range: k must be in the range [0, 2], but got k = 3.")
|
||||
)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
|
||||
arg_fields: vec![],
|
||||
number_rows: 3,
|
||||
return_field: Arc::new(Field::new("x", DataType::Float32, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let e = func.invoke_with_args(args).unwrap_err();
|
||||
assert!(
|
||||
e.to_string()
|
||||
.starts_with("Execution error: index out of bound: vec_kth_elem")
|
||||
);
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![Some("[1.0,2.0,3.0]".to_string())]));
|
||||
let input1 = Arc::new(Int64Vector::from(vec![Some(-1)]));
|
||||
let input0: ArrayRef = Arc::new(StringViewArray::from(vec![Some(
|
||||
"[1.0,2.0,3.0]".to_string(),
|
||||
)]));
|
||||
let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(-1)]));
|
||||
|
||||
let err = func
|
||||
.eval(&FunctionContext::default(), &[input0, input1])
|
||||
.unwrap_err();
|
||||
match err {
|
||||
error::Error::InvalidFuncArgs { err_msg, .. } => {
|
||||
assert_eq!(
|
||||
err_msg,
|
||||
format!("Invalid argument: k must be a non-negative integer, but got k = -1.")
|
||||
)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
|
||||
arg_fields: vec![],
|
||||
number_rows: 3,
|
||||
return_field: Arc::new(Field::new("x", DataType::Float32, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let e = func.invoke_with_args(args).unwrap_err();
|
||||
assert!(e.to_string().starts_with(
|
||||
"Execution error: 2nd argument not a valid index or expected datatype: vec_kth_elem"
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,17 +15,17 @@
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Display;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use datafusion_expr::Signature;
|
||||
use datatypes::arrow::datatypes::DataType;
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
|
||||
use common_query::error::Result;
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_common::{DataFusionError, ScalarValue};
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature};
|
||||
use nalgebra::DVectorView;
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::Function;
|
||||
use crate::helper;
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
|
||||
use crate::scalars::vector::VectorCalculator;
|
||||
use crate::scalars::vector::impl_conv::veclit_to_binlit;
|
||||
|
||||
const NAME: &str = "vec_mul";
|
||||
|
||||
@@ -52,7 +52,7 @@ impl Function for VectorMulFunction {
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
Ok(DataType::Binary)
|
||||
Ok(DataType::BinaryView)
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
@@ -62,64 +62,36 @@ impl Function for VectorMulFunction {
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 2,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly two, have: {}",
|
||||
columns.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let body = |v0: &Option<Cow<[f32]>>,
|
||||
v1: &Option<Cow<[f32]>>|
|
||||
-> datafusion_common::Result<ScalarValue> {
|
||||
let result = if let (Some(v0), Some(v1)) = (v0, v1) {
|
||||
let v0 = DVectorView::from_slice(v0, v0.len());
|
||||
let v1 = DVectorView::from_slice(v1, v1.len());
|
||||
if v0.len() != v1.len() {
|
||||
return Err(DataFusionError::Execution(format!(
|
||||
"vectors length not match: {}",
|
||||
self.name()
|
||||
)));
|
||||
}
|
||||
|
||||
let arg0 = &columns[0];
|
||||
let arg1 = &columns[1];
|
||||
|
||||
let len = arg0.len();
|
||||
let mut result = BinaryVectorBuilder::with_capacity(len);
|
||||
if len == 0 {
|
||||
return Ok(result.to_vector());
|
||||
}
|
||||
|
||||
let arg0_const = as_veclit_if_const(arg0)?;
|
||||
let arg1_const = as_veclit_if_const(arg1)?;
|
||||
|
||||
for i in 0..len {
|
||||
let arg0 = match arg0_const.as_ref() {
|
||||
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
|
||||
None => as_veclit(arg0.get_ref(i))?,
|
||||
};
|
||||
|
||||
let arg1 = match arg1_const.as_ref() {
|
||||
Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
|
||||
None => as_veclit(arg1.get_ref(i))?,
|
||||
};
|
||||
|
||||
if let (Some(arg0), Some(arg1)) = (arg0, arg1) {
|
||||
ensure!(
|
||||
arg0.len() == arg1.len(),
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the vectors must match for multiplying, have: {} vs {}",
|
||||
arg0.len(),
|
||||
arg1.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
let vec0 = DVectorView::from_slice(&arg0, arg0.len());
|
||||
let vec1 = DVectorView::from_slice(&arg1, arg1.len());
|
||||
let vec_res = vec1.component_mul(&vec0);
|
||||
|
||||
let veclit = vec_res.as_slice();
|
||||
let binlit = veclit_to_binlit(veclit);
|
||||
result.push(Some(&binlit));
|
||||
let result = veclit_to_binlit((v0.component_mul(&v1)).as_slice());
|
||||
Some(result)
|
||||
} else {
|
||||
result.push_null();
|
||||
}
|
||||
}
|
||||
None
|
||||
};
|
||||
Ok(ScalarValue::BinaryView(result))
|
||||
};
|
||||
|
||||
Ok(result.to_vector())
|
||||
let calculator = VectorCalculator {
|
||||
name: self.name(),
|
||||
func: body,
|
||||
};
|
||||
calculator.invoke_with_vectors(args)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -133,8 +105,9 @@ impl Display for VectorMulFunction {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error;
|
||||
use datatypes::vectors::StringVector;
|
||||
use arrow_schema::Field;
|
||||
use datafusion::arrow::array::{Array, AsArray, StringViewArray};
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -144,56 +117,59 @@ mod tests {
|
||||
|
||||
let vec0 = vec![1.0, 2.0, 3.0];
|
||||
let vec1 = vec![1.0, 1.0];
|
||||
let (len0, len1) = (vec0.len(), vec1.len());
|
||||
let input0 = Arc::new(StringVector::from(vec![Some(format!("{vec0:?}"))]));
|
||||
let input1 = Arc::new(StringVector::from(vec![Some(format!("{vec1:?}"))]));
|
||||
let input0 = Arc::new(StringViewArray::from(vec![Some(format!("{vec0:?}"))]));
|
||||
let input1 = Arc::new(StringViewArray::from(vec![Some(format!("{vec1:?}"))]));
|
||||
|
||||
let err = func
|
||||
.eval(&FunctionContext::default(), &[input0, input1])
|
||||
.unwrap_err();
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let e = func.invoke_with_args(args).unwrap_err();
|
||||
assert!(
|
||||
e.to_string()
|
||||
.starts_with("Execution error: vectors length not match: vec_mul")
|
||||
);
|
||||
|
||||
match err {
|
||||
error::Error::InvalidFuncArgs { err_msg, .. } => {
|
||||
assert_eq!(
|
||||
err_msg,
|
||||
format!(
|
||||
"The length of the vectors must match for multiplying, have: {} vs {}",
|
||||
len0, len1
|
||||
)
|
||||
)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![
|
||||
let input0 = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0,2.0,3.0]".to_string()),
|
||||
Some("[8.0,10.0,12.0]".to_string()),
|
||||
Some("[7.0,8.0,9.0]".to_string()),
|
||||
None,
|
||||
]));
|
||||
|
||||
let input1 = Arc::new(StringVector::from(vec![
|
||||
let input1 = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0,1.0,1.0]".to_string()),
|
||||
Some("[2.0,2.0,2.0]".to_string()),
|
||||
None,
|
||||
Some("[3.0,3.0,3.0]".to_string()),
|
||||
]));
|
||||
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), &[input0, input1])
|
||||
.invoke_with_args(args)
|
||||
.and_then(|x| x.to_array(4))
|
||||
.unwrap();
|
||||
|
||||
let result = result.as_ref();
|
||||
let result = result.as_binary_view();
|
||||
assert_eq!(result.len(), 4);
|
||||
assert_eq!(
|
||||
result.get_ref(0).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[1.0, 2.0, 3.0]).as_slice())
|
||||
result.value(0),
|
||||
veclit_to_binlit(&[1.0, 2.0, 3.0]).as_slice()
|
||||
);
|
||||
assert_eq!(
|
||||
result.get_ref(1).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[16.0, 20.0, 24.0]).as_slice())
|
||||
result.value(1),
|
||||
veclit_to_binlit(&[16.0, 20.0, 24.0]).as_slice()
|
||||
);
|
||||
assert!(result.get_ref(2).is_null());
|
||||
assert!(result.get_ref(3).is_null());
|
||||
assert!(result.is_null(2));
|
||||
assert!(result.is_null(3));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,20 +12,19 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Display;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use common_query::error::Result;
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion::logical_expr_common::type_coercion::aggregates::{BINARYS, STRINGS};
|
||||
use datafusion_expr::{Signature, TypeSignature, Volatility};
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
|
||||
use datafusion_common::ScalarValue;
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
|
||||
use nalgebra::DVectorView;
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
|
||||
use crate::function::Function;
|
||||
use crate::scalars::vector::VectorCalculator;
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit};
|
||||
|
||||
const NAME: &str = "vec_norm";
|
||||
|
||||
@@ -53,7 +52,7 @@ impl Function for VectorNormFunction {
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
Ok(DataType::Binary)
|
||||
Ok(DataType::BinaryView)
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
@@ -66,55 +65,27 @@ impl Function for VectorNormFunction {
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
_func_ctx: &FunctionContext,
|
||||
columns: &[VectorRef],
|
||||
) -> common_query::error::Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 1,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly one, have: {}",
|
||||
columns.len()
|
||||
)
|
||||
}
|
||||
);
|
||||
let arg0 = &columns[0];
|
||||
|
||||
let len = arg0.len();
|
||||
let mut result = BinaryVectorBuilder::with_capacity(len);
|
||||
if len == 0 {
|
||||
return Ok(result.to_vector());
|
||||
}
|
||||
|
||||
let arg0_const = as_veclit_if_const(arg0)?;
|
||||
|
||||
for i in 0..len {
|
||||
let arg0 = match arg0_const.as_ref() {
|
||||
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
|
||||
None => as_veclit(arg0.get_ref(i))?,
|
||||
};
|
||||
let Some(arg0) = arg0 else {
|
||||
result.push_null();
|
||||
continue;
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
|
||||
let v0 = as_veclit(v0)?;
|
||||
let Some(v0) = v0 else {
|
||||
return Ok(ScalarValue::BinaryView(None));
|
||||
};
|
||||
|
||||
let vec0 = DVectorView::from_slice(&arg0, arg0.len());
|
||||
let vec1 = DVectorView::from_slice(&arg0, arg0.len());
|
||||
let vec2scalar = vec1.component_mul(&vec0);
|
||||
let scalar_var = vec2scalar.sum().sqrt();
|
||||
let v0 = DVectorView::from_slice(&v0, v0.len());
|
||||
let result =
|
||||
veclit_to_binlit(v0.unscale(v0.component_mul(&v0).sum().sqrt()).as_slice());
|
||||
Ok(ScalarValue::BinaryView(Some(result)))
|
||||
};
|
||||
|
||||
let vec = DVectorView::from_slice(&arg0, arg0.len());
|
||||
// Use unscale to avoid division by zero and keep more precision as possible
|
||||
let vec_res = vec.unscale(scalar_var);
|
||||
|
||||
let veclit = vec_res.as_slice();
|
||||
let binlit = veclit_to_binlit(veclit);
|
||||
result.push(Some(&binlit));
|
||||
}
|
||||
|
||||
Ok(result.to_vector())
|
||||
let calculator = VectorCalculator {
|
||||
name: self.name(),
|
||||
func: body,
|
||||
};
|
||||
calculator.invoke_with_single_argument(args)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -128,7 +99,9 @@ impl Display for VectorNormFunction {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datatypes::vectors::StringVector;
|
||||
use arrow_schema::Field;
|
||||
use datafusion::arrow::array::{Array, AsArray, StringViewArray};
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -136,7 +109,7 @@ mod tests {
|
||||
fn test_vec_norm() {
|
||||
let func = VectorNormFunction;
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![
|
||||
let input0 = Arc::new(StringViewArray::from(vec![
|
||||
Some("[0.0,2.0,3.0]".to_string()),
|
||||
Some("[1.0,2.0,3.0]".to_string()),
|
||||
Some("[7.0,8.0,9.0]".to_string()),
|
||||
@@ -144,26 +117,36 @@ mod tests {
|
||||
None,
|
||||
]));
|
||||
|
||||
let result = func.eval(&FunctionContext::default(), &[input0]).unwrap();
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input0)],
|
||||
arg_fields: vec![],
|
||||
number_rows: 5,
|
||||
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let result = func
|
||||
.invoke_with_args(args)
|
||||
.and_then(|x| x.to_array(5))
|
||||
.unwrap();
|
||||
|
||||
let result = result.as_ref();
|
||||
let result = result.as_binary_view();
|
||||
assert_eq!(result.len(), 5);
|
||||
assert_eq!(
|
||||
result.get_ref(0).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[0.0, 0.5547002, 0.8320503]).as_slice())
|
||||
result.value(0),
|
||||
veclit_to_binlit(&[0.0, 0.5547002, 0.8320503]).as_slice()
|
||||
);
|
||||
assert_eq!(
|
||||
result.get_ref(1).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[0.26726124, 0.5345225, 0.8017837]).as_slice())
|
||||
result.value(1),
|
||||
veclit_to_binlit(&[0.26726124, 0.5345225, 0.8017837]).as_slice()
|
||||
);
|
||||
assert_eq!(
|
||||
result.get_ref(2).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[0.5025707, 0.5743665, 0.64616233]).as_slice())
|
||||
result.value(2),
|
||||
veclit_to_binlit(&[0.5025707, 0.5743665, 0.64616233]).as_slice()
|
||||
);
|
||||
assert_eq!(
|
||||
result.get_ref(3).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[0.5025707, -0.5743665, 0.64616233]).as_slice())
|
||||
result.value(3),
|
||||
veclit_to_binlit(&[0.5025707, -0.5743665, 0.64616233]).as_slice()
|
||||
);
|
||||
assert!(result.get_ref(4).is_null());
|
||||
assert!(result.is_null(4));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,17 +15,17 @@
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Display;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use datafusion_expr::Signature;
|
||||
use datatypes::arrow::datatypes::DataType;
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
|
||||
use common_query::error::Result;
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_common::{DataFusionError, ScalarValue};
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature};
|
||||
use nalgebra::DVectorView;
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::Function;
|
||||
use crate::helper;
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
|
||||
use crate::scalars::vector::VectorCalculator;
|
||||
use crate::scalars::vector::impl_conv::veclit_to_binlit;
|
||||
|
||||
const NAME: &str = "vec_sub";
|
||||
|
||||
@@ -51,7 +51,7 @@ impl Function for VectorSubFunction {
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
Ok(DataType::Binary)
|
||||
Ok(DataType::BinaryView)
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
@@ -61,66 +61,36 @@ impl Function for VectorSubFunction {
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
_func_ctx: &FunctionContext,
|
||||
columns: &[VectorRef],
|
||||
) -> common_query::error::Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 2,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly two, have: {}",
|
||||
columns.len()
|
||||
)
|
||||
}
|
||||
);
|
||||
let arg0 = &columns[0];
|
||||
let arg1 = &columns[1];
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let body = |v0: &Option<Cow<[f32]>>,
|
||||
v1: &Option<Cow<[f32]>>|
|
||||
-> datafusion_common::Result<ScalarValue> {
|
||||
let result = if let (Some(v0), Some(v1)) = (v0, v1) {
|
||||
let v0 = DVectorView::from_slice(v0, v0.len());
|
||||
let v1 = DVectorView::from_slice(v1, v1.len());
|
||||
if v0.len() != v1.len() {
|
||||
return Err(DataFusionError::Execution(format!(
|
||||
"vectors length not match: {}",
|
||||
self.name()
|
||||
)));
|
||||
}
|
||||
|
||||
ensure!(
|
||||
arg0.len() == arg1.len(),
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The lengths of the vector are not aligned, args 0: {}, args 1: {}",
|
||||
arg0.len(),
|
||||
arg1.len(),
|
||||
)
|
||||
}
|
||||
);
|
||||
|
||||
let len = arg0.len();
|
||||
let mut result = BinaryVectorBuilder::with_capacity(len);
|
||||
if len == 0 {
|
||||
return Ok(result.to_vector());
|
||||
}
|
||||
|
||||
let arg0_const = as_veclit_if_const(arg0)?;
|
||||
let arg1_const = as_veclit_if_const(arg1)?;
|
||||
|
||||
for i in 0..len {
|
||||
let arg0 = match arg0_const.as_ref() {
|
||||
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
|
||||
None => as_veclit(arg0.get_ref(i))?,
|
||||
let result = veclit_to_binlit((v0 - v1).as_slice());
|
||||
Some(result)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let arg1 = match arg1_const.as_ref() {
|
||||
Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
|
||||
None => as_veclit(arg1.get_ref(i))?,
|
||||
};
|
||||
let (Some(arg0), Some(arg1)) = (arg0, arg1) else {
|
||||
result.push_null();
|
||||
continue;
|
||||
};
|
||||
let vec0 = DVectorView::from_slice(&arg0, arg0.len());
|
||||
let vec1 = DVectorView::from_slice(&arg1, arg1.len());
|
||||
Ok(ScalarValue::BinaryView(result))
|
||||
};
|
||||
|
||||
let vec_res = vec0 - vec1;
|
||||
let veclit = vec_res.as_slice();
|
||||
let binlit = veclit_to_binlit(veclit);
|
||||
result.push(Some(&binlit));
|
||||
}
|
||||
|
||||
Ok(result.to_vector())
|
||||
let calculator = VectorCalculator {
|
||||
name: self.name(),
|
||||
func: body,
|
||||
};
|
||||
calculator.invoke_with_vectors(args)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -134,8 +104,9 @@ impl Display for VectorSubFunction {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error::Error;
|
||||
use datatypes::vectors::StringVector;
|
||||
use arrow_schema::Field;
|
||||
use datafusion::arrow::array::{Array, ArrayRef, AsArray, StringViewArray};
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -143,63 +114,71 @@ mod tests {
|
||||
fn test_sub() {
|
||||
let func = VectorSubFunction;
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![
|
||||
let input0: ArrayRef = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0,2.0,3.0]".to_string()),
|
||||
Some("[4.0,5.0,6.0]".to_string()),
|
||||
None,
|
||||
Some("[2.0,3.0,3.0]".to_string()),
|
||||
]));
|
||||
let input1 = Arc::new(StringVector::from(vec![
|
||||
let input1: ArrayRef = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0,1.0,1.0]".to_string()),
|
||||
Some("[6.0,5.0,4.0]".to_string()),
|
||||
Some("[3.0,2.0,2.0]".to_string()),
|
||||
None,
|
||||
]));
|
||||
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), &[input0, input1])
|
||||
.invoke_with_args(args)
|
||||
.and_then(|x| x.to_array(4))
|
||||
.unwrap();
|
||||
|
||||
let result = result.as_ref();
|
||||
let result = result.as_binary_view();
|
||||
assert_eq!(result.len(), 4);
|
||||
assert_eq!(
|
||||
result.get_ref(0).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[0.0, 1.0, 2.0]).as_slice())
|
||||
result.value(0),
|
||||
veclit_to_binlit(&[0.0, 1.0, 2.0]).as_slice()
|
||||
);
|
||||
assert_eq!(
|
||||
result.get_ref(1).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[-2.0, 0.0, 2.0]).as_slice())
|
||||
result.value(1),
|
||||
veclit_to_binlit(&[-2.0, 0.0, 2.0]).as_slice()
|
||||
);
|
||||
assert!(result.get_ref(2).is_null());
|
||||
assert!(result.get_ref(3).is_null());
|
||||
assert!(result.is_null(2));
|
||||
assert!(result.is_null(3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sub_error() {
|
||||
let func = VectorSubFunction;
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![
|
||||
let input0: ArrayRef = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0,2.0,3.0]".to_string()),
|
||||
Some("[4.0,5.0,6.0]".to_string()),
|
||||
None,
|
||||
Some("[2.0,3.0,3.0]".to_string()),
|
||||
]));
|
||||
let input1 = Arc::new(StringVector::from(vec![
|
||||
let input1: ArrayRef = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0,1.0,1.0]".to_string()),
|
||||
Some("[6.0,5.0,4.0]".to_string()),
|
||||
Some("[3.0,2.0,2.0]".to_string()),
|
||||
]));
|
||||
|
||||
let result = func.eval(&FunctionContext::default(), &[input0, input1]);
|
||||
|
||||
match result {
|
||||
Err(Error::InvalidFuncArgs { err_msg, .. }) => {
|
||||
assert_eq!(
|
||||
err_msg,
|
||||
"The lengths of the vector are not aligned, args 0: 4, args 1: 3"
|
||||
)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let e = func.invoke_with_args(args).unwrap_err();
|
||||
assert!(e.to_string().starts_with(
|
||||
"Internal error: Arguments has mixed length. Expected length: 4, found length: 3."
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,18 +12,20 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Display;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use datafusion_expr::{Signature, TypeSignature, Volatility};
|
||||
use datafusion::arrow::array::{Array, AsArray, BinaryViewBuilder};
|
||||
use datafusion::arrow::datatypes::Int64Type;
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_common::{ScalarValue, utils};
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
|
||||
use datatypes::arrow::datatypes::DataType;
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
|
||||
use crate::function::Function;
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit};
|
||||
|
||||
const NAME: &str = "vec_subvector";
|
||||
|
||||
@@ -52,7 +54,7 @@ impl Function for VectorSubvectorFunction {
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
Ok(DataType::Binary)
|
||||
Ok(DataType::BinaryView)
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
@@ -65,50 +67,28 @@ impl Function for VectorSubvectorFunction {
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 3,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly three, have: {}",
|
||||
columns.len()
|
||||
)
|
||||
}
|
||||
);
|
||||
|
||||
let arg0 = &columns[0];
|
||||
let arg1 = &columns[1];
|
||||
let arg2 = &columns[2];
|
||||
|
||||
ensure!(
|
||||
arg0.len() == arg1.len() && arg1.len() == arg2.len(),
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The lengths of the vector are not aligned, args 0: {}, args 1: {}, args 2: {}",
|
||||
arg0.len(),
|
||||
arg1.len(),
|
||||
arg2.len()
|
||||
)
|
||||
}
|
||||
);
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [arg0, arg1, arg2] = utils::take_function_args(self.name(), args)?;
|
||||
let arg1 = arg1.as_primitive::<Int64Type>();
|
||||
let arg2 = arg2.as_primitive::<Int64Type>();
|
||||
|
||||
let len = arg0.len();
|
||||
let mut result = BinaryVectorBuilder::with_capacity(len);
|
||||
let mut builder = BinaryViewBuilder::with_capacity(len);
|
||||
if len == 0 {
|
||||
return Ok(result.to_vector());
|
||||
return Ok(ColumnarValue::Array(Arc::new(builder.finish())));
|
||||
}
|
||||
|
||||
let arg0_const = as_veclit_if_const(arg0)?;
|
||||
|
||||
for i in 0..len {
|
||||
let arg0 = match arg0_const.as_ref() {
|
||||
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
|
||||
None => as_veclit(arg0.get_ref(i))?,
|
||||
};
|
||||
let arg1 = arg1.get(i).as_i64();
|
||||
let arg2 = arg2.get(i).as_i64();
|
||||
let v = ScalarValue::try_from_array(&arg0, i)?;
|
||||
let arg0 = as_veclit(&v)?;
|
||||
let arg1 = arg1.is_valid(i).then(|| arg1.value(i));
|
||||
let arg2 = arg2.is_valid(i).then(|| arg2.value(i));
|
||||
let (Some(arg0), Some(arg1), Some(arg2)) = (arg0, arg1, arg2) else {
|
||||
result.push_null();
|
||||
builder.append_null();
|
||||
continue;
|
||||
};
|
||||
|
||||
@@ -126,10 +106,10 @@ impl Function for VectorSubvectorFunction {
|
||||
|
||||
let subvector = &arg0[arg1 as usize..arg2 as usize];
|
||||
let binlit = veclit_to_binlit(subvector);
|
||||
result.push(Some(&binlit));
|
||||
builder.append_value(&binlit);
|
||||
}
|
||||
|
||||
Ok(result.to_vector())
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,89 +123,102 @@ impl Display for VectorSubvectorFunction {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error::Error;
|
||||
use datatypes::vectors::{Int64Vector, StringVector};
|
||||
use arrow_schema::Field;
|
||||
use datafusion::arrow::array::{ArrayRef, Int64Array, StringViewArray};
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
|
||||
use super::*;
|
||||
use crate::function::FunctionContext;
|
||||
|
||||
#[test]
|
||||
fn test_subvector() {
|
||||
let func = VectorSubvectorFunction;
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![
|
||||
let input0: ArrayRef = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0, 2.0, 3.0, 4.0, 5.0]".to_string()),
|
||||
Some("[6.0, 7.0, 8.0, 9.0, 10.0]".to_string()),
|
||||
None,
|
||||
Some("[11.0, 12.0, 13.0]".to_string()),
|
||||
]));
|
||||
let input1 = Arc::new(Int64Vector::from(vec![Some(1), Some(0), Some(0), Some(1)]));
|
||||
let input2 = Arc::new(Int64Vector::from(vec![Some(3), Some(5), Some(2), Some(3)]));
|
||||
let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(0), Some(0), Some(1)]));
|
||||
let input2: ArrayRef = Arc::new(Int64Array::from(vec![Some(3), Some(5), Some(2), Some(3)]));
|
||||
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![
|
||||
ColumnarValue::Array(input0),
|
||||
ColumnarValue::Array(input1),
|
||||
ColumnarValue::Array(input2),
|
||||
],
|
||||
arg_fields: vec![],
|
||||
number_rows: 5,
|
||||
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), &[input0, input1, input2])
|
||||
.invoke_with_args(args)
|
||||
.and_then(|x| x.to_array(5))
|
||||
.unwrap();
|
||||
|
||||
let result = result.as_ref();
|
||||
let result = result.as_binary_view();
|
||||
assert_eq!(result.len(), 4);
|
||||
assert_eq!(result.value(0), veclit_to_binlit(&[2.0, 3.0]).as_slice());
|
||||
assert_eq!(
|
||||
result.get_ref(0).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[2.0, 3.0]).as_slice())
|
||||
);
|
||||
assert_eq!(
|
||||
result.get_ref(1).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[6.0, 7.0, 8.0, 9.0, 10.0]).as_slice())
|
||||
);
|
||||
assert!(result.get_ref(2).is_null());
|
||||
assert_eq!(
|
||||
result.get_ref(3).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[12.0, 13.0]).as_slice())
|
||||
result.value(1),
|
||||
veclit_to_binlit(&[6.0, 7.0, 8.0, 9.0, 10.0]).as_slice()
|
||||
);
|
||||
assert!(result.is_null(2));
|
||||
assert_eq!(result.value(3), veclit_to_binlit(&[12.0, 13.0]).as_slice());
|
||||
}
|
||||
#[test]
|
||||
fn test_subvector_error() {
|
||||
let func = VectorSubvectorFunction;
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![
|
||||
let input0: ArrayRef = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0, 2.0, 3.0]".to_string()),
|
||||
Some("[4.0, 5.0, 6.0]".to_string()),
|
||||
]));
|
||||
let input1 = Arc::new(Int64Vector::from(vec![Some(1), Some(2)]));
|
||||
let input2 = Arc::new(Int64Vector::from(vec![Some(3)]));
|
||||
let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(2)]));
|
||||
let input2: ArrayRef = Arc::new(Int64Array::from(vec![Some(3)]));
|
||||
|
||||
let result = func.eval(&FunctionContext::default(), &[input0, input1, input2]);
|
||||
|
||||
match result {
|
||||
Err(Error::InvalidFuncArgs { err_msg, .. }) => {
|
||||
assert_eq!(
|
||||
err_msg,
|
||||
"The lengths of the vector are not aligned, args 0: 2, args 1: 2, args 2: 1"
|
||||
)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![
|
||||
ColumnarValue::Array(input0),
|
||||
ColumnarValue::Array(input1),
|
||||
ColumnarValue::Array(input2),
|
||||
],
|
||||
arg_fields: vec![],
|
||||
number_rows: 3,
|
||||
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let e = func.invoke_with_args(args).unwrap_err();
|
||||
assert!(e.to_string().starts_with(
|
||||
"Internal error: Arguments has mixed length. Expected length: 2, found length: 1."
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_subvector_invalid_indices() {
|
||||
let func = VectorSubvectorFunction;
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![
|
||||
let input0 = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0, 2.0, 3.0]".to_string()),
|
||||
Some("[4.0, 5.0, 6.0]".to_string()),
|
||||
]));
|
||||
let input1 = Arc::new(Int64Vector::from(vec![Some(1), Some(3)]));
|
||||
let input2 = Arc::new(Int64Vector::from(vec![Some(3), Some(4)]));
|
||||
let input1 = Arc::new(Int64Array::from(vec![Some(1), Some(3)]));
|
||||
let input2 = Arc::new(Int64Array::from(vec![Some(3), Some(4)]));
|
||||
|
||||
let result = func.eval(&FunctionContext::default(), &[input0, input1, input2]);
|
||||
|
||||
match result {
|
||||
Err(Error::InvalidFuncArgs { err_msg, .. }) => {
|
||||
assert_eq!(
|
||||
err_msg,
|
||||
"Invalid start and end indices: start=3, end=4, vec_len=3"
|
||||
)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![
|
||||
ColumnarValue::Array(input0),
|
||||
ColumnarValue::Array(input1),
|
||||
ColumnarValue::Array(input2),
|
||||
],
|
||||
arg_fields: vec![],
|
||||
number_rows: 3,
|
||||
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let e = func.invoke_with_args(args).unwrap_err();
|
||||
assert!(e.to_string().starts_with("External error: Invalid function args: Invalid start and end indices: start=3, end=4, vec_len=3"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,11 +12,8 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::borrow::Cow;
|
||||
|
||||
use common_function::scalars::vector::impl_conv::{
|
||||
as_veclit, as_veclit_if_const, veclit_to_binlit,
|
||||
};
|
||||
use common_function::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit};
|
||||
use datafusion_common::ScalarValue;
|
||||
use datatypes::prelude::Value;
|
||||
use nalgebra::{Const, DVectorView, Dyn, OVector};
|
||||
|
||||
@@ -35,14 +32,10 @@ async fn test_vec_product_aggregator() -> Result<(), common_query::error::Error>
|
||||
let sql = "SELECT vector FROM vectors";
|
||||
let vectors = exec_selection(engine, sql).await;
|
||||
|
||||
let column = vectors[0].column(0);
|
||||
let vector_const = as_veclit_if_const(column)?;
|
||||
|
||||
let column = vectors[0].column(0).to_arrow_array();
|
||||
for i in 0..column.len() {
|
||||
let vector = match vector_const.as_ref() {
|
||||
Some(vector) => Some(Cow::Borrowed(vector.as_ref())),
|
||||
None => as_veclit(column.get_ref(i))?,
|
||||
};
|
||||
let v = ScalarValue::try_from_array(&column, i)?;
|
||||
let vector = as_veclit(&v)?;
|
||||
let Some(vector) = vector else {
|
||||
expected_value = None;
|
||||
break;
|
||||
|
||||
@@ -12,12 +12,10 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::borrow::Cow;
|
||||
use std::ops::AddAssign;
|
||||
|
||||
use common_function::scalars::vector::impl_conv::{
|
||||
as_veclit, as_veclit_if_const, veclit_to_binlit,
|
||||
};
|
||||
use common_function::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit};
|
||||
use datafusion_common::ScalarValue;
|
||||
use datatypes::prelude::Value;
|
||||
use nalgebra::{Const, DVectorView, Dyn, OVector};
|
||||
|
||||
@@ -36,14 +34,10 @@ async fn test_vec_sum_aggregator() -> Result<(), common_query::error::Error> {
|
||||
let sql = "SELECT vector FROM vectors";
|
||||
let vectors = exec_selection(engine, sql).await;
|
||||
|
||||
let column = vectors[0].column(0);
|
||||
let vector_const = as_veclit_if_const(column)?;
|
||||
|
||||
let column = vectors[0].column(0).to_arrow_array();
|
||||
for i in 0..column.len() {
|
||||
let vector = match vector_const.as_ref() {
|
||||
Some(vector) => Some(Cow::Borrowed(vector.as_ref())),
|
||||
None => as_veclit(column.get_ref(i))?,
|
||||
};
|
||||
let v = ScalarValue::try_from_array(&column, i)?;
|
||||
let vector = as_veclit(&v)?;
|
||||
let Some(vector) = vector else {
|
||||
expected_value = None;
|
||||
break;
|
||||
|
||||
@@ -76,13 +76,7 @@ SELECT CLAMP(0.5, 0, 1);
|
||||
|
||||
SELECT CLAMP(10, 1, 0);
|
||||
|
||||
Error: 3001(EngineExecuteQuery), Invalid function args: The second arg should be less than or equal to the third arg, have: PrimitiveVector { array: PrimitiveArray<Int64>
|
||||
[
|
||||
1,
|
||||
] }, PrimitiveVector { array: PrimitiveArray<Int64>
|
||||
[
|
||||
0,
|
||||
] }
|
||||
Error: 3001(EngineExecuteQuery), Execution error: min '1' > max '0'
|
||||
|
||||
SELECT CLAMP_MIN(10, 12);
|
||||
|
||||
|
||||
@@ -240,11 +240,11 @@ SELECT geohash(37.76938, -122.3889, 11);
|
||||
|
||||
SELECT geohash(37.76938, -122.3889, 100);
|
||||
|
||||
Error: 3001(EngineExecuteQuery), Invalid geohash resolution 100, expect value: [1, 12]
|
||||
Error: 3001(EngineExecuteQuery), Execution error: Invalid geohash resolution 100, valid value range: [1, 12]
|
||||
|
||||
SELECT geohash(37.76938, -122.3889, -1);
|
||||
|
||||
Error: 3001(EngineExecuteQuery), Invalid geohash resolution -1, expect value: [1, 12]
|
||||
Error: 3001(EngineExecuteQuery), Cast error: Can't cast value -1 to type UInt8
|
||||
|
||||
SELECT geohash(37.76938, -122.3889, 11::Int8);
|
||||
|
||||
|
||||
@@ -375,17 +375,7 @@ TQL EVAL (0, 15, '5s') clamp(host, 6 - 6, 6 + 6);
|
||||
-- SQLNESS SORT_RESULT 3 1
|
||||
TQL EVAL (0, 15, '5s') clamp(host, 12, 0);
|
||||
|
||||
Error: 3001(EngineExecuteQuery), Invalid function args: The second arg should be less than or equal to the third arg, have: PrimitiveVector { array: PrimitiveArray<Float64>
|
||||
[
|
||||
12.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
12.0,
|
||||
12.0,
|
||||
[
|
||||
] }, PrimitiveVector { array: PrimitiveArray<Float64>
|
||||
] }
|
||||
Error: 3001(EngineExecuteQuery), Execution error: min '12' > max '0'
|
||||
|
||||
-- SQLNESS SORT_RESULT 3 1
|
||||
TQL EVAL (0, 15, '5s') clamp(host{host="host1"}, -1, 6);
|
||||
|
||||
@@ -99,7 +99,7 @@ SELECT round(vec_cos_distance(v, v), 2) FROM t;
|
||||
-- Unexpected dimension --
|
||||
SELECT vec_cos_distance(v, '[1.0]') FROM t;
|
||||
|
||||
Error: 3001(EngineExecuteQuery), Invalid function args: The length of the vectors must match to calculate distance, have: 3 vs 1
|
||||
Error: 3001(EngineExecuteQuery), Execution error: vectors length not match: vec_cos_distance
|
||||
|
||||
-- Invalid type --
|
||||
SELECT vec_cos_distance(v, 1.0) FROM t;
|
||||
@@ -174,7 +174,7 @@ SELECT round(vec_l2sq_distance(v, v), 2) FROM t;
|
||||
-- Unexpected dimension --
|
||||
SELECT vec_l2sq_distance(v, '[1.0]') FROM t;
|
||||
|
||||
Error: 3001(EngineExecuteQuery), Invalid function args: The length of the vectors must match to calculate distance, have: 3 vs 1
|
||||
Error: 3001(EngineExecuteQuery), Execution error: vectors length not match: vec_l2sq_distance
|
||||
|
||||
-- Invalid type --
|
||||
SELECT vec_l2sq_distance(v, 1.0) FROM t;
|
||||
@@ -249,7 +249,7 @@ SELECT round(vec_dot_product(v, v), 2) FROM t;
|
||||
-- Unexpected dimension --
|
||||
SELECT vec_dot_product(v, '[1.0]') FROM t;
|
||||
|
||||
Error: 3001(EngineExecuteQuery), Invalid function args: The length of the vectors must match to calculate distance, have: 3 vs 1
|
||||
Error: 3001(EngineExecuteQuery), Execution error: vectors length not match: vec_dot_product
|
||||
|
||||
-- Invalid type --
|
||||
SELECT vec_dot_product(v, 1.0) FROM t;
|
||||
|
||||
Reference in New Issue
Block a user