refactor: simplify udf (#5617)

* refactor: simplify udf

* fix tests
This commit is contained in:
LFC
2025-03-03 13:52:44 +08:00
committed by GitHub
parent 11a4f54c49
commit dee76f0a73
68 changed files with 323 additions and 751 deletions

2
Cargo.lock generated
View File

@@ -2026,6 +2026,8 @@ dependencies = [
"common-time",
"common-version",
"datafusion",
"datafusion-common",
"datafusion-expr",
"datatypes",
"derive_more",
"geo",

View File

@@ -28,6 +28,8 @@ common-telemetry.workspace = true
common-time.workspace = true
common-version.workspace = true
datafusion.workspace = true
datafusion-common.workspace = true
datafusion-expr.workspace = true
datatypes.workspace = true
derive_more = { version = "1", default-features = false, features = ["display"] }
geo = { version = "0.29", optional = true }

View File

@@ -63,7 +63,7 @@ pub trait Function: fmt::Display + Sync + Send {
fn signature(&self) -> Signature;
/// Evaluate the function, e.g. run/execute the function.
fn eval(&self, _func_ctx: FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef>;
fn eval(&self, ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef>;
}
pub type FunctionRef = Arc<dyn Function>;

View File

@@ -58,7 +58,7 @@ impl Function for DateAddFunction {
)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
@@ -146,7 +146,7 @@ mod tests {
let time_vector = TimestampSecondVector::from(times.clone());
let interval_vector = IntervalDayTimeVector::from_vec(intervals);
let args: Vec<VectorRef> = vec![Arc::new(time_vector), Arc::new(interval_vector)];
let vector = f.eval(FunctionContext::default(), &args).unwrap();
let vector = f.eval(&FunctionContext::default(), &args).unwrap();
assert_eq!(4, vector.len());
for (i, _t) in times.iter().enumerate() {
@@ -178,7 +178,7 @@ mod tests {
let date_vector = DateVector::from(dates.clone());
let interval_vector = IntervalYearMonthVector::from_vec(intervals);
let args: Vec<VectorRef> = vec![Arc::new(date_vector), Arc::new(interval_vector)];
let vector = f.eval(FunctionContext::default(), &args).unwrap();
let vector = f.eval(&FunctionContext::default(), &args).unwrap();
assert_eq!(4, vector.len());
for (i, _t) in dates.iter().enumerate() {

View File

@@ -53,7 +53,7 @@ impl Function for DateFormatFunction {
)
}
fn eval(&self, func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
@@ -202,7 +202,7 @@ mod tests {
let time_vector = TimestampSecondVector::from(times.clone());
let interval_vector = StringVector::from_vec(formats);
let args: Vec<VectorRef> = vec![Arc::new(time_vector), Arc::new(interval_vector)];
let vector = f.eval(FunctionContext::default(), &args).unwrap();
let vector = f.eval(&FunctionContext::default(), &args).unwrap();
assert_eq!(4, vector.len());
for (i, _t) in times.iter().enumerate() {
@@ -243,7 +243,7 @@ mod tests {
let date_vector = DateVector::from(dates.clone());
let interval_vector = StringVector::from_vec(formats);
let args: Vec<VectorRef> = vec![Arc::new(date_vector), Arc::new(interval_vector)];
let vector = f.eval(FunctionContext::default(), &args).unwrap();
let vector = f.eval(&FunctionContext::default(), &args).unwrap();
assert_eq!(4, vector.len());
for (i, _t) in dates.iter().enumerate() {
@@ -284,7 +284,7 @@ mod tests {
let date_vector = DateTimeVector::from(dates.clone());
let interval_vector = StringVector::from_vec(formats);
let args: Vec<VectorRef> = vec![Arc::new(date_vector), Arc::new(interval_vector)];
let vector = f.eval(FunctionContext::default(), &args).unwrap();
let vector = f.eval(&FunctionContext::default(), &args).unwrap();
assert_eq!(4, vector.len());
for (i, _t) in dates.iter().enumerate() {

View File

@@ -58,7 +58,7 @@ impl Function for DateSubFunction {
)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
@@ -151,7 +151,7 @@ mod tests {
let time_vector = TimestampSecondVector::from(times.clone());
let interval_vector = IntervalDayTimeVector::from_vec(intervals);
let args: Vec<VectorRef> = vec![Arc::new(time_vector), Arc::new(interval_vector)];
let vector = f.eval(FunctionContext::default(), &args).unwrap();
let vector = f.eval(&FunctionContext::default(), &args).unwrap();
assert_eq!(4, vector.len());
for (i, _t) in times.iter().enumerate() {
@@ -189,7 +189,7 @@ mod tests {
let date_vector = DateVector::from(dates.clone());
let interval_vector = IntervalYearMonthVector::from_vec(intervals);
let args: Vec<VectorRef> = vec![Arc::new(date_vector), Arc::new(interval_vector)];
let vector = f.eval(FunctionContext::default(), &args).unwrap();
let vector = f.eval(&FunctionContext::default(), &args).unwrap();
assert_eq!(4, vector.len());
for (i, _t) in dates.iter().enumerate() {

View File

@@ -55,7 +55,7 @@ impl Function for IsNullFunction {
fn eval(
&self,
_func_ctx: FunctionContext,
_func_ctx: &FunctionContext,
columns: &[VectorRef],
) -> common_query::error::Result<VectorRef> {
ensure!(
@@ -102,7 +102,7 @@ mod tests {
let values = vec![None, Some(3.0), None];
let args: Vec<VectorRef> = vec![Arc::new(Float32Vector::from(values))];
let vector = is_null.eval(FunctionContext::default(), &args).unwrap();
let vector = is_null.eval(&FunctionContext::default(), &args).unwrap();
let expect: VectorRef = Arc::new(BooleanVector::from_vec(vec![true, false, true]));
assert_eq!(expect, vector);
}

View File

@@ -118,7 +118,7 @@ impl Function for GeohashFunction {
Signature::one_of(signatures, Volatility::Stable)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 3,
InvalidFuncArgsSnafu {
@@ -218,7 +218,7 @@ impl Function for GeohashNeighboursFunction {
Signature::one_of(signatures, Volatility::Stable)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 3,
InvalidFuncArgsSnafu {

View File

@@ -119,7 +119,7 @@ impl Function for H3LatLngToCell {
Signature::one_of(signatures, Volatility::Stable)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure_columns_n!(columns, 3);
let lat_vec = &columns[0];
@@ -191,7 +191,7 @@ impl Function for H3LatLngToCellString {
Signature::one_of(signatures, Volatility::Stable)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure_columns_n!(columns, 3);
let lat_vec = &columns[0];
@@ -247,7 +247,7 @@ impl Function for H3CellToString {
signature_of_cell()
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure_columns_n!(columns, 1);
let cell_vec = &columns[0];
@@ -285,7 +285,7 @@ impl Function for H3StringToCell {
)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure_columns_n!(columns, 1);
let string_vec = &columns[0];
@@ -337,7 +337,7 @@ impl Function for H3CellCenterLatLng {
signature_of_cell()
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure_columns_n!(columns, 1);
let cell_vec = &columns[0];
@@ -382,7 +382,7 @@ impl Function for H3CellResolution {
signature_of_cell()
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure_columns_n!(columns, 1);
let cell_vec = &columns[0];
@@ -418,7 +418,7 @@ impl Function for H3CellBase {
signature_of_cell()
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure_columns_n!(columns, 1);
let cell_vec = &columns[0];
@@ -454,7 +454,7 @@ impl Function for H3CellIsPentagon {
signature_of_cell()
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure_columns_n!(columns, 1);
let cell_vec = &columns[0];
@@ -490,7 +490,7 @@ impl Function for H3CellCenterChild {
signature_of_cell_and_resolution()
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure_columns_n!(columns, 2);
let cell_vec = &columns[0];
@@ -530,7 +530,7 @@ impl Function for H3CellParent {
signature_of_cell_and_resolution()
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure_columns_n!(columns, 2);
let cell_vec = &columns[0];
@@ -570,7 +570,7 @@ impl Function for H3CellToChildren {
signature_of_cell_and_resolution()
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure_columns_n!(columns, 2);
let cell_vec = &columns[0];
@@ -619,7 +619,7 @@ impl Function for H3CellToChildrenSize {
signature_of_cell_and_resolution()
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure_columns_n!(columns, 2);
let cell_vec = &columns[0];
@@ -656,7 +656,7 @@ impl Function for H3CellToChildPos {
signature_of_cell_and_resolution()
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure_columns_n!(columns, 2);
let cell_vec = &columns[0];
@@ -706,7 +706,7 @@ impl Function for H3ChildPosToCell {
Signature::one_of(signatures, Volatility::Stable)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure_columns_n!(columns, 3);
let pos_vec = &columns[0];
@@ -747,7 +747,7 @@ impl Function for H3GridDisk {
signature_of_cell_and_distance()
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure_columns_n!(columns, 2);
let cell_vec = &columns[0];
@@ -800,7 +800,7 @@ impl Function for H3GridDiskDistances {
signature_of_cell_and_distance()
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure_columns_n!(columns, 2);
let cell_vec = &columns[0];
@@ -850,7 +850,7 @@ impl Function for H3GridDistance {
signature_of_double_cells()
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure_columns_n!(columns, 2);
let cell_this_vec = &columns[0];
@@ -906,7 +906,7 @@ impl Function for H3GridPathCells {
signature_of_double_cells()
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure_columns_n!(columns, 2);
let cell_this_vec = &columns[0];
@@ -988,7 +988,7 @@ impl Function for H3CellContains {
Signature::one_of(signatures, Volatility::Stable)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure_columns_n!(columns, 2);
let cells_vec = &columns[0];
@@ -1042,7 +1042,7 @@ impl Function for H3CellDistanceSphereKm {
signature_of_double_cells()
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure_columns_n!(columns, 2);
let cell_this_vec = &columns[0];
@@ -1097,7 +1097,7 @@ impl Function for H3CellDistanceEuclideanDegree {
signature_of_double_cells()
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure_columns_n!(columns, 2);
let cell_this_vec = &columns[0];

View File

@@ -54,7 +54,7 @@ impl Function for STDistance {
)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure_columns_n!(columns, 2);
let wkt_this_vec = &columns[0];
@@ -108,7 +108,7 @@ impl Function for STDistanceSphere {
)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure_columns_n!(columns, 2);
let wkt_this_vec = &columns[0];
@@ -169,7 +169,7 @@ impl Function for STArea {
)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure_columns_n!(columns, 1);
let wkt_vec = &columns[0];

View File

@@ -51,7 +51,7 @@ impl Function for STContains {
)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure_columns_n!(columns, 2);
let wkt_this_vec = &columns[0];
@@ -105,7 +105,7 @@ impl Function for STWithin {
)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure_columns_n!(columns, 2);
let wkt_this_vec = &columns[0];
@@ -159,7 +159,7 @@ impl Function for STIntersects {
)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure_columns_n!(columns, 2);
let wkt_this_vec = &columns[0];

View File

@@ -84,7 +84,7 @@ impl Function for S2LatLngToCell {
Signature::one_of(signatures, Volatility::Stable)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure_columns_n!(columns, 2);
let lat_vec = &columns[0];
@@ -138,7 +138,7 @@ impl Function for S2CellLevel {
signature_of_cell()
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure_columns_n!(columns, 1);
let cell_vec = &columns[0];
@@ -174,7 +174,7 @@ impl Function for S2CellToToken {
signature_of_cell()
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure_columns_n!(columns, 1);
let cell_vec = &columns[0];
@@ -210,7 +210,7 @@ impl Function for S2CellParent {
signature_of_cell_and_level()
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure_columns_n!(columns, 2);
let cell_vec = &columns[0];

View File

@@ -63,7 +63,7 @@ impl Function for LatLngToPointWkt {
Signature::one_of(signatures, Volatility::Stable)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure_columns_n!(columns, 2);
let lat_vec = &columns[0];

View File

@@ -71,7 +71,7 @@ impl Function for HllCalcFunction {
)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
if columns.len() != 1 {
return InvalidFuncArgsSnafu {
err_msg: format!("hll_count expects 1 argument, got {}", columns.len()),
@@ -142,7 +142,7 @@ mod tests {
let serialized_bytes = bincode::serialize(&hll).unwrap();
let args: Vec<VectorRef> = vec![Arc::new(BinaryVector::from(vec![Some(serialized_bytes)]))];
let result = function.eval(FunctionContext::default(), &args).unwrap();
let result = function.eval(&FunctionContext::default(), &args).unwrap();
assert_eq!(result.len(), 1);
// Test cardinality estimate
@@ -159,7 +159,7 @@ mod tests {
// Test with invalid number of arguments
let args: Vec<VectorRef> = vec![];
let result = function.eval(FunctionContext::default(), &args);
let result = function.eval(&FunctionContext::default(), &args);
assert!(result.is_err());
assert!(result
.unwrap_err()
@@ -168,7 +168,7 @@ mod tests {
// Test with invalid binary data
let args: Vec<VectorRef> = vec![Arc::new(BinaryVector::from(vec![Some(vec![1, 2, 3])]))]; // Invalid binary data
let result = function.eval(FunctionContext::default(), &args).unwrap();
let result = function.eval(&FunctionContext::default(), &args).unwrap();
assert_eq!(result.len(), 1);
assert!(matches!(result.get(0), datatypes::value::Value::Null));
}

View File

@@ -72,7 +72,7 @@ macro_rules! json_get {
)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
@@ -175,7 +175,7 @@ impl Function for JsonGetString {
)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
@@ -282,7 +282,7 @@ mod tests {
let path_vector = StringVector::from_vec(paths);
let args: Vec<VectorRef> = vec![Arc::new(json_vector), Arc::new(path_vector)];
let vector = json_get_int
.eval(FunctionContext::default(), &args)
.eval(&FunctionContext::default(), &args)
.unwrap();
assert_eq!(3, vector.len());
@@ -335,7 +335,7 @@ mod tests {
let path_vector = StringVector::from_vec(paths);
let args: Vec<VectorRef> = vec![Arc::new(json_vector), Arc::new(path_vector)];
let vector = json_get_float
.eval(FunctionContext::default(), &args)
.eval(&FunctionContext::default(), &args)
.unwrap();
assert_eq!(3, vector.len());
@@ -388,7 +388,7 @@ mod tests {
let path_vector = StringVector::from_vec(paths);
let args: Vec<VectorRef> = vec![Arc::new(json_vector), Arc::new(path_vector)];
let vector = json_get_bool
.eval(FunctionContext::default(), &args)
.eval(&FunctionContext::default(), &args)
.unwrap();
assert_eq!(3, vector.len());
@@ -441,7 +441,7 @@ mod tests {
let path_vector = StringVector::from_vec(paths);
let args: Vec<VectorRef> = vec![Arc::new(json_vector), Arc::new(path_vector)];
let vector = json_get_string
.eval(FunctionContext::default(), &args)
.eval(&FunctionContext::default(), &args)
.unwrap();
assert_eq!(3, vector.len());

View File

@@ -45,7 +45,7 @@ macro_rules! json_is {
Signature::exact(vec![ConcreteDataType::json_datatype()], Volatility::Immutable)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 1,
InvalidFuncArgsSnafu {
@@ -202,7 +202,7 @@ mod tests {
let args: Vec<VectorRef> = vec![Arc::new(json_vector)];
for (func, expected_result) in json_is_functions.iter().zip(expected_results.iter()) {
let vector = func.eval(FunctionContext::default(), &args).unwrap();
let vector = func.eval(&FunctionContext::default(), &args).unwrap();
assert_eq!(vector.len(), json_strings.len());
for (i, expected) in expected_result.iter().enumerate() {

View File

@@ -64,7 +64,7 @@ impl Function for JsonPathExistsFunction {
)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
@@ -204,7 +204,7 @@ mod tests {
let path_vector = StringVector::from_vec(paths);
let args: Vec<VectorRef> = vec![Arc::new(json_vector), Arc::new(path_vector)];
let vector = json_path_exists
.eval(FunctionContext::default(), &args)
.eval(&FunctionContext::default(), &args)
.unwrap();
// Test for non-nulls.
@@ -222,7 +222,7 @@ mod tests {
let illegal_path = StringVector::from_vec(vec!["$..a"]);
let args: Vec<VectorRef> = vec![Arc::new(json), Arc::new(illegal_path)];
let err = json_path_exists.eval(FunctionContext::default(), &args);
let err = json_path_exists.eval(&FunctionContext::default(), &args);
assert!(err.is_err());
// Test for nulls.
@@ -235,11 +235,11 @@ mod tests {
let args: Vec<VectorRef> = vec![Arc::new(null_json), Arc::new(path)];
let result1 = json_path_exists
.eval(FunctionContext::default(), &args)
.eval(&FunctionContext::default(), &args)
.unwrap();
let args: Vec<VectorRef> = vec![Arc::new(json), Arc::new(null_path)];
let result2 = json_path_exists
.eval(FunctionContext::default(), &args)
.eval(&FunctionContext::default(), &args)
.unwrap();
assert_eq!(result1.len(), 1);

View File

@@ -50,7 +50,7 @@ impl Function for JsonPathMatchFunction {
)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
@@ -180,7 +180,7 @@ mod tests {
let path_vector = StringVector::from(paths);
let args: Vec<VectorRef> = vec![Arc::new(json_vector), Arc::new(path_vector)];
let vector = json_path_match
.eval(FunctionContext::default(), &args)
.eval(&FunctionContext::default(), &args)
.unwrap();
assert_eq!(7, vector.len());

View File

@@ -47,7 +47,7 @@ impl Function for JsonToStringFunction {
)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 1,
InvalidFuncArgsSnafu {
@@ -154,7 +154,7 @@ mod tests {
let json_vector = BinaryVector::from_vec(jsonbs);
let args: Vec<VectorRef> = vec![Arc::new(json_vector)];
let vector = json_to_string
.eval(FunctionContext::default(), &args)
.eval(&FunctionContext::default(), &args)
.unwrap();
assert_eq!(3, vector.len());
@@ -168,7 +168,7 @@ mod tests {
let invalid_jsonb = vec![b"invalid json"];
let invalid_json_vector = BinaryVector::from_vec(invalid_jsonb);
let args: Vec<VectorRef> = vec![Arc::new(invalid_json_vector)];
let vector = json_to_string.eval(FunctionContext::default(), &args);
let vector = json_to_string.eval(&FunctionContext::default(), &args);
assert!(vector.is_err());
}
}

View File

@@ -47,7 +47,7 @@ impl Function for ParseJsonFunction {
)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 1,
InvalidFuncArgsSnafu {
@@ -152,7 +152,7 @@ mod tests {
let json_string_vector = StringVector::from_vec(json_strings.to_vec());
let args: Vec<VectorRef> = vec![Arc::new(json_string_vector)];
let vector = parse_json.eval(FunctionContext::default(), &args).unwrap();
let vector = parse_json.eval(&FunctionContext::default(), &args).unwrap();
assert_eq!(3, vector.len());
for (i, gt) in jsonbs.iter().enumerate() {

View File

@@ -72,7 +72,7 @@ impl Function for MatchesFunction {
}
// TODO: read case-sensitive config
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
@@ -82,6 +82,12 @@ impl Function for MatchesFunction {
),
}
);
let data_column = &columns[0];
if data_column.is_empty() {
return Ok(Arc::new(BooleanVector::from(Vec::<bool>::with_capacity(0))));
}
let pattern_vector = &columns[1]
.cast(&ConcreteDataType::string_datatype())
.context(InvalidInputTypeSnafu {
@@ -89,12 +95,12 @@ impl Function for MatchesFunction {
})?;
// Safety: both length and type are checked before
let pattern = pattern_vector.get(0).as_string().unwrap();
self.eval(columns[0].clone(), pattern)
self.eval(data_column, pattern)
}
}
impl MatchesFunction {
fn eval(&self, data: VectorRef, pattern: String) -> Result<VectorRef> {
fn eval(&self, data: &VectorRef, pattern: String) -> Result<VectorRef> {
let col_name = "data";
let parser_context = ParserContext::default();
let raw_ast = parser_context.parse_pattern(&pattern)?;
@@ -1309,7 +1315,7 @@ mod test {
"The quick brown fox jumps over dog",
"The quick brown fox jumps over the dog",
];
let input_vector = Arc::new(StringVector::from(input_data));
let input_vector: VectorRef = Arc::new(StringVector::from(input_data));
let cases = [
// basic cases
("quick", vec![true, false, true, true, true, true, true]),
@@ -1400,7 +1406,7 @@ mod test {
let f = MatchesFunction;
for (pattern, expected) in cases {
let actual: VectorRef = f.eval(input_vector.clone(), pattern.to_string()).unwrap();
let actual: VectorRef = f.eval(&input_vector, pattern.to_string()).unwrap();
let expected: VectorRef = Arc::new(BooleanVector::from(expected)) as _;
assert_eq!(expected, actual, "{pattern}");
}

View File

@@ -80,7 +80,7 @@ impl Function for RangeFunction {
Signature::variadic_any(Volatility::Immutable)
}
fn eval(&self, _func_ctx: FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
Err(DataFusionError::Internal(
"range_fn just a empty function used in range select, It should not be eval!".into(),
))

View File

@@ -27,7 +27,7 @@ use datatypes::vectors::PrimitiveVector;
use datatypes::with_match_primitive_type_id;
use snafu::{ensure, OptionExt};
use crate::function::Function;
use crate::function::{Function, FunctionContext};
#[derive(Clone, Debug, Default)]
pub struct ClampFunction;
@@ -49,11 +49,7 @@ impl Function for ClampFunction {
Signature::uniform(3, ConcreteDataType::numerics(), Volatility::Immutable)
}
fn eval(
&self,
_func_ctx: crate::function::FunctionContext,
columns: &[VectorRef],
) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 3,
InvalidFuncArgsSnafu {
@@ -209,7 +205,7 @@ mod test {
Arc::new(Int64Vector::from_vec(vec![max])) as _,
];
let result = func
.eval(FunctionContext::default(), args.as_slice())
.eval(&FunctionContext::default(), args.as_slice())
.unwrap();
let expected: VectorRef = Arc::new(Int64Vector::from(expected));
assert_eq!(expected, result);
@@ -253,7 +249,7 @@ mod test {
Arc::new(UInt64Vector::from_vec(vec![max])) as _,
];
let result = func
.eval(FunctionContext::default(), args.as_slice())
.eval(&FunctionContext::default(), args.as_slice())
.unwrap();
let expected: VectorRef = Arc::new(UInt64Vector::from(expected));
assert_eq!(expected, result);
@@ -297,7 +293,7 @@ mod test {
Arc::new(Float64Vector::from_vec(vec![max])) as _,
];
let result = func
.eval(FunctionContext::default(), args.as_slice())
.eval(&FunctionContext::default(), args.as_slice())
.unwrap();
let expected: VectorRef = Arc::new(Float64Vector::from(expected));
assert_eq!(expected, result);
@@ -317,7 +313,7 @@ mod test {
Arc::new(Int64Vector::from_vec(vec![max])) as _,
];
let result = func
.eval(FunctionContext::default(), args.as_slice())
.eval(&FunctionContext::default(), args.as_slice())
.unwrap();
let expected: VectorRef = Arc::new(Int64Vector::from(vec![Some(4)]));
assert_eq!(expected, result);
@@ -335,7 +331,7 @@ mod test {
Arc::new(Float64Vector::from_vec(vec![min])) as _,
Arc::new(Float64Vector::from_vec(vec![max])) as _,
];
let result = func.eval(FunctionContext::default(), args.as_slice());
let result = func.eval(&FunctionContext::default(), args.as_slice());
assert!(result.is_err());
}
@@ -351,7 +347,7 @@ mod test {
Arc::new(Int64Vector::from_vec(vec![min])) as _,
Arc::new(UInt64Vector::from_vec(vec![max])) as _,
];
let result = func.eval(FunctionContext::default(), args.as_slice());
let result = func.eval(&FunctionContext::default(), args.as_slice());
assert!(result.is_err());
}
@@ -367,7 +363,7 @@ mod test {
Arc::new(Float64Vector::from_vec(vec![min, min])) as _,
Arc::new(Float64Vector::from_vec(vec![max])) as _,
];
let result = func.eval(FunctionContext::default(), args.as_slice());
let result = func.eval(&FunctionContext::default(), args.as_slice());
assert!(result.is_err());
}
@@ -381,7 +377,7 @@ mod test {
Arc::new(Float64Vector::from(input)) as _,
Arc::new(Float64Vector::from_vec(vec![min])) as _,
];
let result = func.eval(FunctionContext::default(), args.as_slice());
let result = func.eval(&FunctionContext::default(), args.as_slice());
assert!(result.is_err());
}
@@ -395,7 +391,7 @@ mod test {
Arc::new(StringVector::from_vec(vec!["bar"])) as _,
Arc::new(StringVector::from_vec(vec!["baz"])) as _,
];
let result = func.eval(FunctionContext::default(), args.as_slice());
let result = func.eval(&FunctionContext::default(), args.as_slice());
assert!(result.is_err());
}
}

View File

@@ -58,7 +58,7 @@ impl Function for ModuloFunction {
Signature::uniform(2, ConcreteDataType::numerics(), Volatility::Immutable)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
@@ -126,7 +126,7 @@ mod tests {
Arc::new(Int32Vector::from_vec(nums.clone())),
Arc::new(Int32Vector::from_vec(divs.clone())),
];
let result = function.eval(FunctionContext::default(), &args).unwrap();
let result = function.eval(&FunctionContext::default(), &args).unwrap();
assert_eq!(result.len(), 4);
for i in 0..4 {
let p: i64 = (nums[i] % divs[i]) as i64;
@@ -158,7 +158,7 @@ mod tests {
Arc::new(UInt32Vector::from_vec(nums.clone())),
Arc::new(UInt32Vector::from_vec(divs.clone())),
];
let result = function.eval(FunctionContext::default(), &args).unwrap();
let result = function.eval(&FunctionContext::default(), &args).unwrap();
assert_eq!(result.len(), 4);
for i in 0..4 {
let p: u64 = (nums[i] % divs[i]) as u64;
@@ -190,7 +190,7 @@ mod tests {
Arc::new(Float64Vector::from_vec(nums.clone())),
Arc::new(Float64Vector::from_vec(divs.clone())),
];
let result = function.eval(FunctionContext::default(), &args).unwrap();
let result = function.eval(&FunctionContext::default(), &args).unwrap();
assert_eq!(result.len(), 4);
for i in 0..4 {
let p: f64 = nums[i] % divs[i];
@@ -209,7 +209,7 @@ mod tests {
Arc::new(Int32Vector::from_vec(nums.clone())),
Arc::new(Int32Vector::from_vec(divs.clone())),
];
let result = function.eval(FunctionContext::default(), &args);
let result = function.eval(&FunctionContext::default(), &args);
assert!(result.is_err());
let err_msg = result.unwrap_err().output_msg();
assert_eq!(
@@ -220,7 +220,7 @@ mod tests {
let nums = vec![27];
let args: Vec<VectorRef> = vec![Arc::new(Int32Vector::from_vec(nums.clone()))];
let result = function.eval(FunctionContext::default(), &args);
let result = function.eval(&FunctionContext::default(), &args);
assert!(result.is_err());
let err_msg = result.unwrap_err().output_msg();
assert!(
@@ -233,7 +233,7 @@ mod tests {
Arc::new(StringVector::from(nums.clone())),
Arc::new(StringVector::from(divs.clone())),
];
let result = function.eval(FunctionContext::default(), &args);
let result = function.eval(&FunctionContext::default(), &args);
assert!(result.is_err());
let err_msg = result.unwrap_err().output_msg();
assert!(err_msg.contains("Invalid arithmetic operation"));

View File

@@ -44,7 +44,7 @@ impl Function for PowFunction {
Signature::uniform(2, ConcreteDataType::numerics(), Volatility::Immutable)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
with_match_primitive_type_id!(columns[1].data_type().logical_type_id(), |$T| {
let col = scalar_binary_op::<<$S as LogicalPrimitiveType>::Native, <$T as LogicalPrimitiveType>::Native, f64, _>(&columns[0], &columns[1], scalar_pow, &mut EvalContext::default())?;
@@ -109,7 +109,7 @@ mod tests {
Arc::new(Int8Vector::from_vec(bases.clone())),
];
let vector = pow.eval(FunctionContext::default(), &args).unwrap();
let vector = pow.eval(&FunctionContext::default(), &args).unwrap();
assert_eq!(3, vector.len());
for i in 0..3 {

View File

@@ -48,7 +48,7 @@ impl Function for RateFunction {
Signature::uniform(2, ConcreteDataType::numerics(), Volatility::Immutable)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
let val = &columns[0].to_arrow_array();
let val_0 = val.slice(0, val.len() - 1);
let val_1 = val.slice(1, val.len() - 1);
@@ -100,7 +100,7 @@ mod tests {
Arc::new(Float32Vector::from_vec(values)),
Arc::new(Int64Vector::from_vec(ts)),
];
let vector = rate.eval(FunctionContext::default(), &args).unwrap();
let vector = rate.eval(&FunctionContext::default(), &args).unwrap();
let expect: VectorRef = Arc::new(Float64Vector::from_vec(vec![2.0, 3.0]));
assert_eq!(expect, vector);
}

View File

@@ -45,7 +45,7 @@ impl Function for TestAndFunction {
)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
let col = scalar_binary_op::<bool, bool, bool, _>(
&columns[0],
&columns[1],

View File

@@ -97,7 +97,7 @@ impl Function for GreatestFunction {
)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
@@ -191,7 +191,9 @@ mod tests {
])) as _,
];
let result = function.eval(FunctionContext::default(), &columns).unwrap();
let result = function
.eval(&FunctionContext::default(), &columns)
.unwrap();
let result = result.as_any().downcast_ref::<DateTimeVector>().unwrap();
assert_eq!(result.len(), 2);
assert_eq!(
@@ -222,7 +224,9 @@ mod tests {
Arc::new(DateVector::from_slice(vec![0, 1])) as _,
];
let result = function.eval(FunctionContext::default(), &columns).unwrap();
let result = function
.eval(&FunctionContext::default(), &columns)
.unwrap();
let result = result.as_any().downcast_ref::<DateVector>().unwrap();
assert_eq!(result.len(), 2);
assert_eq!(
@@ -253,7 +257,9 @@ mod tests {
Arc::new(DateTimeVector::from_slice(vec![0, 1])) as _,
];
let result = function.eval(FunctionContext::default(), &columns).unwrap();
let result = function
.eval(&FunctionContext::default(), &columns)
.unwrap();
let result = result.as_any().downcast_ref::<DateTimeVector>().unwrap();
assert_eq!(result.len(), 2);
assert_eq!(
@@ -282,7 +288,7 @@ mod tests {
Arc::new([<Timestamp $unit Vector>]::from_slice(vec![0, 1])) as _,
];
let result = function.eval(FunctionContext::default(), &columns).unwrap();
let result = function.eval(&FunctionContext::default(), &columns).unwrap();
let result = result.as_any().downcast_ref::<[<Timestamp $unit Vector>]>().unwrap();
assert_eq!(result.len(), 2);
assert_eq!(

View File

@@ -92,7 +92,7 @@ impl Function for ToUnixtimeFunction {
)
}
fn eval(&self, func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 1,
InvalidFuncArgsSnafu {
@@ -108,7 +108,7 @@ impl Function for ToUnixtimeFunction {
match columns[0].data_type() {
ConcreteDataType::String(_) => Ok(Arc::new(Int64Vector::from(
(0..vector.len())
.map(|i| convert_to_seconds(&vector.get(i).to_string(), &func_ctx))
.map(|i| convert_to_seconds(&vector.get(i).to_string(), ctx))
.collect::<Vec<_>>(),
))),
ConcreteDataType::Int64(_) | ConcreteDataType::Int32(_) => {
@@ -187,7 +187,7 @@ mod tests {
];
let results = [Some(1677652502), None, Some(1656633600), None];
let args: Vec<VectorRef> = vec![Arc::new(StringVector::from(times.clone()))];
let vector = f.eval(FunctionContext::default(), &args).unwrap();
let vector = f.eval(&FunctionContext::default(), &args).unwrap();
assert_eq!(4, vector.len());
for (i, _t) in times.iter().enumerate() {
let v = vector.get(i);
@@ -211,7 +211,7 @@ mod tests {
let times = vec![Some(3_i64), None, Some(5_i64), None];
let results = [Some(3), None, Some(5), None];
let args: Vec<VectorRef> = vec![Arc::new(Int64Vector::from(times.clone()))];
let vector = f.eval(FunctionContext::default(), &args).unwrap();
let vector = f.eval(&FunctionContext::default(), &args).unwrap();
assert_eq!(4, vector.len());
for (i, _t) in times.iter().enumerate() {
let v = vector.get(i);
@@ -236,7 +236,7 @@ mod tests {
let results = [Some(10627200), None, Some(3628800), None];
let date_vector = DateVector::from(times.clone());
let args: Vec<VectorRef> = vec![Arc::new(date_vector)];
let vector = f.eval(FunctionContext::default(), &args).unwrap();
let vector = f.eval(&FunctionContext::default(), &args).unwrap();
assert_eq!(4, vector.len());
for (i, _t) in times.iter().enumerate() {
let v = vector.get(i);
@@ -261,7 +261,7 @@ mod tests {
let results = [Some(123), None, Some(42), None];
let date_vector = DateTimeVector::from(times.clone());
let args: Vec<VectorRef> = vec![Arc::new(date_vector)];
let vector = f.eval(FunctionContext::default(), &args).unwrap();
let vector = f.eval(&FunctionContext::default(), &args).unwrap();
assert_eq!(4, vector.len());
for (i, _t) in times.iter().enumerate() {
let v = vector.get(i);
@@ -286,7 +286,7 @@ mod tests {
let results = [Some(123), None, Some(42), None];
let ts_vector = TimestampSecondVector::from(times.clone());
let args: Vec<VectorRef> = vec![Arc::new(ts_vector)];
let vector = f.eval(FunctionContext::default(), &args).unwrap();
let vector = f.eval(&FunctionContext::default(), &args).unwrap();
assert_eq!(4, vector.len());
for (i, _t) in times.iter().enumerate() {
let v = vector.get(i);
@@ -306,7 +306,7 @@ mod tests {
let results = [Some(123), None, Some(42), None];
let ts_vector = TimestampMillisecondVector::from(times.clone());
let args: Vec<VectorRef> = vec![Arc::new(ts_vector)];
let vector = f.eval(FunctionContext::default(), &args).unwrap();
let vector = f.eval(&FunctionContext::default(), &args).unwrap();
assert_eq!(4, vector.len());
for (i, _t) in times.iter().enumerate() {
let v = vector.get(i);

View File

@@ -75,7 +75,7 @@ impl Function for UddSketchCalcFunction {
)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
if columns.len() != 2 {
return InvalidFuncArgsSnafu {
err_msg: format!("uddsketch_calc expects 2 arguments, got {}", columns.len()),
@@ -169,7 +169,7 @@ mod tests {
Arc::new(BinaryVector::from(vec![Some(serialized.clone()); 3])),
];
let result = function.eval(FunctionContext::default(), &args).unwrap();
let result = function.eval(&FunctionContext::default(), &args).unwrap();
assert_eq!(result.len(), 3);
// Test median (p50)
@@ -192,7 +192,7 @@ mod tests {
// Test with invalid number of arguments
let args: Vec<VectorRef> = vec![Arc::new(Float64Vector::from_vec(vec![0.95]))];
let result = function.eval(FunctionContext::default(), &args);
let result = function.eval(&FunctionContext::default(), &args);
assert!(result.is_err());
assert!(result
.unwrap_err()
@@ -204,7 +204,7 @@ mod tests {
Arc::new(Float64Vector::from_vec(vec![0.95])),
Arc::new(BinaryVector::from(vec![Some(vec![1, 2, 3])])), // Invalid binary data
];
let result = function.eval(FunctionContext::default(), &args).unwrap();
let result = function.eval(&FunctionContext::default(), &args).unwrap();
assert_eq!(result.len(), 1);
assert!(matches!(result.get(0), datatypes::value::Value::Null));
}

View File

@@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::any::Any;
use std::fmt::{Debug, Formatter};
use std::sync::Arc;
use common_query::error::FromScalarValueSnafu;
use common_query::prelude::{
ColumnarValue, ReturnTypeFunction, ScalarFunctionImplementation, ScalarUdf,
};
use datatypes::error::Error as DataTypeError;
use common_query::prelude::ColumnarValue;
use datafusion::logical_expr::{ScalarFunctionArgs, ScalarUDFImpl};
use datafusion_expr::ScalarUDF;
use datatypes::data_type::DataType;
use datatypes::prelude::*;
use datatypes::vectors::Helper;
use session::context::QueryContextRef;
@@ -27,58 +29,92 @@ use snafu::ResultExt;
use crate::function::{FunctionContext, FunctionRef};
use crate::state::FunctionState;
struct ScalarUdf {
function: FunctionRef,
signature: datafusion_expr::Signature,
context: FunctionContext,
}
impl Debug for ScalarUdf {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ScalarUdf")
.field("function", &self.function.name())
.field("signature", &self.signature)
.finish()
}
}
impl ScalarUDFImpl for ScalarUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
self.function.name()
}
fn signature(&self) -> &datafusion_expr::Signature {
&self.signature
}
fn return_type(
&self,
arg_types: &[datatypes::arrow::datatypes::DataType],
) -> datafusion_common::Result<datatypes::arrow::datatypes::DataType> {
let arg_types = arg_types
.iter()
.map(ConcreteDataType::from_arrow_type)
.collect::<Vec<_>>();
let t = self.function.return_type(&arg_types)?;
Ok(t.as_arrow_type())
}
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<datafusion_expr::ColumnarValue> {
let columns = args
.args
.iter()
.map(|x| {
ColumnarValue::try_from(x).and_then(|y| match y {
ColumnarValue::Vector(z) => Ok(z),
ColumnarValue::Scalar(z) => Helper::try_from_scalar_value(z, args.number_rows)
.context(FromScalarValueSnafu),
})
})
.collect::<common_query::error::Result<Vec<_>>>()?;
let v = self
.function
.eval(&self.context, &columns)
.map(ColumnarValue::Vector)?;
Ok(v.into())
}
}
/// Create a ScalarUdf from function, query context and state.
pub fn create_udf(
func: FunctionRef,
query_ctx: QueryContextRef,
state: Arc<FunctionState>,
) -> ScalarUdf {
let func_cloned = func.clone();
let return_type: ReturnTypeFunction = Arc::new(move |input_types: &[ConcreteDataType]| {
Ok(Arc::new(func_cloned.return_type(input_types)?))
});
let func_cloned = func.clone();
let fun: ScalarFunctionImplementation = Arc::new(move |args: &[ColumnarValue]| {
let func_ctx = FunctionContext {
query_ctx: query_ctx.clone(),
state: state.clone(),
};
let len = args
.iter()
.fold(Option::<usize>::None, |acc, arg| match arg {
ColumnarValue::Scalar(_) => acc,
ColumnarValue::Vector(v) => Some(v.len()),
});
let rows = len.unwrap_or(1);
let args: Result<Vec<_>, DataTypeError> = args
.iter()
.map(|arg| match arg {
ColumnarValue::Scalar(v) => Helper::try_from_scalar_value(v.clone(), rows),
ColumnarValue::Vector(v) => Ok(v.clone()),
})
.collect();
let result = func_cloned.eval(func_ctx, &args.context(FromScalarValueSnafu)?);
let udf_result = result.map(ColumnarValue::Vector)?;
Ok(udf_result)
});
ScalarUdf::new(func.name(), &func.signature(), &return_type, &fun)
) -> ScalarUDF {
let signature = func.signature().into();
let udf = ScalarUdf {
function: func,
signature,
context: FunctionContext { query_ctx, state },
};
ScalarUDF::new_from_impl(udf)
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use common_query::prelude::{ColumnarValue, ScalarValue};
use common_query::prelude::ScalarValue;
use datafusion::arrow::array::BooleanArray;
use datatypes::data_type::ConcreteDataType;
use datatypes::prelude::{ScalarVector, Vector, VectorRef};
use datatypes::value::Value;
use datatypes::prelude::VectorRef;
use datatypes::vectors::{BooleanVector, ConstantVector};
use session::context::QueryContextBuilder;
@@ -99,7 +135,7 @@ mod tests {
Arc::new(BooleanVector::from(vec![true, false, true])),
];
let vector = f.eval(FunctionContext::default(), &args).unwrap();
let vector = f.eval(&FunctionContext::default(), &args).unwrap();
assert_eq!(3, vector.len());
for i in 0..3 {
@@ -109,30 +145,36 @@ mod tests {
// create a udf and test it again
let udf = create_udf(f.clone(), query_ctx, Arc::new(FunctionState::default()));
assert_eq!("test_and", udf.name);
assert_eq!(f.signature(), udf.signature);
assert_eq!("test_and", udf.name());
let expected_signature: datafusion_expr::Signature = f.signature().into();
assert_eq!(udf.signature(), &expected_signature);
assert_eq!(
Arc::new(ConcreteDataType::boolean_datatype()),
((udf.return_type)(&[])).unwrap()
ConcreteDataType::boolean_datatype(),
udf.return_type(&[])
.map(|x| ConcreteDataType::from_arrow_type(&x))
.unwrap()
);
let args = vec![
ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))),
ColumnarValue::Vector(Arc::new(BooleanVector::from(vec![
datafusion_expr::ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))),
datafusion_expr::ColumnarValue::Array(Arc::new(BooleanArray::from(vec![
true, false, false, true,
]))),
];
let vec = (udf.fun)(&args).unwrap();
match vec {
ColumnarValue::Vector(vec) => {
let vec = vec.as_any().downcast_ref::<BooleanVector>().unwrap();
assert_eq!(4, vec.len());
for i in 0..4 {
assert_eq!(i == 0 || i == 3, vec.get_data(i).unwrap(), "Failed at {i}",)
}
let args = ScalarFunctionArgs {
args: &args,
number_rows: 4,
return_type: &ConcreteDataType::boolean_datatype().as_arrow_type(),
};
match udf.invoke_with_args(args).unwrap() {
datafusion_expr::ColumnarValue::Array(x) => {
let x = x.as_any().downcast_ref::<BooleanArray>().unwrap();
assert_eq!(x.len(), 4);
assert_eq!(
x.iter().flatten().collect::<Vec<bool>>(),
vec![true, false, false, true]
);
}
_ => unreachable!(),
}

View File

@@ -45,7 +45,7 @@ impl Function for ParseVectorFunction {
)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 1,
InvalidFuncArgsSnafu {
@@ -101,7 +101,7 @@ mod tests {
None,
]));
let result = func.eval(FunctionContext::default(), &[input]).unwrap();
let result = func.eval(&FunctionContext::default(), &[input]).unwrap();
let result = result.as_ref();
assert_eq!(result.len(), 3);
@@ -136,7 +136,7 @@ mod tests {
Some("[7.0,8.0,9.0".to_string()),
]));
let result = func.eval(FunctionContext::default(), &[input]);
let result = func.eval(&FunctionContext::default(), &[input]);
assert!(result.is_err());
let input = Arc::new(StringVector::from(vec![
@@ -145,7 +145,7 @@ mod tests {
Some("7.0,8.0,9.0]".to_string()),
]));
let result = func.eval(FunctionContext::default(), &[input]);
let result = func.eval(&FunctionContext::default(), &[input]);
assert!(result.is_err());
let input = Arc::new(StringVector::from(vec![
@@ -154,7 +154,7 @@ mod tests {
Some("[7.0,hello,9.0]".to_string()),
]));
let result = func.eval(FunctionContext::default(), &[input]);
let result = func.eval(&FunctionContext::default(), &[input]);
assert!(result.is_err());
}
}

View File

@@ -46,7 +46,7 @@ impl Function for VectorToStringFunction {
)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 1,
InvalidFuncArgsSnafu {
@@ -129,7 +129,7 @@ mod tests {
builder.push_null();
let vector = builder.to_vector();
let result = func.eval(FunctionContext::default(), &[vector]).unwrap();
let result = func.eval(&FunctionContext::default(), &[vector]).unwrap();
assert_eq!(result.len(), 3);
assert_eq!(result.get(0), Value::String("[1,2,3]".to_string().into()));

View File

@@ -60,7 +60,7 @@ macro_rules! define_distance_function {
)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
@@ -159,7 +159,7 @@ mod tests {
])) as VectorRef;
let result = func
.eval(FunctionContext::default(), &[vec1.clone(), vec2.clone()])
.eval(&FunctionContext::default(), &[vec1.clone(), vec2.clone()])
.unwrap();
assert!(!result.get(0).is_null());
@@ -168,7 +168,7 @@ mod tests {
assert!(result.get(3).is_null());
let result = func
.eval(FunctionContext::default(), &[vec2, vec1])
.eval(&FunctionContext::default(), &[vec2, vec1])
.unwrap();
assert!(!result.get(0).is_null());
@@ -202,7 +202,7 @@ mod tests {
])) as VectorRef;
let result = func
.eval(FunctionContext::default(), &[vec1.clone(), vec2.clone()])
.eval(&FunctionContext::default(), &[vec1.clone(), vec2.clone()])
.unwrap();
assert!(!result.get(0).is_null());
@@ -211,7 +211,7 @@ mod tests {
assert!(result.get(3).is_null());
let result = func
.eval(FunctionContext::default(), &[vec2, vec1])
.eval(&FunctionContext::default(), &[vec2, vec1])
.unwrap();
assert!(!result.get(0).is_null());
@@ -245,7 +245,7 @@ mod tests {
])) as VectorRef;
let result = func
.eval(FunctionContext::default(), &[vec1.clone(), vec2.clone()])
.eval(&FunctionContext::default(), &[vec1.clone(), vec2.clone()])
.unwrap();
assert!(!result.get(0).is_null());
@@ -254,7 +254,7 @@ mod tests {
assert!(result.get(3).is_null());
let result = func
.eval(FunctionContext::default(), &[vec2, vec1])
.eval(&FunctionContext::default(), &[vec2, vec1])
.unwrap();
assert!(!result.get(0).is_null());
@@ -294,7 +294,7 @@ mod tests {
let result = func
.eval(
FunctionContext::default(),
&FunctionContext::default(),
&[const_str.clone(), vec1.clone()],
)
.unwrap();
@@ -306,7 +306,7 @@ mod tests {
let result = func
.eval(
FunctionContext::default(),
&FunctionContext::default(),
&[vec1.clone(), const_str.clone()],
)
.unwrap();
@@ -318,7 +318,7 @@ mod tests {
let result = func
.eval(
FunctionContext::default(),
&FunctionContext::default(),
&[const_str.clone(), vec2.clone()],
)
.unwrap();
@@ -330,7 +330,7 @@ mod tests {
let result = func
.eval(
FunctionContext::default(),
&FunctionContext::default(),
&[vec2.clone(), const_str.clone()],
)
.unwrap();
@@ -353,13 +353,13 @@ mod tests {
for func in funcs {
let vec1 = Arc::new(StringVector::from(vec!["[1.0]"])) as VectorRef;
let vec2 = Arc::new(StringVector::from(vec!["[1.0, 1.0]"])) as VectorRef;
let result = func.eval(FunctionContext::default(), &[vec1, vec2]);
let result = func.eval(&FunctionContext::default(), &[vec1, vec2]);
assert!(result.is_err());
let vec1 = Arc::new(BinaryVector::from(vec![vec![0, 0, 128, 63]])) as VectorRef;
let vec2 =
Arc::new(BinaryVector::from(vec![vec![0, 0, 128, 63, 0, 0, 0, 64]])) as VectorRef;
let result = func.eval(FunctionContext::default(), &[vec1, vec2]);
let result = func.eval(&FunctionContext::default(), &[vec1, vec2]);
assert!(result.is_err());
}
}

View File

@@ -68,7 +68,7 @@ impl Function for ElemProductFunction {
fn eval(
&self,
_func_ctx: FunctionContext,
_func_ctx: &FunctionContext,
columns: &[VectorRef],
) -> common_query::error::Result<VectorRef> {
ensure!(
@@ -131,7 +131,7 @@ mod tests {
None,
]));
let result = func.eval(FunctionContext::default(), &[input0]).unwrap();
let result = func.eval(&FunctionContext::default(), &[input0]).unwrap();
let result = result.as_ref();
assert_eq!(result.len(), 3);

View File

@@ -55,7 +55,7 @@ impl Function for ElemSumFunction {
fn eval(
&self,
_func_ctx: FunctionContext,
_func_ctx: &FunctionContext,
columns: &[VectorRef],
) -> common_query::error::Result<VectorRef> {
ensure!(
@@ -118,7 +118,7 @@ mod tests {
None,
]));
let result = func.eval(FunctionContext::default(), &[input0]).unwrap();
let result = func.eval(&FunctionContext::default(), &[input0]).unwrap();
let result = result.as_ref();
assert_eq!(result.len(), 3);

View File

@@ -73,7 +73,7 @@ impl Function for ScalarAddFunction {
)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
@@ -154,7 +154,7 @@ mod tests {
]));
let result = func
.eval(FunctionContext::default(), &[input0, input1])
.eval(&FunctionContext::default(), &[input0, input1])
.unwrap();
let result = result.as_ref();

View File

@@ -73,7 +73,7 @@ impl Function for ScalarMulFunction {
)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
@@ -154,7 +154,7 @@ mod tests {
]));
let result = func
.eval(FunctionContext::default(), &[input0, input1])
.eval(&FunctionContext::default(), &[input0, input1])
.unwrap();
let result = result.as_ref();

View File

@@ -72,7 +72,7 @@ impl Function for VectorAddFunction {
fn eval(
&self,
_func_ctx: FunctionContext,
_func_ctx: &FunctionContext,
columns: &[VectorRef],
) -> common_query::error::Result<VectorRef> {
ensure!(
@@ -166,7 +166,7 @@ mod tests {
]));
let result = func
.eval(FunctionContext::default(), &[input0, input1])
.eval(&FunctionContext::default(), &[input0, input1])
.unwrap();
let result = result.as_ref();
@@ -199,7 +199,7 @@ mod tests {
Some("[3.0,2.0,2.0]".to_string()),
]));
let result = func.eval(FunctionContext::default(), &[input0, input1]);
let result = func.eval(&FunctionContext::default(), &[input0, input1]);
match result {
Err(Error::InvalidFuncArgs { err_msg, .. }) => {

View File

@@ -67,7 +67,7 @@ impl Function for VectorDimFunction {
fn eval(
&self,
_func_ctx: FunctionContext,
_func_ctx: &FunctionContext,
columns: &[VectorRef],
) -> common_query::error::Result<VectorRef> {
ensure!(
@@ -131,7 +131,7 @@ mod tests {
Some("[5.0]".to_string()),
]));
let result = func.eval(FunctionContext::default(), &[input0]).unwrap();
let result = func.eval(&FunctionContext::default(), &[input0]).unwrap();
let result = result.as_ref();
assert_eq!(result.len(), 4);
@@ -157,7 +157,7 @@ mod tests {
Some("[3.0,2.0,2.0]".to_string()),
]));
let result = func.eval(FunctionContext::default(), &[input0, input1]);
let result = func.eval(&FunctionContext::default(), &[input0, input1]);
match result {
Err(Error::InvalidFuncArgs { err_msg, .. }) => {

View File

@@ -68,7 +68,7 @@ impl Function for VectorDivFunction {
)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
@@ -155,7 +155,7 @@ mod tests {
let input1 = Arc::new(StringVector::from(vec![Some(format!("{vec1:?}"))]));
let err = func
.eval(FunctionContext::default(), &[input0, input1])
.eval(&FunctionContext::default(), &[input0, input1])
.unwrap_err();
match err {
@@ -186,7 +186,7 @@ mod tests {
]));
let result = func
.eval(FunctionContext::default(), &[input0, input1])
.eval(&FunctionContext::default(), &[input0, input1])
.unwrap();
let result = result.as_ref();
@@ -206,7 +206,7 @@ mod tests {
let input1 = Arc::new(StringVector::from(vec![Some("[0.0,0.0]".to_string())]));
let result = func
.eval(FunctionContext::default(), &[input0, input1])
.eval(&FunctionContext::default(), &[input0, input1])
.unwrap();
let result = result.as_ref();

View File

@@ -68,7 +68,7 @@ impl Function for VectorMulFunction {
)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
@@ -155,7 +155,7 @@ mod tests {
let input1 = Arc::new(StringVector::from(vec![Some(format!("{vec1:?}"))]));
let err = func
.eval(FunctionContext::default(), &[input0, input1])
.eval(&FunctionContext::default(), &[input0, input1])
.unwrap_err();
match err {
@@ -186,7 +186,7 @@ mod tests {
]));
let result = func
.eval(FunctionContext::default(), &[input0, input1])
.eval(&FunctionContext::default(), &[input0, input1])
.unwrap();
let result = result.as_ref();

View File

@@ -67,7 +67,7 @@ impl Function for VectorNormFunction {
fn eval(
&self,
_func_ctx: FunctionContext,
_func_ctx: &FunctionContext,
columns: &[VectorRef],
) -> common_query::error::Result<VectorRef> {
ensure!(
@@ -143,7 +143,7 @@ mod tests {
None,
]));
let result = func.eval(FunctionContext::default(), &[input0]).unwrap();
let result = func.eval(&FunctionContext::default(), &[input0]).unwrap();
let result = result.as_ref();
assert_eq!(result.len(), 5);

View File

@@ -72,7 +72,7 @@ impl Function for VectorSubFunction {
fn eval(
&self,
_func_ctx: FunctionContext,
_func_ctx: &FunctionContext,
columns: &[VectorRef],
) -> common_query::error::Result<VectorRef> {
ensure!(
@@ -166,7 +166,7 @@ mod tests {
]));
let result = func
.eval(FunctionContext::default(), &[input0, input1])
.eval(&FunctionContext::default(), &[input0, input1])
.unwrap();
let result = result.as_ref();
@@ -199,7 +199,7 @@ mod tests {
Some("[3.0,2.0,2.0]".to_string()),
]));
let result = func.eval(FunctionContext::default(), &[input0, input1]);
let result = func.eval(&FunctionContext::default(), &[input0, input1]);
match result {
Err(Error::InvalidFuncArgs { err_msg, .. }) => {

View File

@@ -45,7 +45,7 @@ impl Function for BuildFunction {
Signature::nullary(Volatility::Immutable)
}
fn eval(&self, _func_ctx: FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
let build_info = common_version::build_info().to_string();
let v = Arc::new(StringVector::from(vec![build_info]));
Ok(v)
@@ -67,7 +67,7 @@ mod tests {
);
assert_eq!(build.signature(), Signature::nullary(Volatility::Immutable));
let build_info = common_version::build_info().to_string();
let vector = build.eval(FunctionContext::default(), &[]).unwrap();
let vector = build.eval(&FunctionContext::default(), &[]).unwrap();
let expect: VectorRef = Arc::new(StringVector::from(vec![build_info]));
assert_eq!(expect, vector);
}

View File

@@ -47,7 +47,7 @@ impl Function for DatabaseFunction {
Signature::nullary(Volatility::Immutable)
}
fn eval(&self, func_ctx: FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
let db = func_ctx.query_ctx.current_schema();
Ok(Arc::new(StringVector::from_slice(&[&db])) as _)
@@ -67,7 +67,7 @@ impl Function for CurrentSchemaFunction {
Signature::uniform(0, vec![], Volatility::Immutable)
}
fn eval(&self, func_ctx: FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
let db = func_ctx.query_ctx.current_schema();
Ok(Arc::new(StringVector::from_slice(&[&db])) as _)
@@ -87,7 +87,7 @@ impl Function for SessionUserFunction {
Signature::uniform(0, vec![], Volatility::Immutable)
}
fn eval(&self, func_ctx: FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
let user = func_ctx.query_ctx.current_user();
Ok(Arc::new(StringVector::from_slice(&[user.username()])) as _)
@@ -138,7 +138,7 @@ mod tests {
query_ctx,
..Default::default()
};
let vector = build.eval(func_ctx, &[]).unwrap();
let vector = build.eval(&func_ctx, &[]).unwrap();
let expect: VectorRef = Arc::new(StringVector::from(vec!["test_db"]));
assert_eq!(expect, vector);
}

View File

@@ -53,7 +53,7 @@ impl Function for PGGetUserByIdFunction {
)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$T| {
let col = scalar_unary_op::<<$T as LogicalPrimitiveType>::Native, String, _>(&columns[0], pg_get_user_by_id, &mut EvalContext::default())?;
Ok(Arc::new(col))

View File

@@ -53,7 +53,7 @@ impl Function for PGTableIsVisibleFunction {
)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$T| {
let col = scalar_unary_op::<<$T as LogicalPrimitiveType>::Native, bool, _>(&columns[0], pg_table_is_visible, &mut EvalContext::default())?;
Ok(Arc::new(col))

View File

@@ -44,7 +44,7 @@ impl Function for PGVersionFunction {
Signature::exact(vec![], Volatility::Immutable)
}
fn eval(&self, _func_ctx: FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, _func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
let result = StringVector::from(vec![format!(
"PostgreSQL 16.3 GreptimeDB {}",
env!("CARGO_PKG_VERSION")

View File

@@ -41,7 +41,7 @@ impl Function for TimezoneFunction {
Signature::nullary(Volatility::Immutable)
}
fn eval(&self, func_ctx: FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
let tz = func_ctx.query_ctx.timezone().to_string();
Ok(Arc::new(StringVector::from_slice(&[&tz])) as _)
@@ -77,7 +77,7 @@ mod tests {
query_ctx,
..Default::default()
};
let vector = build.eval(func_ctx, &[]).unwrap();
let vector = build.eval(&func_ctx, &[]).unwrap();
let expect: VectorRef = Arc::new(StringVector::from(vec!["UTC"]));
assert_eq!(expect, vector);
}

View File

@@ -45,7 +45,7 @@ impl Function for VersionFunction {
Signature::nullary(Volatility::Immutable)
}
fn eval(&self, func_ctx: FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
fn eval(&self, func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
let version = match func_ctx.query_ctx.channel() {
Channel::Mysql => {
format!(

View File

@@ -30,14 +30,6 @@ use statrs::StatsError;
#[snafu(visibility(pub))]
#[stack_trace_debug]
pub enum Error {
#[snafu(display("Failed to execute function"))]
ExecuteFunction {
#[snafu(source)]
error: DataFusionError,
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Unsupported input datatypes {:?} in function {}", datatypes, function))]
UnsupportedInputDataType {
function: String,
@@ -264,9 +256,7 @@ impl ErrorExt for Error {
| Error::ArrowCompute { .. }
| Error::FlownodeNotFound { .. } => StatusCode::EngineExecuteQuery,
Error::ExecuteFunction { error, .. } | Error::GeneralDataFusion { error, .. } => {
datafusion_status_code::<Self>(error, None)
}
Error::GeneralDataFusion { error, .. } => datafusion_status_code::<Self>(error, None),
Error::InvalidInputType { source, .. }
| Error::IntoVector { source, .. }

View File

@@ -17,23 +17,9 @@ use std::sync::Arc;
use datafusion_expr::ReturnTypeFunction as DfReturnTypeFunction;
use datatypes::arrow::datatypes::DataType as ArrowDataType;
use datatypes::prelude::{ConcreteDataType, DataType};
use datatypes::vectors::VectorRef;
use snafu::ResultExt;
use crate::error::{ExecuteFunctionSnafu, Result};
use crate::error::Result;
use crate::logical_plan::Accumulator;
use crate::prelude::{ColumnarValue, ScalarValue};
/// Scalar function
///
/// The Fn param is the wrapped function but be aware that the function will
/// be passed with the slice / vec of columnar values (either scalar or array)
/// with the exception of zero param function, where a singular element vec
/// will be passed. In that case the single element is a null array to indicate
/// the batch's row count (so that the generative zero-argument function can know
/// the result array size).
pub type ScalarFunctionImplementation =
Arc<dyn Fn(&[ColumnarValue]) -> Result<ColumnarValue> + Send + Sync>;
/// A function's return type
pub type ReturnTypeFunction =
@@ -51,48 +37,6 @@ pub type AccumulatorCreatorFunction =
pub type StateTypeFunction =
Arc<dyn Fn(&ConcreteDataType) -> Result<Arc<Vec<ConcreteDataType>>> + Send + Sync>;
/// decorates a function to handle [`ScalarValue`]s by converting them to arrays before calling the function
/// and vice-versa after evaluation.
pub fn make_scalar_function<F>(inner: F) -> ScalarFunctionImplementation
where
F: Fn(&[VectorRef]) -> Result<VectorRef> + Sync + Send + 'static,
{
Arc::new(move |args: &[ColumnarValue]| {
// first, identify if any of the arguments is an vector. If yes, store its `len`,
// as any scalar will need to be converted to an vector of len `len`.
let len = args
.iter()
.fold(Option::<usize>::None, |acc, arg| match arg {
ColumnarValue::Scalar(_) => acc,
ColumnarValue::Vector(v) => Some(v.len()),
});
// to array
// TODO(dennis): we create new vectors from Scalar on each call,
// should be optimized in the future.
let args: Result<Vec<_>> = if let Some(len) = len {
args.iter()
.map(|arg| arg.clone().try_into_vector(len))
.collect()
} else {
args.iter()
.map(|arg| arg.clone().try_into_vector(1))
.collect()
};
let result = (inner)(&args?);
// maybe back to scalar
if len.is_some() {
result.map(ColumnarValue::Vector)
} else {
Ok(ScalarValue::try_from_array(&result?.to_arrow_array(), 0)
.map(ColumnarValue::Scalar)
.context(ExecuteFunctionSnafu)?)
}
})
}
pub fn to_df_return_type(func: ReturnTypeFunction) -> DfReturnTypeFunction {
let df_func = move |data_types: &[ArrowDataType]| {
// DataFusion DataType -> ConcreteDataType
@@ -111,60 +55,3 @@ pub fn to_df_return_type(func: ReturnTypeFunction) -> DfReturnTypeFunction {
};
Arc::new(df_func)
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use datatypes::prelude::{ScalarVector, Vector};
use datatypes::vectors::BooleanVector;
use super::*;
#[test]
fn test_make_scalar_function() {
let and_fun = |args: &[VectorRef]| -> Result<VectorRef> {
let left = &args[0]
.as_any()
.downcast_ref::<BooleanVector>()
.expect("cast failed");
let right = &args[1]
.as_any()
.downcast_ref::<BooleanVector>()
.expect("cast failed");
let result = left
.iter_data()
.zip(right.iter_data())
.map(|(left, right)| match (left, right) {
(Some(left), Some(right)) => Some(left && right),
_ => None,
})
.collect::<BooleanVector>();
Ok(Arc::new(result) as VectorRef)
};
let and_fun = make_scalar_function(and_fun);
let args = vec![
ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))),
ColumnarValue::Vector(Arc::new(BooleanVector::from(vec![
true, false, false, true,
]))),
];
let vec = (and_fun)(&args).unwrap();
match vec {
ColumnarValue::Vector(vec) => {
let vec = vec.as_any().downcast_ref::<BooleanVector>().unwrap();
assert_eq!(4, vec.len());
for i in 0..4 {
assert_eq!(i == 0 || i == 3, vec.get_data(i).unwrap(), "Failed at {i}")
}
}
_ => unreachable!(),
}
}
}

View File

@@ -15,7 +15,6 @@
pub mod accumulator;
mod expr;
mod udaf;
mod udf;
use std::sync::Arc;
@@ -24,38 +23,14 @@ use datafusion::error::Result as DatafusionResult;
use datafusion::logical_expr::{LogicalPlan, LogicalPlanBuilder};
use datafusion_common::Column;
use datafusion_expr::col;
use datatypes::prelude::ConcreteDataType;
pub use expr::{build_filter_from_timestamp, build_same_type_ts_filter};
pub use self::accumulator::{Accumulator, AggregateFunctionCreator, AggregateFunctionCreatorRef};
pub use self::udaf::AggregateFunction;
pub use self::udf::ScalarUdf;
use crate::error::Result;
use crate::function::{ReturnTypeFunction, ScalarFunctionImplementation};
use crate::logical_plan::accumulator::*;
use crate::signature::{Signature, Volatility};
/// Creates a new UDF with a specific signature and specific return type.
/// This is a helper function to create a new UDF.
/// The function `create_udf` returns a subset of all possible `ScalarFunction`:
/// * the UDF has a fixed return type
/// * the UDF has a fixed signature (e.g. [f64, f64])
pub fn create_udf(
name: &str,
input_types: Vec<ConcreteDataType>,
return_type: Arc<ConcreteDataType>,
volatility: Volatility,
fun: ScalarFunctionImplementation,
) -> ScalarUdf {
let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone()));
ScalarUdf::new(
name,
&Signature::exact(input_types, volatility),
&return_type,
&fun,
)
}
pub fn create_aggregate_function(
name: String,
args_count: u8,
@@ -127,102 +102,17 @@ pub type SubstraitPlanDecoderRef = Arc<dyn SubstraitPlanDecoder + Send + Sync>;
mod tests {
use std::sync::Arc;
use datafusion_common::DFSchema;
use datafusion_expr::builder::LogicalTableSource;
use datafusion_expr::{
lit, ColumnarValue as DfColumnarValue, ScalarUDF as DfScalarUDF,
TypeSignature as DfTypeSignature,
};
use datatypes::arrow::array::BooleanArray;
use datafusion_expr::lit;
use datatypes::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datatypes::prelude::*;
use datatypes::vectors::{BooleanVector, VectorRef};
use datatypes::vectors::VectorRef;
use super::*;
use crate::error::Result;
use crate::function::{make_scalar_function, AccumulatorCreatorFunction};
use crate::prelude::ScalarValue;
use crate::function::AccumulatorCreatorFunction;
use crate::signature::TypeSignature;
#[test]
fn test_create_udf() {
let and_fun = |args: &[VectorRef]| -> Result<VectorRef> {
let left = &args[0]
.as_any()
.downcast_ref::<BooleanVector>()
.expect("cast failed");
let right = &args[1]
.as_any()
.downcast_ref::<BooleanVector>()
.expect("cast failed");
let result = left
.iter_data()
.zip(right.iter_data())
.map(|(left, right)| match (left, right) {
(Some(left), Some(right)) => Some(left && right),
_ => None,
})
.collect::<BooleanVector>();
Ok(Arc::new(result) as VectorRef)
};
let and_fun = make_scalar_function(and_fun);
let input_types = vec![
ConcreteDataType::boolean_datatype(),
ConcreteDataType::boolean_datatype(),
];
let return_type = Arc::new(ConcreteDataType::boolean_datatype());
let udf = create_udf(
"and",
input_types.clone(),
return_type.clone(),
Volatility::Immutable,
and_fun.clone(),
);
assert_eq!("and", udf.name);
assert!(
matches!(&udf.signature.type_signature, TypeSignature::Exact(ts) if ts.clone() == input_types)
);
assert_eq!(return_type, (udf.return_type)(&[]).unwrap());
// test into_df_udf
let df_udf: DfScalarUDF = udf.into();
assert_eq!("and", df_udf.name());
let types = vec![DataType::Boolean, DataType::Boolean];
assert!(
matches!(&df_udf.signature().type_signature, DfTypeSignature::Exact(ts) if ts.clone() == types)
);
assert_eq!(
DataType::Boolean,
df_udf
.return_type_from_exprs(&[], &DFSchema::empty(), &[])
.unwrap()
);
let args = vec![
DfColumnarValue::Scalar(ScalarValue::Boolean(Some(true))),
DfColumnarValue::Array(Arc::new(BooleanArray::from(vec![true, false, false, true]))),
];
// call the function
let result = df_udf.invoke_batch(&args, 4).unwrap();
match result {
DfColumnarValue::Array(arr) => {
let arr = arr.as_any().downcast_ref::<BooleanArray>().unwrap();
for i in 0..4 {
assert_eq!(i == 0 || i == 3, arr.value(i));
}
}
_ => unreachable!(),
}
}
#[derive(Debug)]
struct DummyAccumulator;

View File

@@ -1,134 +0,0 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Udf module contains foundational types that are used to represent UDFs.
//! It's modified from datafusion.
use std::any::Any;
use std::fmt;
use std::fmt::{Debug, Formatter};
use std::sync::Arc;
use datafusion_expr::{
ColumnarValue as DfColumnarValue,
ScalarFunctionImplementation as DfScalarFunctionImplementation, ScalarUDF as DfScalarUDF,
ScalarUDFImpl,
};
use datatypes::arrow::datatypes::DataType;
use crate::error::Result;
use crate::function::{ReturnTypeFunction, ScalarFunctionImplementation};
use crate::prelude::to_df_return_type;
use crate::signature::Signature;
/// Logical representation of a UDF.
#[derive(Clone)]
pub struct ScalarUdf {
/// name
pub name: String,
/// signature
pub signature: Signature,
/// Return type
pub return_type: ReturnTypeFunction,
/// actual implementation
pub fun: ScalarFunctionImplementation,
}
impl Debug for ScalarUdf {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_struct("ScalarUdf")
.field("name", &self.name)
.field("signature", &self.signature)
.field("fun", &"<FUNC>")
.finish()
}
}
impl ScalarUdf {
/// Create a new ScalarUdf
pub fn new(
name: &str,
signature: &Signature,
return_type: &ReturnTypeFunction,
fun: &ScalarFunctionImplementation,
) -> Self {
Self {
name: name.to_owned(),
signature: signature.clone(),
return_type: return_type.clone(),
fun: fun.clone(),
}
}
}
#[derive(Clone)]
struct DfUdfAdapter {
name: String,
signature: datafusion_expr::Signature,
return_type: datafusion_expr::ReturnTypeFunction,
fun: DfScalarFunctionImplementation,
}
impl Debug for DfUdfAdapter {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_struct("DfUdfAdapter")
.field("name", &self.name)
.field("signature", &self.signature)
.finish()
}
}
impl ScalarUDFImpl for DfUdfAdapter {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
&self.name
}
fn signature(&self) -> &datafusion_expr::Signature {
&self.signature
}
fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
(self.return_type)(arg_types).map(|ty| ty.as_ref().clone())
}
fn invoke(&self, args: &[DfColumnarValue]) -> datafusion_common::Result<DfColumnarValue> {
(self.fun)(args)
}
fn invoke_no_args(&self, number_rows: usize) -> datafusion_common::Result<DfColumnarValue> {
Ok((self.fun)(&[])?.into_array(number_rows)?.into())
}
}
impl From<ScalarUdf> for DfScalarUDF {
fn from(udf: ScalarUdf) -> Self {
DfScalarUDF::new_from_impl(DfUdfAdapter {
name: udf.name,
signature: udf.signature.into(),
return_type: to_df_return_type(udf.return_type),
fun: to_df_scalar_func(udf.fun),
})
}
}
fn to_df_scalar_func(fun: ScalarFunctionImplementation) -> DfScalarFunctionImplementation {
Arc::new(move |args: &[DfColumnarValue]| {
let args: Result<Vec<_>> = args.iter().map(TryFrom::try_from).collect();
let result = fun(&args?);
result.map(From::from).map_err(|e| e.into())
})
}

View File

@@ -16,7 +16,7 @@ pub use datafusion_common::ScalarValue;
pub use crate::columnar_value::ColumnarValue;
pub use crate::function::*;
pub use crate::logical_plan::{create_udf, AggregateFunction, ScalarUdf};
pub use crate::logical_plan::AggregateFunction;
pub use crate::signature::{Signature, TypeSignature, Volatility};
/// Default timestamp column name for Prometheus metrics.

View File

@@ -21,7 +21,6 @@ use async_trait::async_trait;
use common_error::ext::BoxedError;
use common_function::function::FunctionRef;
use common_function::scalars::aggregate::AggregateFunctionMetaRef;
use common_query::prelude::ScalarUdf;
use common_query::Output;
use common_runtime::runtime::{BuilderBuild, RuntimeTrait};
use common_runtime::Runtime;
@@ -77,8 +76,6 @@ impl QueryEngine for MockQueryEngine {
unimplemented!()
}
fn register_udf(&self, _udf: ScalarUdf) {}
fn register_aggregate_function(&self, _func: AggregateFunctionMetaRef) {}
fn register_function(&self, _func: FunctionRef) {}

View File

@@ -17,6 +17,7 @@ use std::collections::BTreeMap;
use std::sync::Arc;
use common_error::ext::BoxedError;
use common_function::function::FunctionContext;
use datafusion_substrait::extensions::Extensions;
use datatypes::data_type::ConcreteDataType as CDT;
use query::QueryEngine;
@@ -146,7 +147,7 @@ impl common_function::function::Function for TumbleFunction {
fn eval(
&self,
_func_ctx: common_function::function::FunctionContext,
_func_ctx: &FunctionContext,
_columns: &[datatypes::prelude::VectorRef],
) -> common_query::error::Result<datatypes::prelude::VectorRef> {
UnexpectedSnafu {

View File

@@ -262,8 +262,11 @@ fn create_df_context(
];
for udf in udfs {
df_context
.register_udf(create_udf(udf, ctx.clone(), Arc::new(FunctionState::default())).into());
df_context.register_udf(create_udf(
udf,
ctx.clone(),
Arc::new(FunctionState::default()),
));
}
Ok(df_context)

View File

@@ -159,14 +159,11 @@ mod tests {
}
fn matches_func() -> Arc<ScalarUDF> {
Arc::new(
create_udf(
FUNCTION_REGISTRY.get_function("matches").unwrap(),
QueryContext::arc(),
Default::default(),
)
.into(),
)
Arc::new(create_udf(
FUNCTION_REGISTRY.get_function("matches").unwrap(),
QueryContext::arc(),
Default::default(),
))
}
#[test]

View File

@@ -27,7 +27,6 @@ use common_catalog::consts::is_readonly_schema;
use common_error::ext::BoxedError;
use common_function::function::FunctionRef;
use common_function::scalars::aggregate::AggregateFunctionMetaRef;
use common_query::prelude::ScalarUdf;
use common_query::{Output, OutputData, OutputMeta};
use common_recordbatch::adapter::RecordBatchStreamAdapter;
use common_recordbatch::{EmptyRecordBatchStream, SendableRecordBatchStream};
@@ -455,11 +454,6 @@ impl QueryEngine for DatafusionQueryEngine {
self.state.register_aggregate_function(func);
}
/// Register a [`ScalarUdf`].
fn register_udf(&self, udf: ScalarUdf) {
self.state.register_udf(udf);
}
/// Register an UDF function.
/// Will override if the function with same name is already registered.
fn register_function(&self, func: FunctionRef) {

View File

@@ -155,14 +155,11 @@ impl ContextProvider for DfContextProviderAdapter {
self.engine_state.udf_function(name).map_or_else(
|| self.session_state.scalar_functions().get(name).cloned(),
|func| {
Some(Arc::new(
create_udf(
func,
self.query_ctx.clone(),
self.engine_state.function_state(),
)
.into(),
))
Some(Arc::new(create_udf(
func,
self.query_ctx.clone(),
self.engine_state.function_state(),
)))
},
)
}

View File

@@ -28,7 +28,6 @@ use common_function::handlers::{
FlowServiceHandlerRef, ProcedureServiceHandlerRef, TableMutationHandlerRef,
};
use common_function::scalars::aggregate::AggregateFunctionMetaRef;
use common_query::prelude::ScalarUdf;
use common_query::Output;
use datafusion_expr::LogicalPlan;
use datatypes::schema::Schema;
@@ -75,9 +74,6 @@ pub trait QueryEngine: Send + Sync {
/// Execute the given [`LogicalPlan`].
async fn execute(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result<Output>;
/// Register a [`ScalarUdf`].
fn register_udf(&self, udf: ScalarUdf);
/// Register an aggregate function.
///
/// # Panics

View File

@@ -27,7 +27,7 @@ use datafusion::execution::context::SessionState;
use datafusion::execution::registry::SerializerRegistry;
use datafusion::execution::{FunctionRegistry, SessionStateBuilder};
use datafusion::logical_expr::LogicalPlan;
use datafusion_expr::{ScalarUDF, UserDefinedLogicalNode};
use datafusion_expr::UserDefinedLogicalNode;
use greptime_proto::substrait_extension::MergeScan as PbMergeScan;
use prost::Message;
use session::context::QueryContextRef;
@@ -120,9 +120,11 @@ impl SubstraitPlanDecoder for DefaultPlanDecoder {
// e.g. The default UDF `to_char()` has an alias `date_format()`, if we register a UDF with the name `date_format()`
// before we build the session state, the UDF will be lost.
for func in FUNCTION_REGISTRY.functions() {
let udf: Arc<ScalarUDF> = Arc::new(
create_udf(func.clone(), self.query_ctx.clone(), Default::default()).into(),
);
let udf = Arc::new(create_udf(
func.clone(),
self.query_ctx.clone(),
Default::default(),
));
session_state
.register_udf(udf)
.context(RegisterUdfSnafu { name: func.name() })?;

View File

@@ -25,7 +25,6 @@ use common_function::handlers::{
};
use common_function::scalars::aggregate::AggregateFunctionMetaRef;
use common_function::state::FunctionState;
use common_query::prelude::ScalarUdf;
use common_telemetry::warn;
use datafusion::dataframe::DataFrame;
use datafusion::error::Result as DfResult;
@@ -242,11 +241,6 @@ impl QueryEngineState {
.collect()
}
/// Register a [`ScalarUdf`].
pub fn register_udf(&self, udf: ScalarUdf) {
self.df_context.register_udf(udf.into());
}
/// Register an aggregate function.
///
/// # Panics

View File

@@ -32,7 +32,6 @@ mod scipy_stats_norm_pdf;
mod time_range_filter_test;
mod function;
mod pow;
mod vec_product_test;
mod vec_sum_test;

View File

@@ -1,49 +0,0 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use common_query::error::Result;
use datatypes::prelude::{ScalarVector, Vector};
use datatypes::vectors::{UInt32Vector, VectorRef};
pub fn pow(args: &[VectorRef]) -> Result<VectorRef> {
assert_eq!(args.len(), 2);
let base = &args[0]
.as_any()
.downcast_ref::<UInt32Vector>()
.expect("cast failed");
let exponent = &args[1]
.as_any()
.downcast_ref::<UInt32Vector>()
.expect("cast failed");
assert_eq!(exponent.len(), base.len());
let iter = base
.iter_data()
.zip(exponent.iter_data())
.map(|(base, exponent)| {
match (base, exponent) {
// in arrow, any value can be null.
// Here we decide to make our UDF to return null when either base or exponent is null.
(Some(base), Some(exponent)) => Some(base.pow(exponent)),
_ => None,
}
});
let v = UInt32Vector::from_owned_iterator(iter);
Ok(Arc::new(v) as _)
}

View File

@@ -19,7 +19,6 @@ use catalog::RegisterTableRequest;
use common_base::Plugins;
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, NUMBERS_TABLE_ID};
use common_error::ext::BoxedError;
use common_query::prelude::{create_udf, make_scalar_function, Volatility};
use common_query::OutputData;
use common_recordbatch::{util, RecordBatch};
use datafusion::datasource::DefaultTableSource;
@@ -37,8 +36,6 @@ use crate::error::{QueryExecutionSnafu, Result};
use crate::parser::QueryLanguageParser;
use crate::query_engine::options::QueryOptions;
use crate::query_engine::QueryEngineFactory;
use crate::tests::exec_selection;
use crate::tests::pow::pow;
#[tokio::test]
async fn test_datafusion_query_engine() -> Result<()> {
@@ -150,46 +147,3 @@ async fn test_query_validate() -> Result<()> {
.is_err());
Ok(())
}
#[tokio::test]
async fn test_udf() -> Result<()> {
common_telemetry::init_default_ut_logging();
let catalog_list = catalog_manager()?;
let factory = QueryEngineFactory::new(catalog_list, None, None, None, None, false);
let engine = factory.query_engine();
let pow = make_scalar_function(pow);
let udf = create_udf(
// datafusion already supports pow, so we use a different name.
"my_pow",
vec![
ConcreteDataType::uint32_datatype(),
ConcreteDataType::uint32_datatype(),
],
Arc::new(ConcreteDataType::uint32_datatype()),
Volatility::Immutable,
pow,
);
engine.register_udf(udf);
let sql = "select my_pow(number, number) as p from numbers limit 10";
let numbers = exec_selection(engine, sql).await;
assert_eq!(1, numbers.len());
assert_eq!(numbers[0].num_columns(), 1);
assert_eq!(1, numbers[0].schema.num_columns());
assert_eq!("p", numbers[0].schema.column_schemas()[0].name);
let batch = &numbers[0];
assert_eq!(1, batch.num_columns());
assert_eq!(batch.column(0).len(), 10);
let expected: Vec<u32> = vec![1, 1, 4, 27, 256, 3125, 46656, 823543, 16777216, 387420489];
assert_eq!(
*batch.column(0),
Arc::new(UInt32Vector::from_slice(expected)) as VectorRef
);
Ok(())
}