From dee76f0a73c2252675b7837bedd695b47bc9ac58 Mon Sep 17 00:00:00 2001 From: LFC <990479+MichaelScofield@users.noreply.github.com> Date: Mon, 3 Mar 2025 13:52:44 +0800 Subject: [PATCH] refactor: simplify udf (#5617) * refactor: simplify udf * fix tests --- Cargo.lock | 2 + src/common/function/Cargo.toml | 2 + src/common/function/src/function.rs | 2 +- .../function/src/scalars/date/date_add.rs | 6 +- .../function/src/scalars/date/date_format.rs | 8 +- .../function/src/scalars/date/date_sub.rs | 6 +- .../src/scalars/expression/is_null.rs | 4 +- .../function/src/scalars/geo/geohash.rs | 4 +- src/common/function/src/scalars/geo/h3.rs | 42 ++--- .../function/src/scalars/geo/measure.rs | 6 +- .../function/src/scalars/geo/relation.rs | 6 +- src/common/function/src/scalars/geo/s2.rs | 8 +- src/common/function/src/scalars/geo/wkt.rs | 2 +- src/common/function/src/scalars/hll_count.rs | 8 +- .../function/src/scalars/json/json_get.rs | 12 +- .../function/src/scalars/json/json_is.rs | 4 +- .../src/scalars/json/json_path_exists.rs | 10 +- .../src/scalars/json/json_path_match.rs | 4 +- .../src/scalars/json/json_to_string.rs | 6 +- .../function/src/scalars/json/parse_json.rs | 4 +- src/common/function/src/scalars/matches.rs | 16 +- src/common/function/src/scalars/math.rs | 2 +- src/common/function/src/scalars/math/clamp.rs | 26 ++- .../function/src/scalars/math/modulo.rs | 14 +- src/common/function/src/scalars/math/pow.rs | 4 +- src/common/function/src/scalars/math/rate.rs | 4 +- src/common/function/src/scalars/test.rs | 2 +- .../src/scalars/timestamp/greatest.rs | 16 +- .../src/scalars/timestamp/to_unixtime.rs | 16 +- .../function/src/scalars/uddsketch_calc.rs | 8 +- src/common/function/src/scalars/udf.rs | 164 +++++++++++------- .../scalars/vector/convert/parse_vector.rs | 10 +- .../vector/convert/vector_to_string.rs | 4 +- .../function/src/scalars/vector/distance.rs | 26 +-- .../src/scalars/vector/elem_product.rs | 4 +- .../function/src/scalars/vector/elem_sum.rs | 4 +- .../function/src/scalars/vector/scalar_add.rs | 4 +- .../function/src/scalars/vector/scalar_mul.rs | 4 +- .../function/src/scalars/vector/vector_add.rs | 6 +- .../function/src/scalars/vector/vector_dim.rs | 6 +- .../function/src/scalars/vector/vector_div.rs | 8 +- .../function/src/scalars/vector/vector_mul.rs | 6 +- .../src/scalars/vector/vector_norm.rs | 4 +- .../function/src/scalars/vector/vector_sub.rs | 6 +- src/common/function/src/system/build.rs | 4 +- src/common/function/src/system/database.rs | 8 +- .../src/system/pg_catalog/pg_get_userbyid.rs | 2 +- .../src/system/pg_catalog/table_is_visible.rs | 2 +- .../function/src/system/pg_catalog/version.rs | 2 +- src/common/function/src/system/timezone.rs | 4 +- src/common/function/src/system/version.rs | 2 +- src/common/query/src/error.rs | 12 +- src/common/query/src/function.rs | 115 +----------- src/common/query/src/logical_plan.rs | 116 +------------ src/common/query/src/logical_plan/udf.rs | 134 -------------- src/common/query/src/prelude.rs | 2 +- src/datanode/src/tests.rs | 3 - src/flow/src/transform.rs | 3 +- src/frontend/src/instance/jaeger.rs | 7 +- .../index/fulltext_index/applier/builder.rs | 13 +- src/query/src/datafusion.rs | 6 - src/query/src/datafusion/planner.rs | 13 +- src/query/src/query_engine.rs | 4 - .../src/query_engine/default_serializer.rs | 10 +- src/query/src/query_engine/state.rs | 6 - src/query/src/tests.rs | 1 - src/query/src/tests/pow.rs | 49 ------ src/query/src/tests/query_engine_test.rs | 46 ----- 68 files changed, 323 insertions(+), 751 deletions(-) delete mode 100644 src/common/query/src/logical_plan/udf.rs delete mode 100644 src/query/src/tests/pow.rs diff --git a/Cargo.lock b/Cargo.lock index 9e940d4de7..dc70cbc8fb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2026,6 +2026,8 @@ dependencies = [ "common-time", "common-version", "datafusion", + "datafusion-common", + "datafusion-expr", "datatypes", "derive_more", "geo", diff --git a/src/common/function/Cargo.toml b/src/common/function/Cargo.toml index d2aa4a86c3..f736c7f377 100644 --- a/src/common/function/Cargo.toml +++ b/src/common/function/Cargo.toml @@ -28,6 +28,8 @@ common-telemetry.workspace = true common-time.workspace = true common-version.workspace = true datafusion.workspace = true +datafusion-common.workspace = true +datafusion-expr.workspace = true datatypes.workspace = true derive_more = { version = "1", default-features = false, features = ["display"] } geo = { version = "0.29", optional = true } diff --git a/src/common/function/src/function.rs b/src/common/function/src/function.rs index d7e2d310e2..999361dc19 100644 --- a/src/common/function/src/function.rs +++ b/src/common/function/src/function.rs @@ -63,7 +63,7 @@ pub trait Function: fmt::Display + Sync + Send { fn signature(&self) -> Signature; /// Evaluate the function, e.g. run/execute the function. - fn eval(&self, _func_ctx: FunctionContext, _columns: &[VectorRef]) -> Result; + fn eval(&self, ctx: &FunctionContext, columns: &[VectorRef]) -> Result; } pub type FunctionRef = Arc; diff --git a/src/common/function/src/scalars/date/date_add.rs b/src/common/function/src/scalars/date/date_add.rs index b2e5e4abe9..76cd3130c2 100644 --- a/src/common/function/src/scalars/date/date_add.rs +++ b/src/common/function/src/scalars/date/date_add.rs @@ -58,7 +58,7 @@ impl Function for DateAddFunction { ) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure!( columns.len() == 2, InvalidFuncArgsSnafu { @@ -146,7 +146,7 @@ mod tests { let time_vector = TimestampSecondVector::from(times.clone()); let interval_vector = IntervalDayTimeVector::from_vec(intervals); let args: Vec = vec![Arc::new(time_vector), Arc::new(interval_vector)]; - let vector = f.eval(FunctionContext::default(), &args).unwrap(); + let vector = f.eval(&FunctionContext::default(), &args).unwrap(); assert_eq!(4, vector.len()); for (i, _t) in times.iter().enumerate() { @@ -178,7 +178,7 @@ mod tests { let date_vector = DateVector::from(dates.clone()); let interval_vector = IntervalYearMonthVector::from_vec(intervals); let args: Vec = vec![Arc::new(date_vector), Arc::new(interval_vector)]; - let vector = f.eval(FunctionContext::default(), &args).unwrap(); + let vector = f.eval(&FunctionContext::default(), &args).unwrap(); assert_eq!(4, vector.len()); for (i, _t) in dates.iter().enumerate() { diff --git a/src/common/function/src/scalars/date/date_format.rs b/src/common/function/src/scalars/date/date_format.rs index fc82dbe06e..ba1a31b1f6 100644 --- a/src/common/function/src/scalars/date/date_format.rs +++ b/src/common/function/src/scalars/date/date_format.rs @@ -53,7 +53,7 @@ impl Function for DateFormatFunction { ) } - fn eval(&self, func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure!( columns.len() == 2, InvalidFuncArgsSnafu { @@ -202,7 +202,7 @@ mod tests { let time_vector = TimestampSecondVector::from(times.clone()); let interval_vector = StringVector::from_vec(formats); let args: Vec = vec![Arc::new(time_vector), Arc::new(interval_vector)]; - let vector = f.eval(FunctionContext::default(), &args).unwrap(); + let vector = f.eval(&FunctionContext::default(), &args).unwrap(); assert_eq!(4, vector.len()); for (i, _t) in times.iter().enumerate() { @@ -243,7 +243,7 @@ mod tests { let date_vector = DateVector::from(dates.clone()); let interval_vector = StringVector::from_vec(formats); let args: Vec = vec![Arc::new(date_vector), Arc::new(interval_vector)]; - let vector = f.eval(FunctionContext::default(), &args).unwrap(); + let vector = f.eval(&FunctionContext::default(), &args).unwrap(); assert_eq!(4, vector.len()); for (i, _t) in dates.iter().enumerate() { @@ -284,7 +284,7 @@ mod tests { let date_vector = DateTimeVector::from(dates.clone()); let interval_vector = StringVector::from_vec(formats); let args: Vec = vec![Arc::new(date_vector), Arc::new(interval_vector)]; - let vector = f.eval(FunctionContext::default(), &args).unwrap(); + let vector = f.eval(&FunctionContext::default(), &args).unwrap(); assert_eq!(4, vector.len()); for (i, _t) in dates.iter().enumerate() { diff --git a/src/common/function/src/scalars/date/date_sub.rs b/src/common/function/src/scalars/date/date_sub.rs index 1765e5b24a..da1155eebf 100644 --- a/src/common/function/src/scalars/date/date_sub.rs +++ b/src/common/function/src/scalars/date/date_sub.rs @@ -58,7 +58,7 @@ impl Function for DateSubFunction { ) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure!( columns.len() == 2, InvalidFuncArgsSnafu { @@ -151,7 +151,7 @@ mod tests { let time_vector = TimestampSecondVector::from(times.clone()); let interval_vector = IntervalDayTimeVector::from_vec(intervals); let args: Vec = vec![Arc::new(time_vector), Arc::new(interval_vector)]; - let vector = f.eval(FunctionContext::default(), &args).unwrap(); + let vector = f.eval(&FunctionContext::default(), &args).unwrap(); assert_eq!(4, vector.len()); for (i, _t) in times.iter().enumerate() { @@ -189,7 +189,7 @@ mod tests { let date_vector = DateVector::from(dates.clone()); let interval_vector = IntervalYearMonthVector::from_vec(intervals); let args: Vec = vec![Arc::new(date_vector), Arc::new(interval_vector)]; - let vector = f.eval(FunctionContext::default(), &args).unwrap(); + let vector = f.eval(&FunctionContext::default(), &args).unwrap(); assert_eq!(4, vector.len()); for (i, _t) in dates.iter().enumerate() { diff --git a/src/common/function/src/scalars/expression/is_null.rs b/src/common/function/src/scalars/expression/is_null.rs index ecae98cf87..8b396943bb 100644 --- a/src/common/function/src/scalars/expression/is_null.rs +++ b/src/common/function/src/scalars/expression/is_null.rs @@ -55,7 +55,7 @@ impl Function for IsNullFunction { fn eval( &self, - _func_ctx: FunctionContext, + _func_ctx: &FunctionContext, columns: &[VectorRef], ) -> common_query::error::Result { ensure!( @@ -102,7 +102,7 @@ mod tests { let values = vec![None, Some(3.0), None]; let args: Vec = vec![Arc::new(Float32Vector::from(values))]; - let vector = is_null.eval(FunctionContext::default(), &args).unwrap(); + let vector = is_null.eval(&FunctionContext::default(), &args).unwrap(); let expect: VectorRef = Arc::new(BooleanVector::from_vec(vec![true, false, true])); assert_eq!(expect, vector); } diff --git a/src/common/function/src/scalars/geo/geohash.rs b/src/common/function/src/scalars/geo/geohash.rs index d35a6a06ff..6fae2b79c9 100644 --- a/src/common/function/src/scalars/geo/geohash.rs +++ b/src/common/function/src/scalars/geo/geohash.rs @@ -118,7 +118,7 @@ impl Function for GeohashFunction { Signature::one_of(signatures, Volatility::Stable) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure!( columns.len() == 3, InvalidFuncArgsSnafu { @@ -218,7 +218,7 @@ impl Function for GeohashNeighboursFunction { Signature::one_of(signatures, Volatility::Stable) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure!( columns.len() == 3, InvalidFuncArgsSnafu { diff --git a/src/common/function/src/scalars/geo/h3.rs b/src/common/function/src/scalars/geo/h3.rs index e86c903dc2..b04f386967 100644 --- a/src/common/function/src/scalars/geo/h3.rs +++ b/src/common/function/src/scalars/geo/h3.rs @@ -119,7 +119,7 @@ impl Function for H3LatLngToCell { Signature::one_of(signatures, Volatility::Stable) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure_columns_n!(columns, 3); let lat_vec = &columns[0]; @@ -191,7 +191,7 @@ impl Function for H3LatLngToCellString { Signature::one_of(signatures, Volatility::Stable) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure_columns_n!(columns, 3); let lat_vec = &columns[0]; @@ -247,7 +247,7 @@ impl Function for H3CellToString { signature_of_cell() } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure_columns_n!(columns, 1); let cell_vec = &columns[0]; @@ -285,7 +285,7 @@ impl Function for H3StringToCell { ) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure_columns_n!(columns, 1); let string_vec = &columns[0]; @@ -337,7 +337,7 @@ impl Function for H3CellCenterLatLng { signature_of_cell() } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure_columns_n!(columns, 1); let cell_vec = &columns[0]; @@ -382,7 +382,7 @@ impl Function for H3CellResolution { signature_of_cell() } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure_columns_n!(columns, 1); let cell_vec = &columns[0]; @@ -418,7 +418,7 @@ impl Function for H3CellBase { signature_of_cell() } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure_columns_n!(columns, 1); let cell_vec = &columns[0]; @@ -454,7 +454,7 @@ impl Function for H3CellIsPentagon { signature_of_cell() } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure_columns_n!(columns, 1); let cell_vec = &columns[0]; @@ -490,7 +490,7 @@ impl Function for H3CellCenterChild { signature_of_cell_and_resolution() } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure_columns_n!(columns, 2); let cell_vec = &columns[0]; @@ -530,7 +530,7 @@ impl Function for H3CellParent { signature_of_cell_and_resolution() } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure_columns_n!(columns, 2); let cell_vec = &columns[0]; @@ -570,7 +570,7 @@ impl Function for H3CellToChildren { signature_of_cell_and_resolution() } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure_columns_n!(columns, 2); let cell_vec = &columns[0]; @@ -619,7 +619,7 @@ impl Function for H3CellToChildrenSize { signature_of_cell_and_resolution() } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure_columns_n!(columns, 2); let cell_vec = &columns[0]; @@ -656,7 +656,7 @@ impl Function for H3CellToChildPos { signature_of_cell_and_resolution() } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure_columns_n!(columns, 2); let cell_vec = &columns[0]; @@ -706,7 +706,7 @@ impl Function for H3ChildPosToCell { Signature::one_of(signatures, Volatility::Stable) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure_columns_n!(columns, 3); let pos_vec = &columns[0]; @@ -747,7 +747,7 @@ impl Function for H3GridDisk { signature_of_cell_and_distance() } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure_columns_n!(columns, 2); let cell_vec = &columns[0]; @@ -800,7 +800,7 @@ impl Function for H3GridDiskDistances { signature_of_cell_and_distance() } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure_columns_n!(columns, 2); let cell_vec = &columns[0]; @@ -850,7 +850,7 @@ impl Function for H3GridDistance { signature_of_double_cells() } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure_columns_n!(columns, 2); let cell_this_vec = &columns[0]; @@ -906,7 +906,7 @@ impl Function for H3GridPathCells { signature_of_double_cells() } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure_columns_n!(columns, 2); let cell_this_vec = &columns[0]; @@ -988,7 +988,7 @@ impl Function for H3CellContains { Signature::one_of(signatures, Volatility::Stable) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure_columns_n!(columns, 2); let cells_vec = &columns[0]; @@ -1042,7 +1042,7 @@ impl Function for H3CellDistanceSphereKm { signature_of_double_cells() } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure_columns_n!(columns, 2); let cell_this_vec = &columns[0]; @@ -1097,7 +1097,7 @@ impl Function for H3CellDistanceEuclideanDegree { signature_of_double_cells() } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure_columns_n!(columns, 2); let cell_this_vec = &columns[0]; diff --git a/src/common/function/src/scalars/geo/measure.rs b/src/common/function/src/scalars/geo/measure.rs index a182259903..c95959a4a7 100644 --- a/src/common/function/src/scalars/geo/measure.rs +++ b/src/common/function/src/scalars/geo/measure.rs @@ -54,7 +54,7 @@ impl Function for STDistance { ) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure_columns_n!(columns, 2); let wkt_this_vec = &columns[0]; @@ -108,7 +108,7 @@ impl Function for STDistanceSphere { ) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure_columns_n!(columns, 2); let wkt_this_vec = &columns[0]; @@ -169,7 +169,7 @@ impl Function for STArea { ) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure_columns_n!(columns, 1); let wkt_vec = &columns[0]; diff --git a/src/common/function/src/scalars/geo/relation.rs b/src/common/function/src/scalars/geo/relation.rs index 570a7c7f56..45e99888ac 100644 --- a/src/common/function/src/scalars/geo/relation.rs +++ b/src/common/function/src/scalars/geo/relation.rs @@ -51,7 +51,7 @@ impl Function for STContains { ) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure_columns_n!(columns, 2); let wkt_this_vec = &columns[0]; @@ -105,7 +105,7 @@ impl Function for STWithin { ) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure_columns_n!(columns, 2); let wkt_this_vec = &columns[0]; @@ -159,7 +159,7 @@ impl Function for STIntersects { ) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure_columns_n!(columns, 2); let wkt_this_vec = &columns[0]; diff --git a/src/common/function/src/scalars/geo/s2.rs b/src/common/function/src/scalars/geo/s2.rs index 6e40dc300f..803a276968 100644 --- a/src/common/function/src/scalars/geo/s2.rs +++ b/src/common/function/src/scalars/geo/s2.rs @@ -84,7 +84,7 @@ impl Function for S2LatLngToCell { Signature::one_of(signatures, Volatility::Stable) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure_columns_n!(columns, 2); let lat_vec = &columns[0]; @@ -138,7 +138,7 @@ impl Function for S2CellLevel { signature_of_cell() } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure_columns_n!(columns, 1); let cell_vec = &columns[0]; @@ -174,7 +174,7 @@ impl Function for S2CellToToken { signature_of_cell() } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure_columns_n!(columns, 1); let cell_vec = &columns[0]; @@ -210,7 +210,7 @@ impl Function for S2CellParent { signature_of_cell_and_level() } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure_columns_n!(columns, 2); let cell_vec = &columns[0]; diff --git a/src/common/function/src/scalars/geo/wkt.rs b/src/common/function/src/scalars/geo/wkt.rs index 3602eb5d36..2a43ee31c0 100644 --- a/src/common/function/src/scalars/geo/wkt.rs +++ b/src/common/function/src/scalars/geo/wkt.rs @@ -63,7 +63,7 @@ impl Function for LatLngToPointWkt { Signature::one_of(signatures, Volatility::Stable) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure_columns_n!(columns, 2); let lat_vec = &columns[0]; diff --git a/src/common/function/src/scalars/hll_count.rs b/src/common/function/src/scalars/hll_count.rs index e2a00d9d49..6cde0c7064 100644 --- a/src/common/function/src/scalars/hll_count.rs +++ b/src/common/function/src/scalars/hll_count.rs @@ -71,7 +71,7 @@ impl Function for HllCalcFunction { ) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { if columns.len() != 1 { return InvalidFuncArgsSnafu { err_msg: format!("hll_count expects 1 argument, got {}", columns.len()), @@ -142,7 +142,7 @@ mod tests { let serialized_bytes = bincode::serialize(&hll).unwrap(); let args: Vec = vec![Arc::new(BinaryVector::from(vec![Some(serialized_bytes)]))]; - let result = function.eval(FunctionContext::default(), &args).unwrap(); + let result = function.eval(&FunctionContext::default(), &args).unwrap(); assert_eq!(result.len(), 1); // Test cardinality estimate @@ -159,7 +159,7 @@ mod tests { // Test with invalid number of arguments let args: Vec = vec![]; - let result = function.eval(FunctionContext::default(), &args); + let result = function.eval(&FunctionContext::default(), &args); assert!(result.is_err()); assert!(result .unwrap_err() @@ -168,7 +168,7 @@ mod tests { // Test with invalid binary data let args: Vec = vec![Arc::new(BinaryVector::from(vec![Some(vec![1, 2, 3])]))]; // Invalid binary data - let result = function.eval(FunctionContext::default(), &args).unwrap(); + let result = function.eval(&FunctionContext::default(), &args).unwrap(); assert_eq!(result.len(), 1); assert!(matches!(result.get(0), datatypes::value::Value::Null)); } diff --git a/src/common/function/src/scalars/json/json_get.rs b/src/common/function/src/scalars/json/json_get.rs index d31f7a0c6e..8dd35a54dd 100644 --- a/src/common/function/src/scalars/json/json_get.rs +++ b/src/common/function/src/scalars/json/json_get.rs @@ -72,7 +72,7 @@ macro_rules! json_get { ) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure!( columns.len() == 2, InvalidFuncArgsSnafu { @@ -175,7 +175,7 @@ impl Function for JsonGetString { ) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure!( columns.len() == 2, InvalidFuncArgsSnafu { @@ -282,7 +282,7 @@ mod tests { let path_vector = StringVector::from_vec(paths); let args: Vec = vec![Arc::new(json_vector), Arc::new(path_vector)]; let vector = json_get_int - .eval(FunctionContext::default(), &args) + .eval(&FunctionContext::default(), &args) .unwrap(); assert_eq!(3, vector.len()); @@ -335,7 +335,7 @@ mod tests { let path_vector = StringVector::from_vec(paths); let args: Vec = vec![Arc::new(json_vector), Arc::new(path_vector)]; let vector = json_get_float - .eval(FunctionContext::default(), &args) + .eval(&FunctionContext::default(), &args) .unwrap(); assert_eq!(3, vector.len()); @@ -388,7 +388,7 @@ mod tests { let path_vector = StringVector::from_vec(paths); let args: Vec = vec![Arc::new(json_vector), Arc::new(path_vector)]; let vector = json_get_bool - .eval(FunctionContext::default(), &args) + .eval(&FunctionContext::default(), &args) .unwrap(); assert_eq!(3, vector.len()); @@ -441,7 +441,7 @@ mod tests { let path_vector = StringVector::from_vec(paths); let args: Vec = vec![Arc::new(json_vector), Arc::new(path_vector)]; let vector = json_get_string - .eval(FunctionContext::default(), &args) + .eval(&FunctionContext::default(), &args) .unwrap(); assert_eq!(3, vector.len()); diff --git a/src/common/function/src/scalars/json/json_is.rs b/src/common/function/src/scalars/json/json_is.rs index e0580ad9d4..8a712305d7 100644 --- a/src/common/function/src/scalars/json/json_is.rs +++ b/src/common/function/src/scalars/json/json_is.rs @@ -45,7 +45,7 @@ macro_rules! json_is { Signature::exact(vec![ConcreteDataType::json_datatype()], Volatility::Immutable) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure!( columns.len() == 1, InvalidFuncArgsSnafu { @@ -202,7 +202,7 @@ mod tests { let args: Vec = vec![Arc::new(json_vector)]; for (func, expected_result) in json_is_functions.iter().zip(expected_results.iter()) { - let vector = func.eval(FunctionContext::default(), &args).unwrap(); + let vector = func.eval(&FunctionContext::default(), &args).unwrap(); assert_eq!(vector.len(), json_strings.len()); for (i, expected) in expected_result.iter().enumerate() { diff --git a/src/common/function/src/scalars/json/json_path_exists.rs b/src/common/function/src/scalars/json/json_path_exists.rs index 69e37cdbe4..1db57d3871 100644 --- a/src/common/function/src/scalars/json/json_path_exists.rs +++ b/src/common/function/src/scalars/json/json_path_exists.rs @@ -64,7 +64,7 @@ impl Function for JsonPathExistsFunction { ) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure!( columns.len() == 2, InvalidFuncArgsSnafu { @@ -204,7 +204,7 @@ mod tests { let path_vector = StringVector::from_vec(paths); let args: Vec = vec![Arc::new(json_vector), Arc::new(path_vector)]; let vector = json_path_exists - .eval(FunctionContext::default(), &args) + .eval(&FunctionContext::default(), &args) .unwrap(); // Test for non-nulls. @@ -222,7 +222,7 @@ mod tests { let illegal_path = StringVector::from_vec(vec!["$..a"]); let args: Vec = vec![Arc::new(json), Arc::new(illegal_path)]; - let err = json_path_exists.eval(FunctionContext::default(), &args); + let err = json_path_exists.eval(&FunctionContext::default(), &args); assert!(err.is_err()); // Test for nulls. @@ -235,11 +235,11 @@ mod tests { let args: Vec = vec![Arc::new(null_json), Arc::new(path)]; let result1 = json_path_exists - .eval(FunctionContext::default(), &args) + .eval(&FunctionContext::default(), &args) .unwrap(); let args: Vec = vec![Arc::new(json), Arc::new(null_path)]; let result2 = json_path_exists - .eval(FunctionContext::default(), &args) + .eval(&FunctionContext::default(), &args) .unwrap(); assert_eq!(result1.len(), 1); diff --git a/src/common/function/src/scalars/json/json_path_match.rs b/src/common/function/src/scalars/json/json_path_match.rs index 8ea1bf082b..db4b3a2010 100644 --- a/src/common/function/src/scalars/json/json_path_match.rs +++ b/src/common/function/src/scalars/json/json_path_match.rs @@ -50,7 +50,7 @@ impl Function for JsonPathMatchFunction { ) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure!( columns.len() == 2, InvalidFuncArgsSnafu { @@ -180,7 +180,7 @@ mod tests { let path_vector = StringVector::from(paths); let args: Vec = vec![Arc::new(json_vector), Arc::new(path_vector)]; let vector = json_path_match - .eval(FunctionContext::default(), &args) + .eval(&FunctionContext::default(), &args) .unwrap(); assert_eq!(7, vector.len()); diff --git a/src/common/function/src/scalars/json/json_to_string.rs b/src/common/function/src/scalars/json/json_to_string.rs index 9873000d6e..61b1ac6e7a 100644 --- a/src/common/function/src/scalars/json/json_to_string.rs +++ b/src/common/function/src/scalars/json/json_to_string.rs @@ -47,7 +47,7 @@ impl Function for JsonToStringFunction { ) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure!( columns.len() == 1, InvalidFuncArgsSnafu { @@ -154,7 +154,7 @@ mod tests { let json_vector = BinaryVector::from_vec(jsonbs); let args: Vec = vec![Arc::new(json_vector)]; let vector = json_to_string - .eval(FunctionContext::default(), &args) + .eval(&FunctionContext::default(), &args) .unwrap(); assert_eq!(3, vector.len()); @@ -168,7 +168,7 @@ mod tests { let invalid_jsonb = vec![b"invalid json"]; let invalid_json_vector = BinaryVector::from_vec(invalid_jsonb); let args: Vec = vec![Arc::new(invalid_json_vector)]; - let vector = json_to_string.eval(FunctionContext::default(), &args); + let vector = json_to_string.eval(&FunctionContext::default(), &args); assert!(vector.is_err()); } } diff --git a/src/common/function/src/scalars/json/parse_json.rs b/src/common/function/src/scalars/json/parse_json.rs index 64300838d8..0c19b02522 100644 --- a/src/common/function/src/scalars/json/parse_json.rs +++ b/src/common/function/src/scalars/json/parse_json.rs @@ -47,7 +47,7 @@ impl Function for ParseJsonFunction { ) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure!( columns.len() == 1, InvalidFuncArgsSnafu { @@ -152,7 +152,7 @@ mod tests { let json_string_vector = StringVector::from_vec(json_strings.to_vec()); let args: Vec = vec![Arc::new(json_string_vector)]; - let vector = parse_json.eval(FunctionContext::default(), &args).unwrap(); + let vector = parse_json.eval(&FunctionContext::default(), &args).unwrap(); assert_eq!(3, vector.len()); for (i, gt) in jsonbs.iter().enumerate() { diff --git a/src/common/function/src/scalars/matches.rs b/src/common/function/src/scalars/matches.rs index c41b394636..edeffbb2f9 100644 --- a/src/common/function/src/scalars/matches.rs +++ b/src/common/function/src/scalars/matches.rs @@ -72,7 +72,7 @@ impl Function for MatchesFunction { } // TODO: read case-sensitive config - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure!( columns.len() == 2, InvalidFuncArgsSnafu { @@ -82,6 +82,12 @@ impl Function for MatchesFunction { ), } ); + + let data_column = &columns[0]; + if data_column.is_empty() { + return Ok(Arc::new(BooleanVector::from(Vec::::with_capacity(0)))); + } + let pattern_vector = &columns[1] .cast(&ConcreteDataType::string_datatype()) .context(InvalidInputTypeSnafu { @@ -89,12 +95,12 @@ impl Function for MatchesFunction { })?; // Safety: both length and type are checked before let pattern = pattern_vector.get(0).as_string().unwrap(); - self.eval(columns[0].clone(), pattern) + self.eval(data_column, pattern) } } impl MatchesFunction { - fn eval(&self, data: VectorRef, pattern: String) -> Result { + fn eval(&self, data: &VectorRef, pattern: String) -> Result { let col_name = "data"; let parser_context = ParserContext::default(); let raw_ast = parser_context.parse_pattern(&pattern)?; @@ -1309,7 +1315,7 @@ mod test { "The quick brown fox jumps over dog", "The quick brown fox jumps over the dog", ]; - let input_vector = Arc::new(StringVector::from(input_data)); + let input_vector: VectorRef = Arc::new(StringVector::from(input_data)); let cases = [ // basic cases ("quick", vec![true, false, true, true, true, true, true]), @@ -1400,7 +1406,7 @@ mod test { let f = MatchesFunction; for (pattern, expected) in cases { - let actual: VectorRef = f.eval(input_vector.clone(), pattern.to_string()).unwrap(); + let actual: VectorRef = f.eval(&input_vector, pattern.to_string()).unwrap(); let expected: VectorRef = Arc::new(BooleanVector::from(expected)) as _; assert_eq!(expected, actual, "{pattern}"); } diff --git a/src/common/function/src/scalars/math.rs b/src/common/function/src/scalars/math.rs index 6635e70b17..152ba999f3 100644 --- a/src/common/function/src/scalars/math.rs +++ b/src/common/function/src/scalars/math.rs @@ -80,7 +80,7 @@ impl Function for RangeFunction { Signature::variadic_any(Volatility::Immutable) } - fn eval(&self, _func_ctx: FunctionContext, _columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result { Err(DataFusionError::Internal( "range_fn just a empty function used in range select, It should not be eval!".into(), )) diff --git a/src/common/function/src/scalars/math/clamp.rs b/src/common/function/src/scalars/math/clamp.rs index dc73aed158..6c19da8212 100644 --- a/src/common/function/src/scalars/math/clamp.rs +++ b/src/common/function/src/scalars/math/clamp.rs @@ -27,7 +27,7 @@ use datatypes::vectors::PrimitiveVector; use datatypes::with_match_primitive_type_id; use snafu::{ensure, OptionExt}; -use crate::function::Function; +use crate::function::{Function, FunctionContext}; #[derive(Clone, Debug, Default)] pub struct ClampFunction; @@ -49,11 +49,7 @@ impl Function for ClampFunction { Signature::uniform(3, ConcreteDataType::numerics(), Volatility::Immutable) } - fn eval( - &self, - _func_ctx: crate::function::FunctionContext, - columns: &[VectorRef], - ) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure!( columns.len() == 3, InvalidFuncArgsSnafu { @@ -209,7 +205,7 @@ mod test { Arc::new(Int64Vector::from_vec(vec![max])) as _, ]; let result = func - .eval(FunctionContext::default(), args.as_slice()) + .eval(&FunctionContext::default(), args.as_slice()) .unwrap(); let expected: VectorRef = Arc::new(Int64Vector::from(expected)); assert_eq!(expected, result); @@ -253,7 +249,7 @@ mod test { Arc::new(UInt64Vector::from_vec(vec![max])) as _, ]; let result = func - .eval(FunctionContext::default(), args.as_slice()) + .eval(&FunctionContext::default(), args.as_slice()) .unwrap(); let expected: VectorRef = Arc::new(UInt64Vector::from(expected)); assert_eq!(expected, result); @@ -297,7 +293,7 @@ mod test { Arc::new(Float64Vector::from_vec(vec![max])) as _, ]; let result = func - .eval(FunctionContext::default(), args.as_slice()) + .eval(&FunctionContext::default(), args.as_slice()) .unwrap(); let expected: VectorRef = Arc::new(Float64Vector::from(expected)); assert_eq!(expected, result); @@ -317,7 +313,7 @@ mod test { Arc::new(Int64Vector::from_vec(vec![max])) as _, ]; let result = func - .eval(FunctionContext::default(), args.as_slice()) + .eval(&FunctionContext::default(), args.as_slice()) .unwrap(); let expected: VectorRef = Arc::new(Int64Vector::from(vec![Some(4)])); assert_eq!(expected, result); @@ -335,7 +331,7 @@ mod test { Arc::new(Float64Vector::from_vec(vec![min])) as _, Arc::new(Float64Vector::from_vec(vec![max])) as _, ]; - let result = func.eval(FunctionContext::default(), args.as_slice()); + let result = func.eval(&FunctionContext::default(), args.as_slice()); assert!(result.is_err()); } @@ -351,7 +347,7 @@ mod test { Arc::new(Int64Vector::from_vec(vec![min])) as _, Arc::new(UInt64Vector::from_vec(vec![max])) as _, ]; - let result = func.eval(FunctionContext::default(), args.as_slice()); + let result = func.eval(&FunctionContext::default(), args.as_slice()); assert!(result.is_err()); } @@ -367,7 +363,7 @@ mod test { Arc::new(Float64Vector::from_vec(vec![min, min])) as _, Arc::new(Float64Vector::from_vec(vec![max])) as _, ]; - let result = func.eval(FunctionContext::default(), args.as_slice()); + let result = func.eval(&FunctionContext::default(), args.as_slice()); assert!(result.is_err()); } @@ -381,7 +377,7 @@ mod test { Arc::new(Float64Vector::from(input)) as _, Arc::new(Float64Vector::from_vec(vec![min])) as _, ]; - let result = func.eval(FunctionContext::default(), args.as_slice()); + let result = func.eval(&FunctionContext::default(), args.as_slice()); assert!(result.is_err()); } @@ -395,7 +391,7 @@ mod test { Arc::new(StringVector::from_vec(vec!["bar"])) as _, Arc::new(StringVector::from_vec(vec!["baz"])) as _, ]; - let result = func.eval(FunctionContext::default(), args.as_slice()); + let result = func.eval(&FunctionContext::default(), args.as_slice()); assert!(result.is_err()); } } diff --git a/src/common/function/src/scalars/math/modulo.rs b/src/common/function/src/scalars/math/modulo.rs index d9ea174488..b9d19e9818 100644 --- a/src/common/function/src/scalars/math/modulo.rs +++ b/src/common/function/src/scalars/math/modulo.rs @@ -58,7 +58,7 @@ impl Function for ModuloFunction { Signature::uniform(2, ConcreteDataType::numerics(), Volatility::Immutable) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure!( columns.len() == 2, InvalidFuncArgsSnafu { @@ -126,7 +126,7 @@ mod tests { Arc::new(Int32Vector::from_vec(nums.clone())), Arc::new(Int32Vector::from_vec(divs.clone())), ]; - let result = function.eval(FunctionContext::default(), &args).unwrap(); + let result = function.eval(&FunctionContext::default(), &args).unwrap(); assert_eq!(result.len(), 4); for i in 0..4 { let p: i64 = (nums[i] % divs[i]) as i64; @@ -158,7 +158,7 @@ mod tests { Arc::new(UInt32Vector::from_vec(nums.clone())), Arc::new(UInt32Vector::from_vec(divs.clone())), ]; - let result = function.eval(FunctionContext::default(), &args).unwrap(); + let result = function.eval(&FunctionContext::default(), &args).unwrap(); assert_eq!(result.len(), 4); for i in 0..4 { let p: u64 = (nums[i] % divs[i]) as u64; @@ -190,7 +190,7 @@ mod tests { Arc::new(Float64Vector::from_vec(nums.clone())), Arc::new(Float64Vector::from_vec(divs.clone())), ]; - let result = function.eval(FunctionContext::default(), &args).unwrap(); + let result = function.eval(&FunctionContext::default(), &args).unwrap(); assert_eq!(result.len(), 4); for i in 0..4 { let p: f64 = nums[i] % divs[i]; @@ -209,7 +209,7 @@ mod tests { Arc::new(Int32Vector::from_vec(nums.clone())), Arc::new(Int32Vector::from_vec(divs.clone())), ]; - let result = function.eval(FunctionContext::default(), &args); + let result = function.eval(&FunctionContext::default(), &args); assert!(result.is_err()); let err_msg = result.unwrap_err().output_msg(); assert_eq!( @@ -220,7 +220,7 @@ mod tests { let nums = vec![27]; let args: Vec = vec![Arc::new(Int32Vector::from_vec(nums.clone()))]; - let result = function.eval(FunctionContext::default(), &args); + let result = function.eval(&FunctionContext::default(), &args); assert!(result.is_err()); let err_msg = result.unwrap_err().output_msg(); assert!( @@ -233,7 +233,7 @@ mod tests { Arc::new(StringVector::from(nums.clone())), Arc::new(StringVector::from(divs.clone())), ]; - let result = function.eval(FunctionContext::default(), &args); + let result = function.eval(&FunctionContext::default(), &args); assert!(result.is_err()); let err_msg = result.unwrap_err().output_msg(); assert!(err_msg.contains("Invalid arithmetic operation")); diff --git a/src/common/function/src/scalars/math/pow.rs b/src/common/function/src/scalars/math/pow.rs index 5e6cc0f089..171c06a694 100644 --- a/src/common/function/src/scalars/math/pow.rs +++ b/src/common/function/src/scalars/math/pow.rs @@ -44,7 +44,7 @@ impl Function for PowFunction { Signature::uniform(2, ConcreteDataType::numerics(), Volatility::Immutable) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| { with_match_primitive_type_id!(columns[1].data_type().logical_type_id(), |$T| { let col = scalar_binary_op::<<$S as LogicalPrimitiveType>::Native, <$T as LogicalPrimitiveType>::Native, f64, _>(&columns[0], &columns[1], scalar_pow, &mut EvalContext::default())?; @@ -109,7 +109,7 @@ mod tests { Arc::new(Int8Vector::from_vec(bases.clone())), ]; - let vector = pow.eval(FunctionContext::default(), &args).unwrap(); + let vector = pow.eval(&FunctionContext::default(), &args).unwrap(); assert_eq!(3, vector.len()); for i in 0..3 { diff --git a/src/common/function/src/scalars/math/rate.rs b/src/common/function/src/scalars/math/rate.rs index 7afc07177d..e296fb9496 100644 --- a/src/common/function/src/scalars/math/rate.rs +++ b/src/common/function/src/scalars/math/rate.rs @@ -48,7 +48,7 @@ impl Function for RateFunction { Signature::uniform(2, ConcreteDataType::numerics(), Volatility::Immutable) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { let val = &columns[0].to_arrow_array(); let val_0 = val.slice(0, val.len() - 1); let val_1 = val.slice(1, val.len() - 1); @@ -100,7 +100,7 @@ mod tests { Arc::new(Float32Vector::from_vec(values)), Arc::new(Int64Vector::from_vec(ts)), ]; - let vector = rate.eval(FunctionContext::default(), &args).unwrap(); + let vector = rate.eval(&FunctionContext::default(), &args).unwrap(); let expect: VectorRef = Arc::new(Float64Vector::from_vec(vec![2.0, 3.0])); assert_eq!(expect, vector); } diff --git a/src/common/function/src/scalars/test.rs b/src/common/function/src/scalars/test.rs index 573c2e715b..0fe05d3f27 100644 --- a/src/common/function/src/scalars/test.rs +++ b/src/common/function/src/scalars/test.rs @@ -45,7 +45,7 @@ impl Function for TestAndFunction { ) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { let col = scalar_binary_op::( &columns[0], &columns[1], diff --git a/src/common/function/src/scalars/timestamp/greatest.rs b/src/common/function/src/scalars/timestamp/greatest.rs index 671a023d06..74cd6ad7d7 100644 --- a/src/common/function/src/scalars/timestamp/greatest.rs +++ b/src/common/function/src/scalars/timestamp/greatest.rs @@ -97,7 +97,7 @@ impl Function for GreatestFunction { ) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure!( columns.len() == 2, InvalidFuncArgsSnafu { @@ -191,7 +191,9 @@ mod tests { ])) as _, ]; - let result = function.eval(FunctionContext::default(), &columns).unwrap(); + let result = function + .eval(&FunctionContext::default(), &columns) + .unwrap(); let result = result.as_any().downcast_ref::().unwrap(); assert_eq!(result.len(), 2); assert_eq!( @@ -222,7 +224,9 @@ mod tests { Arc::new(DateVector::from_slice(vec![0, 1])) as _, ]; - let result = function.eval(FunctionContext::default(), &columns).unwrap(); + let result = function + .eval(&FunctionContext::default(), &columns) + .unwrap(); let result = result.as_any().downcast_ref::().unwrap(); assert_eq!(result.len(), 2); assert_eq!( @@ -253,7 +257,9 @@ mod tests { Arc::new(DateTimeVector::from_slice(vec![0, 1])) as _, ]; - let result = function.eval(FunctionContext::default(), &columns).unwrap(); + let result = function + .eval(&FunctionContext::default(), &columns) + .unwrap(); let result = result.as_any().downcast_ref::().unwrap(); assert_eq!(result.len(), 2); assert_eq!( @@ -282,7 +288,7 @@ mod tests { Arc::new([]::from_slice(vec![0, 1])) as _, ]; - let result = function.eval(FunctionContext::default(), &columns).unwrap(); + let result = function.eval(&FunctionContext::default(), &columns).unwrap(); let result = result.as_any().downcast_ref::<[]>().unwrap(); assert_eq!(result.len(), 2); assert_eq!( diff --git a/src/common/function/src/scalars/timestamp/to_unixtime.rs b/src/common/function/src/scalars/timestamp/to_unixtime.rs index cc297942d1..11b014839a 100644 --- a/src/common/function/src/scalars/timestamp/to_unixtime.rs +++ b/src/common/function/src/scalars/timestamp/to_unixtime.rs @@ -92,7 +92,7 @@ impl Function for ToUnixtimeFunction { ) } - fn eval(&self, func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure!( columns.len() == 1, InvalidFuncArgsSnafu { @@ -108,7 +108,7 @@ impl Function for ToUnixtimeFunction { match columns[0].data_type() { ConcreteDataType::String(_) => Ok(Arc::new(Int64Vector::from( (0..vector.len()) - .map(|i| convert_to_seconds(&vector.get(i).to_string(), &func_ctx)) + .map(|i| convert_to_seconds(&vector.get(i).to_string(), ctx)) .collect::>(), ))), ConcreteDataType::Int64(_) | ConcreteDataType::Int32(_) => { @@ -187,7 +187,7 @@ mod tests { ]; let results = [Some(1677652502), None, Some(1656633600), None]; let args: Vec = vec![Arc::new(StringVector::from(times.clone()))]; - let vector = f.eval(FunctionContext::default(), &args).unwrap(); + let vector = f.eval(&FunctionContext::default(), &args).unwrap(); assert_eq!(4, vector.len()); for (i, _t) in times.iter().enumerate() { let v = vector.get(i); @@ -211,7 +211,7 @@ mod tests { let times = vec![Some(3_i64), None, Some(5_i64), None]; let results = [Some(3), None, Some(5), None]; let args: Vec = vec![Arc::new(Int64Vector::from(times.clone()))]; - let vector = f.eval(FunctionContext::default(), &args).unwrap(); + let vector = f.eval(&FunctionContext::default(), &args).unwrap(); assert_eq!(4, vector.len()); for (i, _t) in times.iter().enumerate() { let v = vector.get(i); @@ -236,7 +236,7 @@ mod tests { let results = [Some(10627200), None, Some(3628800), None]; let date_vector = DateVector::from(times.clone()); let args: Vec = vec![Arc::new(date_vector)]; - let vector = f.eval(FunctionContext::default(), &args).unwrap(); + let vector = f.eval(&FunctionContext::default(), &args).unwrap(); assert_eq!(4, vector.len()); for (i, _t) in times.iter().enumerate() { let v = vector.get(i); @@ -261,7 +261,7 @@ mod tests { let results = [Some(123), None, Some(42), None]; let date_vector = DateTimeVector::from(times.clone()); let args: Vec = vec![Arc::new(date_vector)]; - let vector = f.eval(FunctionContext::default(), &args).unwrap(); + let vector = f.eval(&FunctionContext::default(), &args).unwrap(); assert_eq!(4, vector.len()); for (i, _t) in times.iter().enumerate() { let v = vector.get(i); @@ -286,7 +286,7 @@ mod tests { let results = [Some(123), None, Some(42), None]; let ts_vector = TimestampSecondVector::from(times.clone()); let args: Vec = vec![Arc::new(ts_vector)]; - let vector = f.eval(FunctionContext::default(), &args).unwrap(); + let vector = f.eval(&FunctionContext::default(), &args).unwrap(); assert_eq!(4, vector.len()); for (i, _t) in times.iter().enumerate() { let v = vector.get(i); @@ -306,7 +306,7 @@ mod tests { let results = [Some(123), None, Some(42), None]; let ts_vector = TimestampMillisecondVector::from(times.clone()); let args: Vec = vec![Arc::new(ts_vector)]; - let vector = f.eval(FunctionContext::default(), &args).unwrap(); + let vector = f.eval(&FunctionContext::default(), &args).unwrap(); assert_eq!(4, vector.len()); for (i, _t) in times.iter().enumerate() { let v = vector.get(i); diff --git a/src/common/function/src/scalars/uddsketch_calc.rs b/src/common/function/src/scalars/uddsketch_calc.rs index 4924458c47..5c0beb4fec 100644 --- a/src/common/function/src/scalars/uddsketch_calc.rs +++ b/src/common/function/src/scalars/uddsketch_calc.rs @@ -75,7 +75,7 @@ impl Function for UddSketchCalcFunction { ) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { if columns.len() != 2 { return InvalidFuncArgsSnafu { err_msg: format!("uddsketch_calc expects 2 arguments, got {}", columns.len()), @@ -169,7 +169,7 @@ mod tests { Arc::new(BinaryVector::from(vec![Some(serialized.clone()); 3])), ]; - let result = function.eval(FunctionContext::default(), &args).unwrap(); + let result = function.eval(&FunctionContext::default(), &args).unwrap(); assert_eq!(result.len(), 3); // Test median (p50) @@ -192,7 +192,7 @@ mod tests { // Test with invalid number of arguments let args: Vec = vec![Arc::new(Float64Vector::from_vec(vec![0.95]))]; - let result = function.eval(FunctionContext::default(), &args); + let result = function.eval(&FunctionContext::default(), &args); assert!(result.is_err()); assert!(result .unwrap_err() @@ -204,7 +204,7 @@ mod tests { Arc::new(Float64Vector::from_vec(vec![0.95])), Arc::new(BinaryVector::from(vec![Some(vec![1, 2, 3])])), // Invalid binary data ]; - let result = function.eval(FunctionContext::default(), &args).unwrap(); + let result = function.eval(&FunctionContext::default(), &args).unwrap(); assert_eq!(result.len(), 1); assert!(matches!(result.get(0), datatypes::value::Value::Null)); } diff --git a/src/common/function/src/scalars/udf.rs b/src/common/function/src/scalars/udf.rs index 593162e4ab..65c094bc6b 100644 --- a/src/common/function/src/scalars/udf.rs +++ b/src/common/function/src/scalars/udf.rs @@ -12,13 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::any::Any; +use std::fmt::{Debug, Formatter}; use std::sync::Arc; use common_query::error::FromScalarValueSnafu; -use common_query::prelude::{ - ColumnarValue, ReturnTypeFunction, ScalarFunctionImplementation, ScalarUdf, -}; -use datatypes::error::Error as DataTypeError; +use common_query::prelude::ColumnarValue; +use datafusion::logical_expr::{ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_expr::ScalarUDF; +use datatypes::data_type::DataType; use datatypes::prelude::*; use datatypes::vectors::Helper; use session::context::QueryContextRef; @@ -27,58 +29,92 @@ use snafu::ResultExt; use crate::function::{FunctionContext, FunctionRef}; use crate::state::FunctionState; +struct ScalarUdf { + function: FunctionRef, + signature: datafusion_expr::Signature, + context: FunctionContext, +} + +impl Debug for ScalarUdf { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ScalarUdf") + .field("function", &self.function.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl ScalarUDFImpl for ScalarUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.function.name() + } + + fn signature(&self) -> &datafusion_expr::Signature { + &self.signature + } + + fn return_type( + &self, + arg_types: &[datatypes::arrow::datatypes::DataType], + ) -> datafusion_common::Result { + let arg_types = arg_types + .iter() + .map(ConcreteDataType::from_arrow_type) + .collect::>(); + let t = self.function.return_type(&arg_types)?; + Ok(t.as_arrow_type()) + } + + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + let columns = args + .args + .iter() + .map(|x| { + ColumnarValue::try_from(x).and_then(|y| match y { + ColumnarValue::Vector(z) => Ok(z), + ColumnarValue::Scalar(z) => Helper::try_from_scalar_value(z, args.number_rows) + .context(FromScalarValueSnafu), + }) + }) + .collect::>>()?; + let v = self + .function + .eval(&self.context, &columns) + .map(ColumnarValue::Vector)?; + Ok(v.into()) + } +} + /// Create a ScalarUdf from function, query context and state. pub fn create_udf( func: FunctionRef, query_ctx: QueryContextRef, state: Arc, -) -> ScalarUdf { - let func_cloned = func.clone(); - let return_type: ReturnTypeFunction = Arc::new(move |input_types: &[ConcreteDataType]| { - Ok(Arc::new(func_cloned.return_type(input_types)?)) - }); - - let func_cloned = func.clone(); - - let fun: ScalarFunctionImplementation = Arc::new(move |args: &[ColumnarValue]| { - let func_ctx = FunctionContext { - query_ctx: query_ctx.clone(), - state: state.clone(), - }; - - let len = args - .iter() - .fold(Option::::None, |acc, arg| match arg { - ColumnarValue::Scalar(_) => acc, - ColumnarValue::Vector(v) => Some(v.len()), - }); - - let rows = len.unwrap_or(1); - - let args: Result, DataTypeError> = args - .iter() - .map(|arg| match arg { - ColumnarValue::Scalar(v) => Helper::try_from_scalar_value(v.clone(), rows), - ColumnarValue::Vector(v) => Ok(v.clone()), - }) - .collect(); - - let result = func_cloned.eval(func_ctx, &args.context(FromScalarValueSnafu)?); - let udf_result = result.map(ColumnarValue::Vector)?; - Ok(udf_result) - }); - - ScalarUdf::new(func.name(), &func.signature(), &return_type, &fun) +) -> ScalarUDF { + let signature = func.signature().into(); + let udf = ScalarUdf { + function: func, + signature, + context: FunctionContext { query_ctx, state }, + }; + ScalarUDF::new_from_impl(udf) } #[cfg(test)] mod tests { use std::sync::Arc; - use common_query::prelude::{ColumnarValue, ScalarValue}; + use common_query::prelude::ScalarValue; + use datafusion::arrow::array::BooleanArray; use datatypes::data_type::ConcreteDataType; - use datatypes::prelude::{ScalarVector, Vector, VectorRef}; - use datatypes::value::Value; + use datatypes::prelude::VectorRef; use datatypes::vectors::{BooleanVector, ConstantVector}; use session::context::QueryContextBuilder; @@ -99,7 +135,7 @@ mod tests { Arc::new(BooleanVector::from(vec![true, false, true])), ]; - let vector = f.eval(FunctionContext::default(), &args).unwrap(); + let vector = f.eval(&FunctionContext::default(), &args).unwrap(); assert_eq!(3, vector.len()); for i in 0..3 { @@ -109,30 +145,36 @@ mod tests { // create a udf and test it again let udf = create_udf(f.clone(), query_ctx, Arc::new(FunctionState::default())); - assert_eq!("test_and", udf.name); - assert_eq!(f.signature(), udf.signature); + assert_eq!("test_and", udf.name()); + let expected_signature: datafusion_expr::Signature = f.signature().into(); + assert_eq!(udf.signature(), &expected_signature); assert_eq!( - Arc::new(ConcreteDataType::boolean_datatype()), - ((udf.return_type)(&[])).unwrap() + ConcreteDataType::boolean_datatype(), + udf.return_type(&[]) + .map(|x| ConcreteDataType::from_arrow_type(&x)) + .unwrap() ); let args = vec![ - ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))), - ColumnarValue::Vector(Arc::new(BooleanVector::from(vec![ + datafusion_expr::ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))), + datafusion_expr::ColumnarValue::Array(Arc::new(BooleanArray::from(vec![ true, false, false, true, ]))), ]; - let vec = (udf.fun)(&args).unwrap(); - - match vec { - ColumnarValue::Vector(vec) => { - let vec = vec.as_any().downcast_ref::().unwrap(); - - assert_eq!(4, vec.len()); - for i in 0..4 { - assert_eq!(i == 0 || i == 3, vec.get_data(i).unwrap(), "Failed at {i}",) - } + let args = ScalarFunctionArgs { + args: &args, + number_rows: 4, + return_type: &ConcreteDataType::boolean_datatype().as_arrow_type(), + }; + match udf.invoke_with_args(args).unwrap() { + datafusion_expr::ColumnarValue::Array(x) => { + let x = x.as_any().downcast_ref::().unwrap(); + assert_eq!(x.len(), 4); + assert_eq!( + x.iter().flatten().collect::>(), + vec![true, false, false, true] + ); } _ => unreachable!(), } diff --git a/src/common/function/src/scalars/vector/convert/parse_vector.rs b/src/common/function/src/scalars/vector/convert/parse_vector.rs index ae92a10f44..796336bda1 100644 --- a/src/common/function/src/scalars/vector/convert/parse_vector.rs +++ b/src/common/function/src/scalars/vector/convert/parse_vector.rs @@ -45,7 +45,7 @@ impl Function for ParseVectorFunction { ) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure!( columns.len() == 1, InvalidFuncArgsSnafu { @@ -101,7 +101,7 @@ mod tests { None, ])); - let result = func.eval(FunctionContext::default(), &[input]).unwrap(); + let result = func.eval(&FunctionContext::default(), &[input]).unwrap(); let result = result.as_ref(); assert_eq!(result.len(), 3); @@ -136,7 +136,7 @@ mod tests { Some("[7.0,8.0,9.0".to_string()), ])); - let result = func.eval(FunctionContext::default(), &[input]); + let result = func.eval(&FunctionContext::default(), &[input]); assert!(result.is_err()); let input = Arc::new(StringVector::from(vec![ @@ -145,7 +145,7 @@ mod tests { Some("7.0,8.0,9.0]".to_string()), ])); - let result = func.eval(FunctionContext::default(), &[input]); + let result = func.eval(&FunctionContext::default(), &[input]); assert!(result.is_err()); let input = Arc::new(StringVector::from(vec![ @@ -154,7 +154,7 @@ mod tests { Some("[7.0,hello,9.0]".to_string()), ])); - let result = func.eval(FunctionContext::default(), &[input]); + let result = func.eval(&FunctionContext::default(), &[input]); assert!(result.is_err()); } } diff --git a/src/common/function/src/scalars/vector/convert/vector_to_string.rs b/src/common/function/src/scalars/vector/convert/vector_to_string.rs index 456b072910..73639c6a60 100644 --- a/src/common/function/src/scalars/vector/convert/vector_to_string.rs +++ b/src/common/function/src/scalars/vector/convert/vector_to_string.rs @@ -46,7 +46,7 @@ impl Function for VectorToStringFunction { ) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure!( columns.len() == 1, InvalidFuncArgsSnafu { @@ -129,7 +129,7 @@ mod tests { builder.push_null(); let vector = builder.to_vector(); - let result = func.eval(FunctionContext::default(), &[vector]).unwrap(); + let result = func.eval(&FunctionContext::default(), &[vector]).unwrap(); assert_eq!(result.len(), 3); assert_eq!(result.get(0), Value::String("[1,2,3]".to_string().into())); diff --git a/src/common/function/src/scalars/vector/distance.rs b/src/common/function/src/scalars/vector/distance.rs index f17eec5b04..bc004d4eb0 100644 --- a/src/common/function/src/scalars/vector/distance.rs +++ b/src/common/function/src/scalars/vector/distance.rs @@ -60,7 +60,7 @@ macro_rules! define_distance_function { ) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure!( columns.len() == 2, InvalidFuncArgsSnafu { @@ -159,7 +159,7 @@ mod tests { ])) as VectorRef; let result = func - .eval(FunctionContext::default(), &[vec1.clone(), vec2.clone()]) + .eval(&FunctionContext::default(), &[vec1.clone(), vec2.clone()]) .unwrap(); assert!(!result.get(0).is_null()); @@ -168,7 +168,7 @@ mod tests { assert!(result.get(3).is_null()); let result = func - .eval(FunctionContext::default(), &[vec2, vec1]) + .eval(&FunctionContext::default(), &[vec2, vec1]) .unwrap(); assert!(!result.get(0).is_null()); @@ -202,7 +202,7 @@ mod tests { ])) as VectorRef; let result = func - .eval(FunctionContext::default(), &[vec1.clone(), vec2.clone()]) + .eval(&FunctionContext::default(), &[vec1.clone(), vec2.clone()]) .unwrap(); assert!(!result.get(0).is_null()); @@ -211,7 +211,7 @@ mod tests { assert!(result.get(3).is_null()); let result = func - .eval(FunctionContext::default(), &[vec2, vec1]) + .eval(&FunctionContext::default(), &[vec2, vec1]) .unwrap(); assert!(!result.get(0).is_null()); @@ -245,7 +245,7 @@ mod tests { ])) as VectorRef; let result = func - .eval(FunctionContext::default(), &[vec1.clone(), vec2.clone()]) + .eval(&FunctionContext::default(), &[vec1.clone(), vec2.clone()]) .unwrap(); assert!(!result.get(0).is_null()); @@ -254,7 +254,7 @@ mod tests { assert!(result.get(3).is_null()); let result = func - .eval(FunctionContext::default(), &[vec2, vec1]) + .eval(&FunctionContext::default(), &[vec2, vec1]) .unwrap(); assert!(!result.get(0).is_null()); @@ -294,7 +294,7 @@ mod tests { let result = func .eval( - FunctionContext::default(), + &FunctionContext::default(), &[const_str.clone(), vec1.clone()], ) .unwrap(); @@ -306,7 +306,7 @@ mod tests { let result = func .eval( - FunctionContext::default(), + &FunctionContext::default(), &[vec1.clone(), const_str.clone()], ) .unwrap(); @@ -318,7 +318,7 @@ mod tests { let result = func .eval( - FunctionContext::default(), + &FunctionContext::default(), &[const_str.clone(), vec2.clone()], ) .unwrap(); @@ -330,7 +330,7 @@ mod tests { let result = func .eval( - FunctionContext::default(), + &FunctionContext::default(), &[vec2.clone(), const_str.clone()], ) .unwrap(); @@ -353,13 +353,13 @@ mod tests { for func in funcs { let vec1 = Arc::new(StringVector::from(vec!["[1.0]"])) as VectorRef; let vec2 = Arc::new(StringVector::from(vec!["[1.0, 1.0]"])) as VectorRef; - let result = func.eval(FunctionContext::default(), &[vec1, vec2]); + let result = func.eval(&FunctionContext::default(), &[vec1, vec2]); assert!(result.is_err()); let vec1 = Arc::new(BinaryVector::from(vec![vec![0, 0, 128, 63]])) as VectorRef; let vec2 = Arc::new(BinaryVector::from(vec![vec![0, 0, 128, 63, 0, 0, 0, 64]])) as VectorRef; - let result = func.eval(FunctionContext::default(), &[vec1, vec2]); + let result = func.eval(&FunctionContext::default(), &[vec1, vec2]); assert!(result.is_err()); } } diff --git a/src/common/function/src/scalars/vector/elem_product.rs b/src/common/function/src/scalars/vector/elem_product.rs index 062000bb78..82c64958d7 100644 --- a/src/common/function/src/scalars/vector/elem_product.rs +++ b/src/common/function/src/scalars/vector/elem_product.rs @@ -68,7 +68,7 @@ impl Function for ElemProductFunction { fn eval( &self, - _func_ctx: FunctionContext, + _func_ctx: &FunctionContext, columns: &[VectorRef], ) -> common_query::error::Result { ensure!( @@ -131,7 +131,7 @@ mod tests { None, ])); - let result = func.eval(FunctionContext::default(), &[input0]).unwrap(); + let result = func.eval(&FunctionContext::default(), &[input0]).unwrap(); let result = result.as_ref(); assert_eq!(result.len(), 3); diff --git a/src/common/function/src/scalars/vector/elem_sum.rs b/src/common/function/src/scalars/vector/elem_sum.rs index 748614e05c..bc0459c6be 100644 --- a/src/common/function/src/scalars/vector/elem_sum.rs +++ b/src/common/function/src/scalars/vector/elem_sum.rs @@ -55,7 +55,7 @@ impl Function for ElemSumFunction { fn eval( &self, - _func_ctx: FunctionContext, + _func_ctx: &FunctionContext, columns: &[VectorRef], ) -> common_query::error::Result { ensure!( @@ -118,7 +118,7 @@ mod tests { None, ])); - let result = func.eval(FunctionContext::default(), &[input0]).unwrap(); + let result = func.eval(&FunctionContext::default(), &[input0]).unwrap(); let result = result.as_ref(); assert_eq!(result.len(), 3); diff --git a/src/common/function/src/scalars/vector/scalar_add.rs b/src/common/function/src/scalars/vector/scalar_add.rs index ef016eff4b..f6a070361b 100644 --- a/src/common/function/src/scalars/vector/scalar_add.rs +++ b/src/common/function/src/scalars/vector/scalar_add.rs @@ -73,7 +73,7 @@ impl Function for ScalarAddFunction { ) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure!( columns.len() == 2, InvalidFuncArgsSnafu { @@ -154,7 +154,7 @@ mod tests { ])); let result = func - .eval(FunctionContext::default(), &[input0, input1]) + .eval(&FunctionContext::default(), &[input0, input1]) .unwrap(); let result = result.as_ref(); diff --git a/src/common/function/src/scalars/vector/scalar_mul.rs b/src/common/function/src/scalars/vector/scalar_mul.rs index 3c7fe4c070..9f4480bb51 100644 --- a/src/common/function/src/scalars/vector/scalar_mul.rs +++ b/src/common/function/src/scalars/vector/scalar_mul.rs @@ -73,7 +73,7 @@ impl Function for ScalarMulFunction { ) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure!( columns.len() == 2, InvalidFuncArgsSnafu { @@ -154,7 +154,7 @@ mod tests { ])); let result = func - .eval(FunctionContext::default(), &[input0, input1]) + .eval(&FunctionContext::default(), &[input0, input1]) .unwrap(); let result = result.as_ref(); diff --git a/src/common/function/src/scalars/vector/vector_add.rs b/src/common/function/src/scalars/vector/vector_add.rs index f0fd9bbbc3..679ec38256 100644 --- a/src/common/function/src/scalars/vector/vector_add.rs +++ b/src/common/function/src/scalars/vector/vector_add.rs @@ -72,7 +72,7 @@ impl Function for VectorAddFunction { fn eval( &self, - _func_ctx: FunctionContext, + _func_ctx: &FunctionContext, columns: &[VectorRef], ) -> common_query::error::Result { ensure!( @@ -166,7 +166,7 @@ mod tests { ])); let result = func - .eval(FunctionContext::default(), &[input0, input1]) + .eval(&FunctionContext::default(), &[input0, input1]) .unwrap(); let result = result.as_ref(); @@ -199,7 +199,7 @@ mod tests { Some("[3.0,2.0,2.0]".to_string()), ])); - let result = func.eval(FunctionContext::default(), &[input0, input1]); + let result = func.eval(&FunctionContext::default(), &[input0, input1]); match result { Err(Error::InvalidFuncArgs { err_msg, .. }) => { diff --git a/src/common/function/src/scalars/vector/vector_dim.rs b/src/common/function/src/scalars/vector/vector_dim.rs index 6a7c078100..59c38609ba 100644 --- a/src/common/function/src/scalars/vector/vector_dim.rs +++ b/src/common/function/src/scalars/vector/vector_dim.rs @@ -67,7 +67,7 @@ impl Function for VectorDimFunction { fn eval( &self, - _func_ctx: FunctionContext, + _func_ctx: &FunctionContext, columns: &[VectorRef], ) -> common_query::error::Result { ensure!( @@ -131,7 +131,7 @@ mod tests { Some("[5.0]".to_string()), ])); - let result = func.eval(FunctionContext::default(), &[input0]).unwrap(); + let result = func.eval(&FunctionContext::default(), &[input0]).unwrap(); let result = result.as_ref(); assert_eq!(result.len(), 4); @@ -157,7 +157,7 @@ mod tests { Some("[3.0,2.0,2.0]".to_string()), ])); - let result = func.eval(FunctionContext::default(), &[input0, input1]); + let result = func.eval(&FunctionContext::default(), &[input0, input1]); match result { Err(Error::InvalidFuncArgs { err_msg, .. }) => { diff --git a/src/common/function/src/scalars/vector/vector_div.rs b/src/common/function/src/scalars/vector/vector_div.rs index d7f57796a3..74e784aa41 100644 --- a/src/common/function/src/scalars/vector/vector_div.rs +++ b/src/common/function/src/scalars/vector/vector_div.rs @@ -68,7 +68,7 @@ impl Function for VectorDivFunction { ) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure!( columns.len() == 2, InvalidFuncArgsSnafu { @@ -155,7 +155,7 @@ mod tests { let input1 = Arc::new(StringVector::from(vec![Some(format!("{vec1:?}"))])); let err = func - .eval(FunctionContext::default(), &[input0, input1]) + .eval(&FunctionContext::default(), &[input0, input1]) .unwrap_err(); match err { @@ -186,7 +186,7 @@ mod tests { ])); let result = func - .eval(FunctionContext::default(), &[input0, input1]) + .eval(&FunctionContext::default(), &[input0, input1]) .unwrap(); let result = result.as_ref(); @@ -206,7 +206,7 @@ mod tests { let input1 = Arc::new(StringVector::from(vec![Some("[0.0,0.0]".to_string())])); let result = func - .eval(FunctionContext::default(), &[input0, input1]) + .eval(&FunctionContext::default(), &[input0, input1]) .unwrap(); let result = result.as_ref(); diff --git a/src/common/function/src/scalars/vector/vector_mul.rs b/src/common/function/src/scalars/vector/vector_mul.rs index 02e9833623..cbfe3e8452 100644 --- a/src/common/function/src/scalars/vector/vector_mul.rs +++ b/src/common/function/src/scalars/vector/vector_mul.rs @@ -68,7 +68,7 @@ impl Function for VectorMulFunction { ) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { ensure!( columns.len() == 2, InvalidFuncArgsSnafu { @@ -155,7 +155,7 @@ mod tests { let input1 = Arc::new(StringVector::from(vec![Some(format!("{vec1:?}"))])); let err = func - .eval(FunctionContext::default(), &[input0, input1]) + .eval(&FunctionContext::default(), &[input0, input1]) .unwrap_err(); match err { @@ -186,7 +186,7 @@ mod tests { ])); let result = func - .eval(FunctionContext::default(), &[input0, input1]) + .eval(&FunctionContext::default(), &[input0, input1]) .unwrap(); let result = result.as_ref(); diff --git a/src/common/function/src/scalars/vector/vector_norm.rs b/src/common/function/src/scalars/vector/vector_norm.rs index 62eeb395e0..b0979ddf7e 100644 --- a/src/common/function/src/scalars/vector/vector_norm.rs +++ b/src/common/function/src/scalars/vector/vector_norm.rs @@ -67,7 +67,7 @@ impl Function for VectorNormFunction { fn eval( &self, - _func_ctx: FunctionContext, + _func_ctx: &FunctionContext, columns: &[VectorRef], ) -> common_query::error::Result { ensure!( @@ -143,7 +143,7 @@ mod tests { None, ])); - let result = func.eval(FunctionContext::default(), &[input0]).unwrap(); + let result = func.eval(&FunctionContext::default(), &[input0]).unwrap(); let result = result.as_ref(); assert_eq!(result.len(), 5); diff --git a/src/common/function/src/scalars/vector/vector_sub.rs b/src/common/function/src/scalars/vector/vector_sub.rs index 7f97bb322e..8e034417c8 100644 --- a/src/common/function/src/scalars/vector/vector_sub.rs +++ b/src/common/function/src/scalars/vector/vector_sub.rs @@ -72,7 +72,7 @@ impl Function for VectorSubFunction { fn eval( &self, - _func_ctx: FunctionContext, + _func_ctx: &FunctionContext, columns: &[VectorRef], ) -> common_query::error::Result { ensure!( @@ -166,7 +166,7 @@ mod tests { ])); let result = func - .eval(FunctionContext::default(), &[input0, input1]) + .eval(&FunctionContext::default(), &[input0, input1]) .unwrap(); let result = result.as_ref(); @@ -199,7 +199,7 @@ mod tests { Some("[3.0,2.0,2.0]".to_string()), ])); - let result = func.eval(FunctionContext::default(), &[input0, input1]); + let result = func.eval(&FunctionContext::default(), &[input0, input1]); match result { Err(Error::InvalidFuncArgs { err_msg, .. }) => { diff --git a/src/common/function/src/system/build.rs b/src/common/function/src/system/build.rs index bd5b044a9c..1c17865325 100644 --- a/src/common/function/src/system/build.rs +++ b/src/common/function/src/system/build.rs @@ -45,7 +45,7 @@ impl Function for BuildFunction { Signature::nullary(Volatility::Immutable) } - fn eval(&self, _func_ctx: FunctionContext, _columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result { let build_info = common_version::build_info().to_string(); let v = Arc::new(StringVector::from(vec![build_info])); Ok(v) @@ -67,7 +67,7 @@ mod tests { ); assert_eq!(build.signature(), Signature::nullary(Volatility::Immutable)); let build_info = common_version::build_info().to_string(); - let vector = build.eval(FunctionContext::default(), &[]).unwrap(); + let vector = build.eval(&FunctionContext::default(), &[]).unwrap(); let expect: VectorRef = Arc::new(StringVector::from(vec![build_info])); assert_eq!(expect, vector); } diff --git a/src/common/function/src/system/database.rs b/src/common/function/src/system/database.rs index a9759de115..370bd2c8da 100644 --- a/src/common/function/src/system/database.rs +++ b/src/common/function/src/system/database.rs @@ -47,7 +47,7 @@ impl Function for DatabaseFunction { Signature::nullary(Volatility::Immutable) } - fn eval(&self, func_ctx: FunctionContext, _columns: &[VectorRef]) -> Result { + fn eval(&self, func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result { let db = func_ctx.query_ctx.current_schema(); Ok(Arc::new(StringVector::from_slice(&[&db])) as _) @@ -67,7 +67,7 @@ impl Function for CurrentSchemaFunction { Signature::uniform(0, vec![], Volatility::Immutable) } - fn eval(&self, func_ctx: FunctionContext, _columns: &[VectorRef]) -> Result { + fn eval(&self, func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result { let db = func_ctx.query_ctx.current_schema(); Ok(Arc::new(StringVector::from_slice(&[&db])) as _) @@ -87,7 +87,7 @@ impl Function for SessionUserFunction { Signature::uniform(0, vec![], Volatility::Immutable) } - fn eval(&self, func_ctx: FunctionContext, _columns: &[VectorRef]) -> Result { + fn eval(&self, func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result { let user = func_ctx.query_ctx.current_user(); Ok(Arc::new(StringVector::from_slice(&[user.username()])) as _) @@ -138,7 +138,7 @@ mod tests { query_ctx, ..Default::default() }; - let vector = build.eval(func_ctx, &[]).unwrap(); + let vector = build.eval(&func_ctx, &[]).unwrap(); let expect: VectorRef = Arc::new(StringVector::from(vec!["test_db"])); assert_eq!(expect, vector); } diff --git a/src/common/function/src/system/pg_catalog/pg_get_userbyid.rs b/src/common/function/src/system/pg_catalog/pg_get_userbyid.rs index d618ec4ecd..1b0b1a987d 100644 --- a/src/common/function/src/system/pg_catalog/pg_get_userbyid.rs +++ b/src/common/function/src/system/pg_catalog/pg_get_userbyid.rs @@ -53,7 +53,7 @@ impl Function for PGGetUserByIdFunction { ) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$T| { let col = scalar_unary_op::<<$T as LogicalPrimitiveType>::Native, String, _>(&columns[0], pg_get_user_by_id, &mut EvalContext::default())?; Ok(Arc::new(col)) diff --git a/src/common/function/src/system/pg_catalog/table_is_visible.rs b/src/common/function/src/system/pg_catalog/table_is_visible.rs index 630ad13762..eef007cf04 100644 --- a/src/common/function/src/system/pg_catalog/table_is_visible.rs +++ b/src/common/function/src/system/pg_catalog/table_is_visible.rs @@ -53,7 +53,7 @@ impl Function for PGTableIsVisibleFunction { ) } - fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$T| { let col = scalar_unary_op::<<$T as LogicalPrimitiveType>::Native, bool, _>(&columns[0], pg_table_is_visible, &mut EvalContext::default())?; Ok(Arc::new(col)) diff --git a/src/common/function/src/system/pg_catalog/version.rs b/src/common/function/src/system/pg_catalog/version.rs index e9511bd6e1..7be27fe9b2 100644 --- a/src/common/function/src/system/pg_catalog/version.rs +++ b/src/common/function/src/system/pg_catalog/version.rs @@ -44,7 +44,7 @@ impl Function for PGVersionFunction { Signature::exact(vec![], Volatility::Immutable) } - fn eval(&self, _func_ctx: FunctionContext, _columns: &[VectorRef]) -> Result { + fn eval(&self, _func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result { let result = StringVector::from(vec![format!( "PostgreSQL 16.3 GreptimeDB {}", env!("CARGO_PKG_VERSION") diff --git a/src/common/function/src/system/timezone.rs b/src/common/function/src/system/timezone.rs index 1ac873e61b..3c1a7254aa 100644 --- a/src/common/function/src/system/timezone.rs +++ b/src/common/function/src/system/timezone.rs @@ -41,7 +41,7 @@ impl Function for TimezoneFunction { Signature::nullary(Volatility::Immutable) } - fn eval(&self, func_ctx: FunctionContext, _columns: &[VectorRef]) -> Result { + fn eval(&self, func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result { let tz = func_ctx.query_ctx.timezone().to_string(); Ok(Arc::new(StringVector::from_slice(&[&tz])) as _) @@ -77,7 +77,7 @@ mod tests { query_ctx, ..Default::default() }; - let vector = build.eval(func_ctx, &[]).unwrap(); + let vector = build.eval(&func_ctx, &[]).unwrap(); let expect: VectorRef = Arc::new(StringVector::from(vec!["UTC"])); assert_eq!(expect, vector); } diff --git a/src/common/function/src/system/version.rs b/src/common/function/src/system/version.rs index 96a8d7fc6b..bfab3f1334 100644 --- a/src/common/function/src/system/version.rs +++ b/src/common/function/src/system/version.rs @@ -45,7 +45,7 @@ impl Function for VersionFunction { Signature::nullary(Volatility::Immutable) } - fn eval(&self, func_ctx: FunctionContext, _columns: &[VectorRef]) -> Result { + fn eval(&self, func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result { let version = match func_ctx.query_ctx.channel() { Channel::Mysql => { format!( diff --git a/src/common/query/src/error.rs b/src/common/query/src/error.rs index 4141f34881..b81d4cde8b 100644 --- a/src/common/query/src/error.rs +++ b/src/common/query/src/error.rs @@ -30,14 +30,6 @@ use statrs::StatsError; #[snafu(visibility(pub))] #[stack_trace_debug] pub enum Error { - #[snafu(display("Failed to execute function"))] - ExecuteFunction { - #[snafu(source)] - error: DataFusionError, - #[snafu(implicit)] - location: Location, - }, - #[snafu(display("Unsupported input datatypes {:?} in function {}", datatypes, function))] UnsupportedInputDataType { function: String, @@ -264,9 +256,7 @@ impl ErrorExt for Error { | Error::ArrowCompute { .. } | Error::FlownodeNotFound { .. } => StatusCode::EngineExecuteQuery, - Error::ExecuteFunction { error, .. } | Error::GeneralDataFusion { error, .. } => { - datafusion_status_code::(error, None) - } + Error::GeneralDataFusion { error, .. } => datafusion_status_code::(error, None), Error::InvalidInputType { source, .. } | Error::IntoVector { source, .. } diff --git a/src/common/query/src/function.rs b/src/common/query/src/function.rs index 6eb683c797..e1806737a6 100644 --- a/src/common/query/src/function.rs +++ b/src/common/query/src/function.rs @@ -17,23 +17,9 @@ use std::sync::Arc; use datafusion_expr::ReturnTypeFunction as DfReturnTypeFunction; use datatypes::arrow::datatypes::DataType as ArrowDataType; use datatypes::prelude::{ConcreteDataType, DataType}; -use datatypes::vectors::VectorRef; -use snafu::ResultExt; -use crate::error::{ExecuteFunctionSnafu, Result}; +use crate::error::Result; use crate::logical_plan::Accumulator; -use crate::prelude::{ColumnarValue, ScalarValue}; - -/// Scalar function -/// -/// The Fn param is the wrapped function but be aware that the function will -/// be passed with the slice / vec of columnar values (either scalar or array) -/// with the exception of zero param function, where a singular element vec -/// will be passed. In that case the single element is a null array to indicate -/// the batch's row count (so that the generative zero-argument function can know -/// the result array size). -pub type ScalarFunctionImplementation = - Arc Result + Send + Sync>; /// A function's return type pub type ReturnTypeFunction = @@ -51,48 +37,6 @@ pub type AccumulatorCreatorFunction = pub type StateTypeFunction = Arc Result>> + Send + Sync>; -/// decorates a function to handle [`ScalarValue`]s by converting them to arrays before calling the function -/// and vice-versa after evaluation. -pub fn make_scalar_function(inner: F) -> ScalarFunctionImplementation -where - F: Fn(&[VectorRef]) -> Result + Sync + Send + 'static, -{ - Arc::new(move |args: &[ColumnarValue]| { - // first, identify if any of the arguments is an vector. If yes, store its `len`, - // as any scalar will need to be converted to an vector of len `len`. - let len = args - .iter() - .fold(Option::::None, |acc, arg| match arg { - ColumnarValue::Scalar(_) => acc, - ColumnarValue::Vector(v) => Some(v.len()), - }); - - // to array - // TODO(dennis): we create new vectors from Scalar on each call, - // should be optimized in the future. - let args: Result> = if let Some(len) = len { - args.iter() - .map(|arg| arg.clone().try_into_vector(len)) - .collect() - } else { - args.iter() - .map(|arg| arg.clone().try_into_vector(1)) - .collect() - }; - - let result = (inner)(&args?); - - // maybe back to scalar - if len.is_some() { - result.map(ColumnarValue::Vector) - } else { - Ok(ScalarValue::try_from_array(&result?.to_arrow_array(), 0) - .map(ColumnarValue::Scalar) - .context(ExecuteFunctionSnafu)?) - } - }) -} - pub fn to_df_return_type(func: ReturnTypeFunction) -> DfReturnTypeFunction { let df_func = move |data_types: &[ArrowDataType]| { // DataFusion DataType -> ConcreteDataType @@ -111,60 +55,3 @@ pub fn to_df_return_type(func: ReturnTypeFunction) -> DfReturnTypeFunction { }; Arc::new(df_func) } - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use datatypes::prelude::{ScalarVector, Vector}; - use datatypes::vectors::BooleanVector; - - use super::*; - - #[test] - fn test_make_scalar_function() { - let and_fun = |args: &[VectorRef]| -> Result { - let left = &args[0] - .as_any() - .downcast_ref::() - .expect("cast failed"); - let right = &args[1] - .as_any() - .downcast_ref::() - .expect("cast failed"); - - let result = left - .iter_data() - .zip(right.iter_data()) - .map(|(left, right)| match (left, right) { - (Some(left), Some(right)) => Some(left && right), - _ => None, - }) - .collect::(); - Ok(Arc::new(result) as VectorRef) - }; - - let and_fun = make_scalar_function(and_fun); - - let args = vec![ - ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))), - ColumnarValue::Vector(Arc::new(BooleanVector::from(vec![ - true, false, false, true, - ]))), - ]; - - let vec = (and_fun)(&args).unwrap(); - - match vec { - ColumnarValue::Vector(vec) => { - let vec = vec.as_any().downcast_ref::().unwrap(); - - assert_eq!(4, vec.len()); - for i in 0..4 { - assert_eq!(i == 0 || i == 3, vec.get_data(i).unwrap(), "Failed at {i}") - } - } - _ => unreachable!(), - } - } -} diff --git a/src/common/query/src/logical_plan.rs b/src/common/query/src/logical_plan.rs index 6dc94307b8..974a30a15a 100644 --- a/src/common/query/src/logical_plan.rs +++ b/src/common/query/src/logical_plan.rs @@ -15,7 +15,6 @@ pub mod accumulator; mod expr; mod udaf; -mod udf; use std::sync::Arc; @@ -24,38 +23,14 @@ use datafusion::error::Result as DatafusionResult; use datafusion::logical_expr::{LogicalPlan, LogicalPlanBuilder}; use datafusion_common::Column; use datafusion_expr::col; -use datatypes::prelude::ConcreteDataType; pub use expr::{build_filter_from_timestamp, build_same_type_ts_filter}; pub use self::accumulator::{Accumulator, AggregateFunctionCreator, AggregateFunctionCreatorRef}; pub use self::udaf::AggregateFunction; -pub use self::udf::ScalarUdf; use crate::error::Result; -use crate::function::{ReturnTypeFunction, ScalarFunctionImplementation}; use crate::logical_plan::accumulator::*; use crate::signature::{Signature, Volatility}; -/// Creates a new UDF with a specific signature and specific return type. -/// This is a helper function to create a new UDF. -/// The function `create_udf` returns a subset of all possible `ScalarFunction`: -/// * the UDF has a fixed return type -/// * the UDF has a fixed signature (e.g. [f64, f64]) -pub fn create_udf( - name: &str, - input_types: Vec, - return_type: Arc, - volatility: Volatility, - fun: ScalarFunctionImplementation, -) -> ScalarUdf { - let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone())); - ScalarUdf::new( - name, - &Signature::exact(input_types, volatility), - &return_type, - &fun, - ) -} - pub fn create_aggregate_function( name: String, args_count: u8, @@ -127,102 +102,17 @@ pub type SubstraitPlanDecoderRef = Arc; mod tests { use std::sync::Arc; - use datafusion_common::DFSchema; use datafusion_expr::builder::LogicalTableSource; - use datafusion_expr::{ - lit, ColumnarValue as DfColumnarValue, ScalarUDF as DfScalarUDF, - TypeSignature as DfTypeSignature, - }; - use datatypes::arrow::array::BooleanArray; + use datafusion_expr::lit; use datatypes::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datatypes::prelude::*; - use datatypes::vectors::{BooleanVector, VectorRef}; + use datatypes::vectors::VectorRef; use super::*; use crate::error::Result; - use crate::function::{make_scalar_function, AccumulatorCreatorFunction}; - use crate::prelude::ScalarValue; + use crate::function::AccumulatorCreatorFunction; use crate::signature::TypeSignature; - #[test] - fn test_create_udf() { - let and_fun = |args: &[VectorRef]| -> Result { - let left = &args[0] - .as_any() - .downcast_ref::() - .expect("cast failed"); - let right = &args[1] - .as_any() - .downcast_ref::() - .expect("cast failed"); - - let result = left - .iter_data() - .zip(right.iter_data()) - .map(|(left, right)| match (left, right) { - (Some(left), Some(right)) => Some(left && right), - _ => None, - }) - .collect::(); - Ok(Arc::new(result) as VectorRef) - }; - - let and_fun = make_scalar_function(and_fun); - - let input_types = vec![ - ConcreteDataType::boolean_datatype(), - ConcreteDataType::boolean_datatype(), - ]; - - let return_type = Arc::new(ConcreteDataType::boolean_datatype()); - - let udf = create_udf( - "and", - input_types.clone(), - return_type.clone(), - Volatility::Immutable, - and_fun.clone(), - ); - - assert_eq!("and", udf.name); - assert!( - matches!(&udf.signature.type_signature, TypeSignature::Exact(ts) if ts.clone() == input_types) - ); - assert_eq!(return_type, (udf.return_type)(&[]).unwrap()); - - // test into_df_udf - let df_udf: DfScalarUDF = udf.into(); - assert_eq!("and", df_udf.name()); - - let types = vec![DataType::Boolean, DataType::Boolean]; - assert!( - matches!(&df_udf.signature().type_signature, DfTypeSignature::Exact(ts) if ts.clone() == types) - ); - assert_eq!( - DataType::Boolean, - df_udf - .return_type_from_exprs(&[], &DFSchema::empty(), &[]) - .unwrap() - ); - - let args = vec![ - DfColumnarValue::Scalar(ScalarValue::Boolean(Some(true))), - DfColumnarValue::Array(Arc::new(BooleanArray::from(vec![true, false, false, true]))), - ]; - - // call the function - let result = df_udf.invoke_batch(&args, 4).unwrap(); - match result { - DfColumnarValue::Array(arr) => { - let arr = arr.as_any().downcast_ref::().unwrap(); - for i in 0..4 { - assert_eq!(i == 0 || i == 3, arr.value(i)); - } - } - _ => unreachable!(), - } - } - #[derive(Debug)] struct DummyAccumulator; diff --git a/src/common/query/src/logical_plan/udf.rs b/src/common/query/src/logical_plan/udf.rs deleted file mode 100644 index 276f753e77..0000000000 --- a/src/common/query/src/logical_plan/udf.rs +++ /dev/null @@ -1,134 +0,0 @@ -// Copyright 2023 Greptime Team -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! Udf module contains foundational types that are used to represent UDFs. -//! It's modified from datafusion. -use std::any::Any; -use std::fmt; -use std::fmt::{Debug, Formatter}; -use std::sync::Arc; - -use datafusion_expr::{ - ColumnarValue as DfColumnarValue, - ScalarFunctionImplementation as DfScalarFunctionImplementation, ScalarUDF as DfScalarUDF, - ScalarUDFImpl, -}; -use datatypes::arrow::datatypes::DataType; - -use crate::error::Result; -use crate::function::{ReturnTypeFunction, ScalarFunctionImplementation}; -use crate::prelude::to_df_return_type; -use crate::signature::Signature; - -/// Logical representation of a UDF. -#[derive(Clone)] -pub struct ScalarUdf { - /// name - pub name: String, - /// signature - pub signature: Signature, - /// Return type - pub return_type: ReturnTypeFunction, - /// actual implementation - pub fun: ScalarFunctionImplementation, -} - -impl Debug for ScalarUdf { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - f.debug_struct("ScalarUdf") - .field("name", &self.name) - .field("signature", &self.signature) - .field("fun", &"") - .finish() - } -} - -impl ScalarUdf { - /// Create a new ScalarUdf - pub fn new( - name: &str, - signature: &Signature, - return_type: &ReturnTypeFunction, - fun: &ScalarFunctionImplementation, - ) -> Self { - Self { - name: name.to_owned(), - signature: signature.clone(), - return_type: return_type.clone(), - fun: fun.clone(), - } - } -} - -#[derive(Clone)] -struct DfUdfAdapter { - name: String, - signature: datafusion_expr::Signature, - return_type: datafusion_expr::ReturnTypeFunction, - fun: DfScalarFunctionImplementation, -} - -impl Debug for DfUdfAdapter { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - f.debug_struct("DfUdfAdapter") - .field("name", &self.name) - .field("signature", &self.signature) - .finish() - } -} - -impl ScalarUDFImpl for DfUdfAdapter { - fn as_any(&self) -> &dyn Any { - self - } - - fn name(&self) -> &str { - &self.name - } - - fn signature(&self) -> &datafusion_expr::Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { - (self.return_type)(arg_types).map(|ty| ty.as_ref().clone()) - } - - fn invoke(&self, args: &[DfColumnarValue]) -> datafusion_common::Result { - (self.fun)(args) - } - - fn invoke_no_args(&self, number_rows: usize) -> datafusion_common::Result { - Ok((self.fun)(&[])?.into_array(number_rows)?.into()) - } -} - -impl From for DfScalarUDF { - fn from(udf: ScalarUdf) -> Self { - DfScalarUDF::new_from_impl(DfUdfAdapter { - name: udf.name, - signature: udf.signature.into(), - return_type: to_df_return_type(udf.return_type), - fun: to_df_scalar_func(udf.fun), - }) - } -} - -fn to_df_scalar_func(fun: ScalarFunctionImplementation) -> DfScalarFunctionImplementation { - Arc::new(move |args: &[DfColumnarValue]| { - let args: Result> = args.iter().map(TryFrom::try_from).collect(); - let result = fun(&args?); - result.map(From::from).map_err(|e| e.into()) - }) -} diff --git a/src/common/query/src/prelude.rs b/src/common/query/src/prelude.rs index 8cfc125583..83c6f8564d 100644 --- a/src/common/query/src/prelude.rs +++ b/src/common/query/src/prelude.rs @@ -16,7 +16,7 @@ pub use datafusion_common::ScalarValue; pub use crate::columnar_value::ColumnarValue; pub use crate::function::*; -pub use crate::logical_plan::{create_udf, AggregateFunction, ScalarUdf}; +pub use crate::logical_plan::AggregateFunction; pub use crate::signature::{Signature, TypeSignature, Volatility}; /// Default timestamp column name for Prometheus metrics. diff --git a/src/datanode/src/tests.rs b/src/datanode/src/tests.rs index a7f95e29fd..8e8878fa79 100644 --- a/src/datanode/src/tests.rs +++ b/src/datanode/src/tests.rs @@ -21,7 +21,6 @@ use async_trait::async_trait; use common_error::ext::BoxedError; use common_function::function::FunctionRef; use common_function::scalars::aggregate::AggregateFunctionMetaRef; -use common_query::prelude::ScalarUdf; use common_query::Output; use common_runtime::runtime::{BuilderBuild, RuntimeTrait}; use common_runtime::Runtime; @@ -77,8 +76,6 @@ impl QueryEngine for MockQueryEngine { unimplemented!() } - fn register_udf(&self, _udf: ScalarUdf) {} - fn register_aggregate_function(&self, _func: AggregateFunctionMetaRef) {} fn register_function(&self, _func: FunctionRef) {} diff --git a/src/flow/src/transform.rs b/src/flow/src/transform.rs index 01ff9cc299..15da89b21f 100644 --- a/src/flow/src/transform.rs +++ b/src/flow/src/transform.rs @@ -17,6 +17,7 @@ use std::collections::BTreeMap; use std::sync::Arc; use common_error::ext::BoxedError; +use common_function::function::FunctionContext; use datafusion_substrait::extensions::Extensions; use datatypes::data_type::ConcreteDataType as CDT; use query::QueryEngine; @@ -146,7 +147,7 @@ impl common_function::function::Function for TumbleFunction { fn eval( &self, - _func_ctx: common_function::function::FunctionContext, + _func_ctx: &FunctionContext, _columns: &[datatypes::prelude::VectorRef], ) -> common_query::error::Result { UnexpectedSnafu { diff --git a/src/frontend/src/instance/jaeger.rs b/src/frontend/src/instance/jaeger.rs index 8e7c350386..cfc89a8ed9 100644 --- a/src/frontend/src/instance/jaeger.rs +++ b/src/frontend/src/instance/jaeger.rs @@ -262,8 +262,11 @@ fn create_df_context( ]; for udf in udfs { - df_context - .register_udf(create_udf(udf, ctx.clone(), Arc::new(FunctionState::default())).into()); + df_context.register_udf(create_udf( + udf, + ctx.clone(), + Arc::new(FunctionState::default()), + )); } Ok(df_context) diff --git a/src/mito2/src/sst/index/fulltext_index/applier/builder.rs b/src/mito2/src/sst/index/fulltext_index/applier/builder.rs index b76bdc2f1b..5e91bbfac4 100644 --- a/src/mito2/src/sst/index/fulltext_index/applier/builder.rs +++ b/src/mito2/src/sst/index/fulltext_index/applier/builder.rs @@ -159,14 +159,11 @@ mod tests { } fn matches_func() -> Arc { - Arc::new( - create_udf( - FUNCTION_REGISTRY.get_function("matches").unwrap(), - QueryContext::arc(), - Default::default(), - ) - .into(), - ) + Arc::new(create_udf( + FUNCTION_REGISTRY.get_function("matches").unwrap(), + QueryContext::arc(), + Default::default(), + )) } #[test] diff --git a/src/query/src/datafusion.rs b/src/query/src/datafusion.rs index fff002268a..036d717884 100644 --- a/src/query/src/datafusion.rs +++ b/src/query/src/datafusion.rs @@ -27,7 +27,6 @@ use common_catalog::consts::is_readonly_schema; use common_error::ext::BoxedError; use common_function::function::FunctionRef; use common_function::scalars::aggregate::AggregateFunctionMetaRef; -use common_query::prelude::ScalarUdf; use common_query::{Output, OutputData, OutputMeta}; use common_recordbatch::adapter::RecordBatchStreamAdapter; use common_recordbatch::{EmptyRecordBatchStream, SendableRecordBatchStream}; @@ -455,11 +454,6 @@ impl QueryEngine for DatafusionQueryEngine { self.state.register_aggregate_function(func); } - /// Register a [`ScalarUdf`]. - fn register_udf(&self, udf: ScalarUdf) { - self.state.register_udf(udf); - } - /// Register an UDF function. /// Will override if the function with same name is already registered. fn register_function(&self, func: FunctionRef) { diff --git a/src/query/src/datafusion/planner.rs b/src/query/src/datafusion/planner.rs index 25f1015735..13e95ee560 100644 --- a/src/query/src/datafusion/planner.rs +++ b/src/query/src/datafusion/planner.rs @@ -155,14 +155,11 @@ impl ContextProvider for DfContextProviderAdapter { self.engine_state.udf_function(name).map_or_else( || self.session_state.scalar_functions().get(name).cloned(), |func| { - Some(Arc::new( - create_udf( - func, - self.query_ctx.clone(), - self.engine_state.function_state(), - ) - .into(), - )) + Some(Arc::new(create_udf( + func, + self.query_ctx.clone(), + self.engine_state.function_state(), + ))) }, ) } diff --git a/src/query/src/query_engine.rs b/src/query/src/query_engine.rs index 61f6e1a8f0..c4e8aee7d1 100644 --- a/src/query/src/query_engine.rs +++ b/src/query/src/query_engine.rs @@ -28,7 +28,6 @@ use common_function::handlers::{ FlowServiceHandlerRef, ProcedureServiceHandlerRef, TableMutationHandlerRef, }; use common_function::scalars::aggregate::AggregateFunctionMetaRef; -use common_query::prelude::ScalarUdf; use common_query::Output; use datafusion_expr::LogicalPlan; use datatypes::schema::Schema; @@ -75,9 +74,6 @@ pub trait QueryEngine: Send + Sync { /// Execute the given [`LogicalPlan`]. async fn execute(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result; - /// Register a [`ScalarUdf`]. - fn register_udf(&self, udf: ScalarUdf); - /// Register an aggregate function. /// /// # Panics diff --git a/src/query/src/query_engine/default_serializer.rs b/src/query/src/query_engine/default_serializer.rs index d35feeb1a2..63ae3ab4fa 100644 --- a/src/query/src/query_engine/default_serializer.rs +++ b/src/query/src/query_engine/default_serializer.rs @@ -27,7 +27,7 @@ use datafusion::execution::context::SessionState; use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::{FunctionRegistry, SessionStateBuilder}; use datafusion::logical_expr::LogicalPlan; -use datafusion_expr::{ScalarUDF, UserDefinedLogicalNode}; +use datafusion_expr::UserDefinedLogicalNode; use greptime_proto::substrait_extension::MergeScan as PbMergeScan; use prost::Message; use session::context::QueryContextRef; @@ -120,9 +120,11 @@ impl SubstraitPlanDecoder for DefaultPlanDecoder { // e.g. The default UDF `to_char()` has an alias `date_format()`, if we register a UDF with the name `date_format()` // before we build the session state, the UDF will be lost. for func in FUNCTION_REGISTRY.functions() { - let udf: Arc = Arc::new( - create_udf(func.clone(), self.query_ctx.clone(), Default::default()).into(), - ); + let udf = Arc::new(create_udf( + func.clone(), + self.query_ctx.clone(), + Default::default(), + )); session_state .register_udf(udf) .context(RegisterUdfSnafu { name: func.name() })?; diff --git a/src/query/src/query_engine/state.rs b/src/query/src/query_engine/state.rs index 0eb31f31b5..ab63c6491a 100644 --- a/src/query/src/query_engine/state.rs +++ b/src/query/src/query_engine/state.rs @@ -25,7 +25,6 @@ use common_function::handlers::{ }; use common_function::scalars::aggregate::AggregateFunctionMetaRef; use common_function::state::FunctionState; -use common_query::prelude::ScalarUdf; use common_telemetry::warn; use datafusion::dataframe::DataFrame; use datafusion::error::Result as DfResult; @@ -242,11 +241,6 @@ impl QueryEngineState { .collect() } - /// Register a [`ScalarUdf`]. - pub fn register_udf(&self, udf: ScalarUdf) { - self.df_context.register_udf(udf.into()); - } - /// Register an aggregate function. /// /// # Panics diff --git a/src/query/src/tests.rs b/src/query/src/tests.rs index 34f2ecbdba..cbce67a4fe 100644 --- a/src/query/src/tests.rs +++ b/src/query/src/tests.rs @@ -32,7 +32,6 @@ mod scipy_stats_norm_pdf; mod time_range_filter_test; mod function; -mod pow; mod vec_product_test; mod vec_sum_test; diff --git a/src/query/src/tests/pow.rs b/src/query/src/tests/pow.rs deleted file mode 100644 index ffb0e85e02..0000000000 --- a/src/query/src/tests/pow.rs +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2023 Greptime Team -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::sync::Arc; - -use common_query::error::Result; -use datatypes::prelude::{ScalarVector, Vector}; -use datatypes::vectors::{UInt32Vector, VectorRef}; - -pub fn pow(args: &[VectorRef]) -> Result { - assert_eq!(args.len(), 2); - - let base = &args[0] - .as_any() - .downcast_ref::() - .expect("cast failed"); - let exponent = &args[1] - .as_any() - .downcast_ref::() - .expect("cast failed"); - - assert_eq!(exponent.len(), base.len()); - - let iter = base - .iter_data() - .zip(exponent.iter_data()) - .map(|(base, exponent)| { - match (base, exponent) { - // in arrow, any value can be null. - // Here we decide to make our UDF to return null when either base or exponent is null. - (Some(base), Some(exponent)) => Some(base.pow(exponent)), - _ => None, - } - }); - let v = UInt32Vector::from_owned_iterator(iter); - - Ok(Arc::new(v) as _) -} diff --git a/src/query/src/tests/query_engine_test.rs b/src/query/src/tests/query_engine_test.rs index d46d7afd9d..0f3f817703 100644 --- a/src/query/src/tests/query_engine_test.rs +++ b/src/query/src/tests/query_engine_test.rs @@ -19,7 +19,6 @@ use catalog::RegisterTableRequest; use common_base::Plugins; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, NUMBERS_TABLE_ID}; use common_error::ext::BoxedError; -use common_query::prelude::{create_udf, make_scalar_function, Volatility}; use common_query::OutputData; use common_recordbatch::{util, RecordBatch}; use datafusion::datasource::DefaultTableSource; @@ -37,8 +36,6 @@ use crate::error::{QueryExecutionSnafu, Result}; use crate::parser::QueryLanguageParser; use crate::query_engine::options::QueryOptions; use crate::query_engine::QueryEngineFactory; -use crate::tests::exec_selection; -use crate::tests::pow::pow; #[tokio::test] async fn test_datafusion_query_engine() -> Result<()> { @@ -150,46 +147,3 @@ async fn test_query_validate() -> Result<()> { .is_err()); Ok(()) } - -#[tokio::test] -async fn test_udf() -> Result<()> { - common_telemetry::init_default_ut_logging(); - let catalog_list = catalog_manager()?; - - let factory = QueryEngineFactory::new(catalog_list, None, None, None, None, false); - let engine = factory.query_engine(); - - let pow = make_scalar_function(pow); - - let udf = create_udf( - // datafusion already supports pow, so we use a different name. - "my_pow", - vec![ - ConcreteDataType::uint32_datatype(), - ConcreteDataType::uint32_datatype(), - ], - Arc::new(ConcreteDataType::uint32_datatype()), - Volatility::Immutable, - pow, - ); - - engine.register_udf(udf); - - let sql = "select my_pow(number, number) as p from numbers limit 10"; - let numbers = exec_selection(engine, sql).await; - assert_eq!(1, numbers.len()); - assert_eq!(numbers[0].num_columns(), 1); - assert_eq!(1, numbers[0].schema.num_columns()); - assert_eq!("p", numbers[0].schema.column_schemas()[0].name); - - let batch = &numbers[0]; - assert_eq!(1, batch.num_columns()); - assert_eq!(batch.column(0).len(), 10); - let expected: Vec = vec![1, 1, 4, 27, 256, 3125, 46656, 823543, 16777216, 387420489]; - assert_eq!( - *batch.column(0), - Arc::new(UInt32Vector::from_slice(expected)) as VectorRef - ); - - Ok(()) -}