From c26138963e9f0b5b050b5c191a3236aef244f414 Mon Sep 17 00:00:00 2001 From: Zhenchi Date: Tue, 10 Jun 2025 18:11:06 +0800 Subject: [PATCH] refactor: unify function registry (Part 1) (#6262) * refactor: unify function registry (Part 1) Signed-off-by: Zhenchi * refactor: simplify via register_scalar Signed-off-by: Zhenchi --------- Signed-off-by: Zhenchi --- src/common/function/src/{aggr.rs => aggrs.rs} | 12 +-- src/common/function/src/aggrs/approximate.rs | 32 +++++++ .../src/{aggr => aggrs/approximate}/hll.rs | 0 .../approximate/uddsketch.rs} | 0 src/common/function/src/aggrs/geo.rs | 27 ++++++ .../src/{scalars => aggrs}/geo/encoding.rs | 15 +++- .../src/{aggr => aggrs/geo}/geo_path.rs | 2 +- src/common/function/src/aggrs/vector.rs | 29 ++++++ .../src/{scalars => aggrs}/vector/product.rs | 15 +++- .../src/{scalars => aggrs}/vector/sum.rs | 16 +++- src/common/function/src/function_factory.rs | 63 +++++++++++++ src/common/function/src/function_registry.rs | 65 ++++++++------ src/common/function/src/lib.rs | 5 +- src/common/function/src/scalars.rs | 1 - src/common/function/src/scalars/aggregate.rs | 89 ------------------- src/common/function/src/scalars/date.rs | 7 +- src/common/function/src/scalars/expression.rs | 4 +- src/common/function/src/scalars/geo.rs | 72 ++++++++------- .../function/src/scalars/geo/helpers.rs | 6 +- src/common/function/src/scalars/hll_count.rs | 7 +- src/common/function/src/scalars/ip.rs | 18 ++-- src/common/function/src/scalars/json.rs | 31 ++++--- src/common/function/src/scalars/matches.rs | 4 +- .../function/src/scalars/matches_term.rs | 2 +- src/common/function/src/scalars/math.rs | 15 ++-- src/common/function/src/scalars/timestamp.rs | 3 +- .../function/src/scalars/uddsketch_calc.rs | 3 +- src/common/function/src/scalars/vector.rs | 39 ++++---- src/common/function/src/system.rs | 14 +-- src/common/function/src/system/pg_catalog.rs | 8 +- src/datanode/src/tests.rs | 9 +- src/flow/src/transform.rs | 12 ++- .../index/fulltext_index/applier/builder.rs | 25 +++--- src/query/src/datafusion.rs | 17 ++-- src/query/src/datafusion/planner.rs | 42 +++------ src/query/src/query_engine.rs | 15 ++-- .../src/query_engine/default_serializer.rs | 74 +++++++-------- src/query/src/query_engine/state.rs | 74 +++++++-------- src/query/src/tests/my_sum_udaf_example.rs | 18 ++-- 39 files changed, 484 insertions(+), 406 deletions(-) rename src/common/function/src/{aggr.rs => aggrs.rs} (68%) create mode 100644 src/common/function/src/aggrs/approximate.rs rename src/common/function/src/{aggr => aggrs/approximate}/hll.rs (100%) rename src/common/function/src/{aggr/uddsketch_state.rs => aggrs/approximate/uddsketch.rs} (100%) create mode 100644 src/common/function/src/aggrs/geo.rs rename src/common/function/src/{scalars => aggrs}/geo/encoding.rs (94%) rename src/common/function/src/{aggr => aggrs/geo}/geo_path.rs (99%) create mode 100644 src/common/function/src/aggrs/vector.rs rename src/common/function/src/{scalars => aggrs}/vector/product.rs (94%) rename src/common/function/src/{scalars => aggrs}/vector/sum.rs (93%) create mode 100644 src/common/function/src/function_factory.rs delete mode 100644 src/common/function/src/scalars/aggregate.rs diff --git a/src/common/function/src/aggr.rs b/src/common/function/src/aggrs.rs similarity index 68% rename from src/common/function/src/aggr.rs rename to src/common/function/src/aggrs.rs index 8b4486906d..6bce2ee657 100644 --- a/src/common/function/src/aggr.rs +++ b/src/common/function/src/aggrs.rs @@ -12,11 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -mod geo_path; -mod hll; -mod uddsketch_state; - -pub use geo_path::{GeoPathAccumulator, GEO_PATH_NAME}; -pub(crate) use hll::HllStateType; -pub use hll::{HllState, HLL_MERGE_NAME, HLL_NAME}; -pub use uddsketch_state::{UddSketchState, UDDSKETCH_MERGE_NAME, UDDSKETCH_STATE_NAME}; +pub mod approximate; +#[cfg(feature = "geo")] +pub mod geo; +pub mod vector; diff --git a/src/common/function/src/aggrs/approximate.rs b/src/common/function/src/aggrs/approximate.rs new file mode 100644 index 0000000000..25044da27a --- /dev/null +++ b/src/common/function/src/aggrs/approximate.rs @@ -0,0 +1,32 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::function_registry::FunctionRegistry; + +pub(crate) mod hll; +mod uddsketch; + +pub(crate) struct ApproximateFunction; + +impl ApproximateFunction { + pub fn register(registry: &FunctionRegistry) { + // uddsketch + registry.register_aggr(uddsketch::UddSketchState::state_udf_impl()); + registry.register_aggr(uddsketch::UddSketchState::merge_udf_impl()); + + // hll + registry.register_aggr(hll::HllState::state_udf_impl()); + registry.register_aggr(hll::HllState::merge_udf_impl()); + } +} diff --git a/src/common/function/src/aggr/hll.rs b/src/common/function/src/aggrs/approximate/hll.rs similarity index 100% rename from src/common/function/src/aggr/hll.rs rename to src/common/function/src/aggrs/approximate/hll.rs diff --git a/src/common/function/src/aggr/uddsketch_state.rs b/src/common/function/src/aggrs/approximate/uddsketch.rs similarity index 100% rename from src/common/function/src/aggr/uddsketch_state.rs rename to src/common/function/src/aggrs/approximate/uddsketch.rs diff --git a/src/common/function/src/aggrs/geo.rs b/src/common/function/src/aggrs/geo.rs new file mode 100644 index 0000000000..5caa43d263 --- /dev/null +++ b/src/common/function/src/aggrs/geo.rs @@ -0,0 +1,27 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::function_registry::FunctionRegistry; + +mod encoding; +mod geo_path; + +pub(crate) struct GeoFunction; + +impl GeoFunction { + pub fn register(registry: &FunctionRegistry) { + registry.register_aggr(geo_path::GeoPathAccumulator::uadf_impl()); + registry.register_aggr(encoding::JsonPathAccumulator::uadf_impl()); + } +} diff --git a/src/common/function/src/scalars/geo/encoding.rs b/src/common/function/src/aggrs/geo/encoding.rs similarity index 94% rename from src/common/function/src/scalars/geo/encoding.rs rename to src/common/function/src/aggrs/geo/encoding.rs index 10a2df97be..b21b6c0dd1 100644 --- a/src/common/function/src/scalars/geo/encoding.rs +++ b/src/common/function/src/aggrs/geo/encoding.rs @@ -19,9 +19,12 @@ use common_error::status_code::StatusCode; use common_macro::{as_aggr_func_creator, AggrFuncTypeStore}; use common_query::error::{self, InvalidInputStateSnafu, Result}; use common_query::logical_plan::accumulator::AggrFuncTypeStore; -use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; +use common_query::logical_plan::{ + create_aggregate_function, Accumulator, AggregateFunctionCreator, +}; use common_query::prelude::AccumulatorCreatorFunction; use common_time::Timestamp; +use datafusion_expr::AggregateUDF; use datatypes::prelude::ConcreteDataType; use datatypes::value::{ListValue, Value}; use datatypes::vectors::VectorRef; @@ -47,6 +50,16 @@ impl JsonPathAccumulator { timestamp_type, } } + + /// Create a new `AggregateUDF` for the `json_encode_path` aggregate function. + pub fn uadf_impl() -> AggregateUDF { + create_aggregate_function( + "json_encode_path".to_string(), + 3, + Arc::new(JsonPathEncodeFunctionCreator::default()), + ) + .into() + } } impl Accumulator for JsonPathAccumulator { diff --git a/src/common/function/src/aggr/geo_path.rs b/src/common/function/src/aggrs/geo/geo_path.rs similarity index 99% rename from src/common/function/src/aggr/geo_path.rs rename to src/common/function/src/aggrs/geo/geo_path.rs index d5a2f71b57..08abe0c731 100644 --- a/src/common/function/src/aggr/geo_path.rs +++ b/src/common/function/src/aggrs/geo/geo_path.rs @@ -47,7 +47,7 @@ impl GeoPathAccumulator { Self::default() } - pub fn udf_impl() -> AggregateUDF { + pub fn uadf_impl() -> AggregateUDF { create_udaf( GEO_PATH_NAME, // Input types: lat, lng, timestamp diff --git a/src/common/function/src/aggrs/vector.rs b/src/common/function/src/aggrs/vector.rs new file mode 100644 index 0000000000..5af064d002 --- /dev/null +++ b/src/common/function/src/aggrs/vector.rs @@ -0,0 +1,29 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::aggrs::vector::product::VectorProduct; +use crate::aggrs::vector::sum::VectorSum; +use crate::function_registry::FunctionRegistry; + +mod product; +mod sum; + +pub(crate) struct VectorFunction; + +impl VectorFunction { + pub fn register(registry: &FunctionRegistry) { + registry.register_aggr(VectorSum::uadf_impl()); + registry.register_aggr(VectorProduct::uadf_impl()); + } +} diff --git a/src/common/function/src/scalars/vector/product.rs b/src/common/function/src/aggrs/vector/product.rs similarity index 94% rename from src/common/function/src/scalars/vector/product.rs rename to src/common/function/src/aggrs/vector/product.rs index fb1475ff14..8e7e62feee 100644 --- a/src/common/function/src/scalars/vector/product.rs +++ b/src/common/function/src/aggrs/vector/product.rs @@ -16,8 +16,11 @@ use std::sync::Arc; use common_macro::{as_aggr_func_creator, AggrFuncTypeStore}; use common_query::error::{CreateAccumulatorSnafu, Error, InvalidFuncArgsSnafu}; -use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; +use common_query::logical_plan::{ + create_aggregate_function, Accumulator, AggregateFunctionCreator, +}; use common_query::prelude::AccumulatorCreatorFunction; +use datafusion_expr::AggregateUDF; use datatypes::prelude::{ConcreteDataType, Value, *}; use datatypes::vectors::VectorRef; use nalgebra::{Const, DVectorView, Dyn, OVector}; @@ -75,6 +78,16 @@ impl AggregateFunctionCreator for VectorProductCreator { } impl VectorProduct { + /// Create a new `AggregateUDF` for the `vec_product` aggregate function. + pub fn uadf_impl() -> AggregateUDF { + create_aggregate_function( + "vec_product".to_string(), + 1, + Arc::new(VectorProductCreator::default()), + ) + .into() + } + fn inner(&mut self, len: usize) -> &mut OVector { self.product.get_or_insert_with(|| { OVector::from_iterator_generic(Dyn(len), Const::<1>, (0..len).map(|_| 1.0)) diff --git a/src/common/function/src/scalars/vector/sum.rs b/src/common/function/src/aggrs/vector/sum.rs similarity index 93% rename from src/common/function/src/scalars/vector/sum.rs rename to src/common/function/src/aggrs/vector/sum.rs index c293abbeb4..920b5ae289 100644 --- a/src/common/function/src/scalars/vector/sum.rs +++ b/src/common/function/src/aggrs/vector/sum.rs @@ -16,8 +16,11 @@ use std::sync::Arc; use common_macro::{as_aggr_func_creator, AggrFuncTypeStore}; use common_query::error::{CreateAccumulatorSnafu, Error, InvalidFuncArgsSnafu}; -use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; +use common_query::logical_plan::{ + create_aggregate_function, Accumulator, AggregateFunctionCreator, +}; use common_query::prelude::AccumulatorCreatorFunction; +use datafusion_expr::AggregateUDF; use datatypes::prelude::{ConcreteDataType, Value, *}; use datatypes::vectors::VectorRef; use nalgebra::{Const, DVectorView, Dyn, OVector}; @@ -25,6 +28,7 @@ use snafu::ensure; use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit}; +/// The accumulator for the `vec_sum` aggregate function. #[derive(Debug, Default)] pub struct VectorSum { sum: Option>, @@ -74,6 +78,16 @@ impl AggregateFunctionCreator for VectorSumCreator { } impl VectorSum { + /// Create a new `AggregateUDF` for the `vec_sum` aggregate function. + pub fn uadf_impl() -> AggregateUDF { + create_aggregate_function( + "vec_sum".to_string(), + 1, + Arc::new(VectorSumCreator::default()), + ) + .into() + } + fn inner(&mut self, len: usize) -> &mut OVector { self.sum .get_or_insert_with(|| OVector::zeros_generic(Dyn(len), Const::<1>)) diff --git a/src/common/function/src/function_factory.rs b/src/common/function/src/function_factory.rs new file mode 100644 index 0000000000..045692f187 --- /dev/null +++ b/src/common/function/src/function_factory.rs @@ -0,0 +1,63 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use datafusion_expr::ScalarUDF; + +use crate::function::{FunctionContext, FunctionRef}; +use crate::scalars::udf::create_udf; + +/// A factory for creating `ScalarUDF` that require a function context. +#[derive(Clone)] +pub struct ScalarFunctionFactory { + name: String, + factory: Arc ScalarUDF + Send + Sync>, +} + +impl ScalarFunctionFactory { + /// Returns the name of the function. + pub fn name(&self) -> &str { + &self.name + } + + /// Returns a `ScalarUDF` when given a function context. + pub fn provide(&self, ctx: FunctionContext) -> ScalarUDF { + (self.factory)(ctx) + } +} + +impl From for ScalarFunctionFactory { + fn from(df_udf: ScalarUDF) -> Self { + let name = df_udf.name().to_string(); + let func = Arc::new(move |_ctx| df_udf.clone()); + Self { + name, + factory: func, + } + } +} + +impl From for ScalarFunctionFactory { + fn from(func: FunctionRef) -> Self { + let name = func.name().to_string(); + let func = Arc::new(move |ctx: FunctionContext| { + create_udf(func.clone(), ctx.query_ctx, ctx.state) + }); + Self { + name, + factory: func, + } + } +} diff --git a/src/common/function/src/function_registry.rs b/src/common/function/src/function_registry.rs index 773131314c..134f040526 100644 --- a/src/common/function/src/function_registry.rs +++ b/src/common/function/src/function_registry.rs @@ -16,11 +16,14 @@ use std::collections::HashMap; use std::sync::{Arc, RwLock}; +use datafusion_expr::AggregateUDF; use once_cell::sync::Lazy; use crate::admin::AdminFunction; -use crate::function::{AsyncFunctionRef, FunctionRef}; -use crate::scalars::aggregate::{AggregateFunctionMetaRef, AggregateFunctions}; +use crate::aggrs::approximate::ApproximateFunction; +use crate::aggrs::vector::VectorFunction as VectorAggrFunction; +use crate::function::{AsyncFunctionRef, Function, FunctionRef}; +use crate::function_factory::ScalarFunctionFactory; use crate::scalars::date::DateFunction; use crate::scalars::expression::ExpressionFunction; use crate::scalars::hll_count::HllCalcFunction; @@ -31,18 +34,19 @@ use crate::scalars::matches_term::MatchesTermFunction; use crate::scalars::math::MathFunction; use crate::scalars::timestamp::TimestampFunction; use crate::scalars::uddsketch_calc::UddSketchCalcFunction; -use crate::scalars::vector::VectorFunction; +use crate::scalars::vector::VectorFunction as VectorScalarFunction; use crate::system::SystemFunction; #[derive(Default)] pub struct FunctionRegistry { - functions: RwLock>, + functions: RwLock>, async_functions: RwLock>, - aggregate_functions: RwLock>, + aggregate_functions: RwLock>, } impl FunctionRegistry { - pub fn register(&self, func: FunctionRef) { + pub fn register(&self, func: impl Into) { + let func = func.into(); let _ = self .functions .write() @@ -50,6 +54,10 @@ impl FunctionRegistry { .insert(func.name().to_string(), func); } + pub fn register_scalar(&self, func: impl Function + 'static) { + self.register(Arc::new(func) as FunctionRef); + } + pub fn register_async(&self, func: AsyncFunctionRef) { let _ = self .async_functions @@ -58,6 +66,14 @@ impl FunctionRegistry { .insert(func.name().to_string(), func); } + pub fn register_aggr(&self, func: AggregateUDF) { + let _ = self + .aggregate_functions + .write() + .unwrap() + .insert(func.name().to_string(), func); + } + pub fn get_async_function(&self, name: &str) -> Option { self.async_functions.read().unwrap().get(name).cloned() } @@ -71,27 +87,16 @@ impl FunctionRegistry { .collect() } - pub fn register_aggregate_function(&self, func: AggregateFunctionMetaRef) { - let _ = self - .aggregate_functions - .write() - .unwrap() - .insert(func.name(), func); - } - - pub fn get_aggr_function(&self, name: &str) -> Option { - self.aggregate_functions.read().unwrap().get(name).cloned() - } - - pub fn get_function(&self, name: &str) -> Option { + #[cfg(test)] + pub fn get_function(&self, name: &str) -> Option { self.functions.read().unwrap().get(name).cloned() } - pub fn functions(&self) -> Vec { + pub fn scalar_functions(&self) -> Vec { self.functions.read().unwrap().values().cloned().collect() } - pub fn aggregate_functions(&self) -> Vec { + pub fn aggregate_functions(&self) -> Vec { self.aggregate_functions .read() .unwrap() @@ -112,9 +117,6 @@ pub static FUNCTION_REGISTRY: Lazy> = Lazy::new(|| { UddSketchCalcFunction::register(&function_registry); HllCalcFunction::register(&function_registry); - // Aggregate functions - AggregateFunctions::register(&function_registry); - // Full text search function MatchesFunction::register(&function_registry); MatchesTermFunction::register(&function_registry); @@ -127,15 +129,21 @@ pub static FUNCTION_REGISTRY: Lazy> = Lazy::new(|| { JsonFunction::register(&function_registry); // Vector related functions - VectorFunction::register(&function_registry); + VectorScalarFunction::register(&function_registry); + VectorAggrFunction::register(&function_registry); // Geo functions #[cfg(feature = "geo")] crate::scalars::geo::GeoFunctions::register(&function_registry); + #[cfg(feature = "geo")] + crate::aggrs::geo::GeoFunction::register(&function_registry); // Ip functions IpFunctions::register(&function_registry); + // Approximate functions + ApproximateFunction::register(&function_registry); + Arc::new(function_registry) }); @@ -147,12 +155,11 @@ mod tests { #[test] fn test_function_registry() { let registry = FunctionRegistry::default(); - let func = Arc::new(TestAndFunction); assert!(registry.get_function("test_and").is_none()); - assert!(registry.functions().is_empty()); - registry.register(func); + assert!(registry.scalar_functions().is_empty()); + registry.register_scalar(TestAndFunction); let _ = registry.get_function("test_and").unwrap(); - assert_eq!(1, registry.functions().len()); + assert_eq!(1, registry.scalar_functions().len()); } } diff --git a/src/common/function/src/lib.rs b/src/common/function/src/lib.rs index ea5e20ee3c..95b8b6a3b1 100644 --- a/src/common/function/src/lib.rs +++ b/src/common/function/src/lib.rs @@ -18,13 +18,14 @@ mod admin; mod flush_flow; mod macros; -pub mod scalars; mod system; -pub mod aggr; +pub mod aggrs; pub mod function; +pub mod function_factory; pub mod function_registry; pub mod handlers; pub mod helper; +pub mod scalars; pub mod state; pub mod utils; diff --git a/src/common/function/src/scalars.rs b/src/common/function/src/scalars.rs index ac5389e9fd..6f93f2741d 100644 --- a/src/common/function/src/scalars.rs +++ b/src/common/function/src/scalars.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -pub mod aggregate; pub(crate) mod date; pub mod expression; #[cfg(feature = "geo")] diff --git a/src/common/function/src/scalars/aggregate.rs b/src/common/function/src/scalars/aggregate.rs deleted file mode 100644 index 65c82ba99c..0000000000 --- a/src/common/function/src/scalars/aggregate.rs +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright 2023 Greptime Team -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! # Deprecate Warning: -//! -//! This module is deprecated and will be removed in the future. -//! All UDAF implementation here are not maintained and should -//! not be used before they are refactored into the `src/aggr` -//! version. - -use std::sync::Arc; - -use common_query::logical_plan::AggregateFunctionCreatorRef; - -use crate::function_registry::FunctionRegistry; -use crate::scalars::vector::product::VectorProductCreator; -use crate::scalars::vector::sum::VectorSumCreator; - -/// A function creates `AggregateFunctionCreator`. -/// "Aggregator" *is* AggregatorFunction. Since the later one is long, we named an short alias for it. -/// The two names might be used interchangeably. -type AggregatorCreatorFunction = Arc AggregateFunctionCreatorRef + Send + Sync>; - -/// `AggregateFunctionMeta` dynamically creates AggregateFunctionCreator. -#[derive(Clone)] -pub struct AggregateFunctionMeta { - name: String, - args_count: u8, - creator: AggregatorCreatorFunction, -} - -pub type AggregateFunctionMetaRef = Arc; - -impl AggregateFunctionMeta { - pub fn new(name: &str, args_count: u8, creator: AggregatorCreatorFunction) -> Self { - Self { - name: name.to_string(), - args_count, - creator, - } - } - - pub fn name(&self) -> String { - self.name.to_string() - } - - pub fn args_count(&self) -> u8 { - self.args_count - } - - pub fn create(&self) -> AggregateFunctionCreatorRef { - (self.creator)() - } -} - -pub(crate) struct AggregateFunctions; - -impl AggregateFunctions { - pub fn register(registry: &FunctionRegistry) { - registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new( - "vec_sum", - 1, - Arc::new(|| Arc::new(VectorSumCreator::default())), - ))); - registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new( - "vec_product", - 1, - Arc::new(|| Arc::new(VectorProductCreator::default())), - ))); - - #[cfg(feature = "geo")] - registry.register_aggregate_function(Arc::new(AggregateFunctionMeta::new( - "json_encode_path", - 3, - Arc::new(|| Arc::new(super::geo::encoding::JsonPathEncodeFunctionCreator::default())), - ))); - } -} diff --git a/src/common/function/src/scalars/date.rs b/src/common/function/src/scalars/date.rs index 4b8e714ec5..5789d56496 100644 --- a/src/common/function/src/scalars/date.rs +++ b/src/common/function/src/scalars/date.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; mod date_add; mod date_format; mod date_sub; @@ -27,8 +26,8 @@ pub(crate) struct DateFunction; impl DateFunction { pub fn register(registry: &FunctionRegistry) { - registry.register(Arc::new(DateAddFunction)); - registry.register(Arc::new(DateSubFunction)); - registry.register(Arc::new(DateFormatFunction)); + registry.register_scalar(DateAddFunction); + registry.register_scalar(DateSubFunction); + registry.register_scalar(DateFormatFunction); } } diff --git a/src/common/function/src/scalars/expression.rs b/src/common/function/src/scalars/expression.rs index 573202f6e2..90732ba2ce 100644 --- a/src/common/function/src/scalars/expression.rs +++ b/src/common/function/src/scalars/expression.rs @@ -17,8 +17,6 @@ mod ctx; mod is_null; mod unary; -use std::sync::Arc; - pub use binary::scalar_binary_op; pub use ctx::EvalContext; pub use unary::scalar_unary_op; @@ -30,6 +28,6 @@ pub(crate) struct ExpressionFunction; impl ExpressionFunction { pub fn register(registry: &FunctionRegistry) { - registry.register(Arc::new(IsNullFunction)); + registry.register_scalar(IsNullFunction); } } diff --git a/src/common/function/src/scalars/geo.rs b/src/common/function/src/scalars/geo.rs index 37a7b3eb55..7dbcc890c0 100644 --- a/src/common/function/src/scalars/geo.rs +++ b/src/common/function/src/scalars/geo.rs @@ -12,11 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; -pub(crate) mod encoding; mod geohash; mod h3; -mod helpers; +pub(crate) mod helpers; mod measure; mod relation; mod s2; @@ -29,57 +27,57 @@ pub(crate) struct GeoFunctions; impl GeoFunctions { pub fn register(registry: &FunctionRegistry) { // geohash - registry.register(Arc::new(geohash::GeohashFunction)); - registry.register(Arc::new(geohash::GeohashNeighboursFunction)); + registry.register_scalar(geohash::GeohashFunction); + registry.register_scalar(geohash::GeohashNeighboursFunction); // h3 index - registry.register(Arc::new(h3::H3LatLngToCell)); - registry.register(Arc::new(h3::H3LatLngToCellString)); + registry.register_scalar(h3::H3LatLngToCell); + registry.register_scalar(h3::H3LatLngToCellString); // h3 index inspection - registry.register(Arc::new(h3::H3CellBase)); - registry.register(Arc::new(h3::H3CellIsPentagon)); - registry.register(Arc::new(h3::H3StringToCell)); - registry.register(Arc::new(h3::H3CellToString)); - registry.register(Arc::new(h3::H3CellCenterLatLng)); - registry.register(Arc::new(h3::H3CellResolution)); + registry.register_scalar(h3::H3CellBase); + registry.register_scalar(h3::H3CellIsPentagon); + registry.register_scalar(h3::H3StringToCell); + registry.register_scalar(h3::H3CellToString); + registry.register_scalar(h3::H3CellCenterLatLng); + registry.register_scalar(h3::H3CellResolution); // h3 hierarchical grid - registry.register(Arc::new(h3::H3CellCenterChild)); - registry.register(Arc::new(h3::H3CellParent)); - registry.register(Arc::new(h3::H3CellToChildren)); - registry.register(Arc::new(h3::H3CellToChildrenSize)); - registry.register(Arc::new(h3::H3CellToChildPos)); - registry.register(Arc::new(h3::H3ChildPosToCell)); - registry.register(Arc::new(h3::H3CellContains)); + registry.register_scalar(h3::H3CellCenterChild); + registry.register_scalar(h3::H3CellParent); + registry.register_scalar(h3::H3CellToChildren); + registry.register_scalar(h3::H3CellToChildrenSize); + registry.register_scalar(h3::H3CellToChildPos); + registry.register_scalar(h3::H3ChildPosToCell); + registry.register_scalar(h3::H3CellContains); // h3 grid traversal - registry.register(Arc::new(h3::H3GridDisk)); - registry.register(Arc::new(h3::H3GridDiskDistances)); - registry.register(Arc::new(h3::H3GridDistance)); - registry.register(Arc::new(h3::H3GridPathCells)); + registry.register_scalar(h3::H3GridDisk); + registry.register_scalar(h3::H3GridDiskDistances); + registry.register_scalar(h3::H3GridDistance); + registry.register_scalar(h3::H3GridPathCells); // h3 measurement - registry.register(Arc::new(h3::H3CellDistanceSphereKm)); - registry.register(Arc::new(h3::H3CellDistanceEuclideanDegree)); + registry.register_scalar(h3::H3CellDistanceSphereKm); + registry.register_scalar(h3::H3CellDistanceEuclideanDegree); // s2 - registry.register(Arc::new(s2::S2LatLngToCell)); - registry.register(Arc::new(s2::S2CellLevel)); - registry.register(Arc::new(s2::S2CellToToken)); - registry.register(Arc::new(s2::S2CellParent)); + registry.register_scalar(s2::S2LatLngToCell); + registry.register_scalar(s2::S2CellLevel); + registry.register_scalar(s2::S2CellToToken); + registry.register_scalar(s2::S2CellParent); // spatial data type - registry.register(Arc::new(wkt::LatLngToPointWkt)); + registry.register_scalar(wkt::LatLngToPointWkt); // spatial relation - registry.register(Arc::new(relation::STContains)); - registry.register(Arc::new(relation::STWithin)); - registry.register(Arc::new(relation::STIntersects)); + registry.register_scalar(relation::STContains); + registry.register_scalar(relation::STWithin); + registry.register_scalar(relation::STIntersects); // spatial measure - registry.register(Arc::new(measure::STDistance)); - registry.register(Arc::new(measure::STDistanceSphere)); - registry.register(Arc::new(measure::STArea)); + registry.register_scalar(measure::STDistance); + registry.register_scalar(measure::STDistanceSphere); + registry.register_scalar(measure::STArea); } } diff --git a/src/common/function/src/scalars/geo/helpers.rs b/src/common/function/src/scalars/geo/helpers.rs index 22d47f54e4..aba3c80543 100644 --- a/src/common/function/src/scalars/geo/helpers.rs +++ b/src/common/function/src/scalars/geo/helpers.rs @@ -37,7 +37,7 @@ macro_rules! ensure_columns_len { }; } -pub(super) use ensure_columns_len; +pub(crate) use ensure_columns_len; macro_rules! ensure_columns_n { ($columns:ident, $n:literal) => { @@ -58,7 +58,7 @@ macro_rules! ensure_columns_n { }; } -pub(super) use ensure_columns_n; +pub(crate) use ensure_columns_n; macro_rules! ensure_and_coerce { ($compare:expr, $coerce:expr) => {{ @@ -72,4 +72,4 @@ macro_rules! ensure_and_coerce { }}; } -pub(super) use ensure_and_coerce; +pub(crate) use ensure_and_coerce; diff --git a/src/common/function/src/scalars/hll_count.rs b/src/common/function/src/scalars/hll_count.rs index 6cde0c7064..c40d74a154 100644 --- a/src/common/function/src/scalars/hll_count.rs +++ b/src/common/function/src/scalars/hll_count.rs @@ -16,7 +16,6 @@ use std::fmt; use std::fmt::Display; -use std::sync::Arc; use common_query::error::{DowncastVectorSnafu, InvalidFuncArgsSnafu, Result}; use common_query::prelude::{Signature, Volatility}; @@ -27,7 +26,7 @@ use datatypes::vectors::{BinaryVector, MutableVector, UInt64VectorBuilder, Vecto use hyperloglogplus::HyperLogLog; use snafu::OptionExt; -use crate::aggr::HllStateType; +use crate::aggrs::approximate::hll::HllStateType; use crate::function::{Function, FunctionContext}; use crate::function_registry::FunctionRegistry; @@ -44,7 +43,7 @@ pub struct HllCalcFunction; impl HllCalcFunction { pub fn register(registry: &FunctionRegistry) { - registry.register(Arc::new(HllCalcFunction)); + registry.register_scalar(HllCalcFunction); } } @@ -117,6 +116,8 @@ impl Function for HllCalcFunction { #[cfg(test)] mod tests { + use std::sync::Arc; + use datatypes::vectors::BinaryVector; use super::*; diff --git a/src/common/function/src/scalars/ip.rs b/src/common/function/src/scalars/ip.rs index 8e860b3346..f10ac2e83b 100644 --- a/src/common/function/src/scalars/ip.rs +++ b/src/common/function/src/scalars/ip.rs @@ -17,8 +17,6 @@ mod ipv4; mod ipv6; mod range; -use std::sync::Arc; - use cidr::{Ipv4ToCidr, Ipv6ToCidr}; use ipv4::{Ipv4NumToString, Ipv4StringToNum}; use ipv6::{Ipv6NumToString, Ipv6StringToNum}; @@ -31,15 +29,15 @@ pub(crate) struct IpFunctions; impl IpFunctions { pub fn register(registry: &FunctionRegistry) { // Register IPv4 functions - registry.register(Arc::new(Ipv4NumToString)); - registry.register(Arc::new(Ipv4StringToNum)); - registry.register(Arc::new(Ipv4ToCidr)); - registry.register(Arc::new(Ipv4InRange)); + registry.register_scalar(Ipv4NumToString); + registry.register_scalar(Ipv4StringToNum); + registry.register_scalar(Ipv4ToCidr); + registry.register_scalar(Ipv4InRange); // Register IPv6 functions - registry.register(Arc::new(Ipv6NumToString)); - registry.register(Arc::new(Ipv6StringToNum)); - registry.register(Arc::new(Ipv6ToCidr)); - registry.register(Arc::new(Ipv6InRange)); + registry.register_scalar(Ipv6NumToString); + registry.register_scalar(Ipv6StringToNum); + registry.register_scalar(Ipv6ToCidr); + registry.register_scalar(Ipv6InRange); } } diff --git a/src/common/function/src/scalars/json.rs b/src/common/function/src/scalars/json.rs index 9cde42dcdb..e3eb0b0c9a 100644 --- a/src/common/function/src/scalars/json.rs +++ b/src/common/function/src/scalars/json.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; pub mod json_get; mod json_is; mod json_path_exists; @@ -33,23 +32,23 @@ pub(crate) struct JsonFunction; impl JsonFunction { pub fn register(registry: &FunctionRegistry) { - registry.register(Arc::new(JsonToStringFunction)); - registry.register(Arc::new(ParseJsonFunction)); + registry.register_scalar(JsonToStringFunction); + registry.register_scalar(ParseJsonFunction); - registry.register(Arc::new(JsonGetInt)); - registry.register(Arc::new(JsonGetFloat)); - registry.register(Arc::new(JsonGetString)); - registry.register(Arc::new(JsonGetBool)); + registry.register_scalar(JsonGetInt); + registry.register_scalar(JsonGetFloat); + registry.register_scalar(JsonGetString); + registry.register_scalar(JsonGetBool); - registry.register(Arc::new(JsonIsNull)); - registry.register(Arc::new(JsonIsInt)); - registry.register(Arc::new(JsonIsFloat)); - registry.register(Arc::new(JsonIsString)); - registry.register(Arc::new(JsonIsBool)); - registry.register(Arc::new(JsonIsArray)); - registry.register(Arc::new(JsonIsObject)); + registry.register_scalar(JsonIsNull); + registry.register_scalar(JsonIsInt); + registry.register_scalar(JsonIsFloat); + registry.register_scalar(JsonIsString); + registry.register_scalar(JsonIsBool); + registry.register_scalar(JsonIsArray); + registry.register_scalar(JsonIsObject); - registry.register(Arc::new(json_path_exists::JsonPathExistsFunction)); - registry.register(Arc::new(json_path_match::JsonPathMatchFunction)); + registry.register_scalar(json_path_exists::JsonPathExistsFunction); + registry.register_scalar(json_path_match::JsonPathMatchFunction); } } diff --git a/src/common/function/src/scalars/matches.rs b/src/common/function/src/scalars/matches.rs index edeffbb2f9..332e9890cb 100644 --- a/src/common/function/src/scalars/matches.rs +++ b/src/common/function/src/scalars/matches.rs @@ -38,11 +38,11 @@ use crate::function_registry::FunctionRegistry; /// /// Usage: matches(``, ``) -> boolean #[derive(Clone, Debug, Default)] -pub(crate) struct MatchesFunction; +pub struct MatchesFunction; impl MatchesFunction { pub fn register(registry: &FunctionRegistry) { - registry.register(Arc::new(MatchesFunction)); + registry.register_scalar(MatchesFunction); } } diff --git a/src/common/function/src/scalars/matches_term.rs b/src/common/function/src/scalars/matches_term.rs index 54cf556e85..018e269bbe 100644 --- a/src/common/function/src/scalars/matches_term.rs +++ b/src/common/function/src/scalars/matches_term.rs @@ -77,7 +77,7 @@ pub struct MatchesTermFunction; impl MatchesTermFunction { pub fn register(registry: &FunctionRegistry) { - registry.register(Arc::new(MatchesTermFunction)); + registry.register_scalar(MatchesTermFunction); } } diff --git a/src/common/function/src/scalars/math.rs b/src/common/function/src/scalars/math.rs index fd37a9fd6e..bb55f72e1c 100644 --- a/src/common/function/src/scalars/math.rs +++ b/src/common/function/src/scalars/math.rs @@ -18,7 +18,6 @@ mod pow; mod rate; use std::fmt; -use std::sync::Arc; pub use clamp::{ClampFunction, ClampMaxFunction, ClampMinFunction}; use common_query::error::{GeneralDataFusionSnafu, Result}; @@ -39,13 +38,13 @@ pub(crate) struct MathFunction; impl MathFunction { pub fn register(registry: &FunctionRegistry) { - registry.register(Arc::new(ModuloFunction)); - registry.register(Arc::new(PowFunction)); - registry.register(Arc::new(RateFunction)); - registry.register(Arc::new(RangeFunction)); - registry.register(Arc::new(ClampFunction)); - registry.register(Arc::new(ClampMinFunction)); - registry.register(Arc::new(ClampMaxFunction)); + registry.register_scalar(ModuloFunction); + registry.register_scalar(PowFunction); + registry.register_scalar(RateFunction); + registry.register_scalar(RangeFunction); + registry.register_scalar(ClampFunction); + registry.register_scalar(ClampMinFunction); + registry.register_scalar(ClampMaxFunction); } } diff --git a/src/common/function/src/scalars/timestamp.rs b/src/common/function/src/scalars/timestamp.rs index 35676ac793..faaf8d0524 100644 --- a/src/common/function/src/scalars/timestamp.rs +++ b/src/common/function/src/scalars/timestamp.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; mod to_unixtime; use to_unixtime::ToUnixtimeFunction; @@ -23,6 +22,6 @@ pub(crate) struct TimestampFunction; impl TimestampFunction { pub fn register(registry: &FunctionRegistry) { - registry.register(Arc::new(ToUnixtimeFunction)); + registry.register_scalar(ToUnixtimeFunction); } } diff --git a/src/common/function/src/scalars/uddsketch_calc.rs b/src/common/function/src/scalars/uddsketch_calc.rs index f429766eb7..917ab63244 100644 --- a/src/common/function/src/scalars/uddsketch_calc.rs +++ b/src/common/function/src/scalars/uddsketch_calc.rs @@ -16,7 +16,6 @@ use std::fmt; use std::fmt::Display; -use std::sync::Arc; use common_query::error::{DowncastVectorSnafu, InvalidFuncArgsSnafu, Result}; use common_query::prelude::{Signature, Volatility}; @@ -44,7 +43,7 @@ pub struct UddSketchCalcFunction; impl UddSketchCalcFunction { pub fn register(registry: &FunctionRegistry) { - registry.register(Arc::new(UddSketchCalcFunction)); + registry.register_scalar(UddSketchCalcFunction); } } diff --git a/src/common/function/src/scalars/vector.rs b/src/common/function/src/scalars/vector.rs index d8dc195e5b..46abf1e163 100644 --- a/src/common/function/src/scalars/vector.rs +++ b/src/common/function/src/scalars/vector.rs @@ -17,10 +17,8 @@ mod distance; mod elem_product; mod elem_sum; pub mod impl_conv; -pub(crate) mod product; mod scalar_add; mod scalar_mul; -pub(crate) mod sum; mod vector_add; mod vector_dim; mod vector_div; @@ -30,37 +28,34 @@ mod vector_norm; mod vector_sub; mod vector_subvector; -use std::sync::Arc; - use crate::function_registry::FunctionRegistry; - pub(crate) struct VectorFunction; impl VectorFunction { pub fn register(registry: &FunctionRegistry) { // conversion - registry.register(Arc::new(convert::ParseVectorFunction)); - registry.register(Arc::new(convert::VectorToStringFunction)); + registry.register_scalar(convert::ParseVectorFunction); + registry.register_scalar(convert::VectorToStringFunction); // distance - registry.register(Arc::new(distance::CosDistanceFunction)); - registry.register(Arc::new(distance::DotProductFunction)); - registry.register(Arc::new(distance::L2SqDistanceFunction)); + registry.register_scalar(distance::CosDistanceFunction); + registry.register_scalar(distance::DotProductFunction); + registry.register_scalar(distance::L2SqDistanceFunction); // scalar calculation - registry.register(Arc::new(scalar_add::ScalarAddFunction)); - registry.register(Arc::new(scalar_mul::ScalarMulFunction)); + registry.register_scalar(scalar_add::ScalarAddFunction); + registry.register_scalar(scalar_mul::ScalarMulFunction); // vector calculation - registry.register(Arc::new(vector_add::VectorAddFunction)); - registry.register(Arc::new(vector_sub::VectorSubFunction)); - registry.register(Arc::new(vector_mul::VectorMulFunction)); - registry.register(Arc::new(vector_div::VectorDivFunction)); - registry.register(Arc::new(vector_norm::VectorNormFunction)); - registry.register(Arc::new(vector_dim::VectorDimFunction)); - registry.register(Arc::new(vector_kth_elem::VectorKthElemFunction)); - registry.register(Arc::new(vector_subvector::VectorSubvectorFunction)); - registry.register(Arc::new(elem_sum::ElemSumFunction)); - registry.register(Arc::new(elem_product::ElemProductFunction)); + registry.register_scalar(vector_add::VectorAddFunction); + registry.register_scalar(vector_sub::VectorSubFunction); + registry.register_scalar(vector_mul::VectorMulFunction); + registry.register_scalar(vector_div::VectorDivFunction); + registry.register_scalar(vector_norm::VectorNormFunction); + registry.register_scalar(vector_dim::VectorDimFunction); + registry.register_scalar(vector_kth_elem::VectorKthElemFunction); + registry.register_scalar(vector_subvector::VectorSubvectorFunction); + registry.register_scalar(elem_sum::ElemSumFunction); + registry.register_scalar(elem_product::ElemProductFunction); } } diff --git a/src/common/function/src/system.rs b/src/common/function/src/system.rs index dad1e4f7bf..98acafb72e 100644 --- a/src/common/function/src/system.rs +++ b/src/common/function/src/system.rs @@ -36,13 +36,13 @@ pub(crate) struct SystemFunction; impl SystemFunction { pub fn register(registry: &FunctionRegistry) { - registry.register(Arc::new(BuildFunction)); - registry.register(Arc::new(VersionFunction)); - registry.register(Arc::new(CurrentSchemaFunction)); - registry.register(Arc::new(DatabaseFunction)); - registry.register(Arc::new(SessionUserFunction)); - registry.register(Arc::new(ReadPreferenceFunction)); - registry.register(Arc::new(TimezoneFunction)); + registry.register_scalar(BuildFunction); + registry.register_scalar(VersionFunction); + registry.register_scalar(CurrentSchemaFunction); + registry.register_scalar(DatabaseFunction); + registry.register_scalar(SessionUserFunction); + registry.register_scalar(ReadPreferenceFunction); + registry.register_scalar(TimezoneFunction); registry.register_async(Arc::new(ProcedureStateFunction)); PGCatalogFunction::register(registry); } diff --git a/src/common/function/src/system/pg_catalog.rs b/src/common/function/src/system/pg_catalog.rs index 26b7dc4f24..b064e11268 100644 --- a/src/common/function/src/system/pg_catalog.rs +++ b/src/common/function/src/system/pg_catalog.rs @@ -16,8 +16,6 @@ mod pg_get_userbyid; mod table_is_visible; mod version; -use std::sync::Arc; - use pg_get_userbyid::PGGetUserByIdFunction; use table_is_visible::PGTableIsVisibleFunction; use version::PGVersionFunction; @@ -35,8 +33,8 @@ pub(super) struct PGCatalogFunction; impl PGCatalogFunction { pub fn register(registry: &FunctionRegistry) { - registry.register(Arc::new(PGTableIsVisibleFunction)); - registry.register(Arc::new(PGGetUserByIdFunction)); - registry.register(Arc::new(PGVersionFunction)); + registry.register_scalar(PGTableIsVisibleFunction); + registry.register_scalar(PGGetUserByIdFunction); + registry.register_scalar(PGVersionFunction); } } diff --git a/src/datanode/src/tests.rs b/src/datanode/src/tests.rs index f182e1c423..ee6b611b46 100644 --- a/src/datanode/src/tests.rs +++ b/src/datanode/src/tests.rs @@ -19,12 +19,11 @@ use std::time::Duration; use api::region::RegionResponse; use async_trait::async_trait; use common_error::ext::BoxedError; -use common_function::function::FunctionRef; -use common_function::scalars::aggregate::AggregateFunctionMetaRef; +use common_function::function_factory::ScalarFunctionFactory; use common_query::Output; use common_runtime::runtime::{BuilderBuild, RuntimeTrait}; use common_runtime::Runtime; -use datafusion_expr::LogicalPlan; +use datafusion_expr::{AggregateUDF, LogicalPlan}; use query::dataframe::DataFrame; use query::planner::LogicalPlanner; use query::query_engine::{DescribeResult, QueryEngineState}; @@ -76,9 +75,9 @@ impl QueryEngine for MockQueryEngine { unimplemented!() } - fn register_aggregate_function(&self, _func: AggregateFunctionMetaRef) {} + fn register_aggregate_function(&self, _func: AggregateUDF) {} - fn register_function(&self, _func: FunctionRef) {} + fn register_scalar_function(&self, _func: ScalarFunctionFactory) {} fn read_table(&self, _table: TableRef) -> query::error::Result { unimplemented!() diff --git a/src/flow/src/transform.rs b/src/flow/src/transform.rs index 04c7f40e68..a3ecfcd5fe 100644 --- a/src/flow/src/transform.rs +++ b/src/flow/src/transform.rs @@ -17,7 +17,7 @@ use std::collections::BTreeMap; use std::sync::Arc; use common_error::ext::BoxedError; -use common_function::function::FunctionContext; +use common_function::function::{FunctionContext, FunctionRef}; use datafusion_substrait::extensions::Extensions; use datatypes::data_type::ConcreteDataType as CDT; use query::QueryEngine; @@ -108,9 +108,13 @@ impl FunctionExtensions { /// register flow-specific functions to the query engine pub fn register_function_to_query_engine(engine: &Arc) { - engine.register_function(Arc::new(TumbleFunction::new("tumble"))); - engine.register_function(Arc::new(TumbleFunction::new(TUMBLE_START))); - engine.register_function(Arc::new(TumbleFunction::new(TUMBLE_END))); + let tumble_fn = Arc::new(TumbleFunction::new("tumble")) as FunctionRef; + let tumble_start_fn = Arc::new(TumbleFunction::new(TUMBLE_START)) as FunctionRef; + let tumble_end_fn = Arc::new(TumbleFunction::new(TUMBLE_END)) as FunctionRef; + + engine.register_scalar_function(tumble_fn.into()); + engine.register_scalar_function(tumble_start_fn.into()); + engine.register_scalar_function(tumble_end_fn.into()); } #[derive(Debug)] diff --git a/src/mito2/src/sst/index/fulltext_index/applier/builder.rs b/src/mito2/src/sst/index/fulltext_index/applier/builder.rs index d3054306e6..81eb542625 100644 --- a/src/mito2/src/sst/index/fulltext_index/applier/builder.rs +++ b/src/mito2/src/sst/index/fulltext_index/applier/builder.rs @@ -282,14 +282,15 @@ mod tests { use std::sync::Arc; use api::v1::SemanticType; - use common_function::function_registry::FUNCTION_REGISTRY; - use common_function::scalars::udf::create_udf; + use common_function::function::FunctionRef; + use common_function::function_factory::ScalarFunctionFactory; + use common_function::scalars::matches::MatchesFunction; + use common_function::scalars::matches_term::MatchesTermFunction; use datafusion::functions::string::lower; use datafusion_common::Column; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::ScalarUDF; use datatypes::schema::ColumnSchema; - use session::context::QueryContext; use store_api::metadata::{ColumnMetadata, RegionMetadataBuilder}; use store_api::storage::RegionId; @@ -317,19 +318,17 @@ mod tests { } fn matches_func() -> Arc { - Arc::new(create_udf( - FUNCTION_REGISTRY.get_function("matches").unwrap(), - QueryContext::arc(), - Default::default(), - )) + Arc::new( + ScalarFunctionFactory::from(Arc::new(MatchesFunction) as FunctionRef) + .provide(Default::default()), + ) } fn matches_term_func() -> Arc { - Arc::new(create_udf( - FUNCTION_REGISTRY.get_function("matches_term").unwrap(), - QueryContext::arc(), - Default::default(), - )) + Arc::new( + ScalarFunctionFactory::from(Arc::new(MatchesTermFunction) as FunctionRef) + .provide(Default::default()), + ) } #[test] diff --git a/src/query/src/datafusion.rs b/src/query/src/datafusion.rs index db4207fd8a..685f27c355 100644 --- a/src/query/src/datafusion.rs +++ b/src/query/src/datafusion.rs @@ -25,8 +25,7 @@ use async_trait::async_trait; use common_base::Plugins; use common_catalog::consts::is_readonly_schema; use common_error::ext::BoxedError; -use common_function::function::FunctionRef; -use common_function::scalars::aggregate::AggregateFunctionMetaRef; +use common_function::function_factory::ScalarFunctionFactory; use common_query::{Output, OutputData, OutputMeta}; use common_recordbatch::adapter::RecordBatchStreamAdapter; use common_recordbatch::{EmptyRecordBatchStream, SendableRecordBatchStream}; @@ -35,7 +34,9 @@ use datafusion::physical_plan::analyze::AnalyzeExec; use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::physical_plan::ExecutionPlan; use datafusion_common::ResolvedTableReference; -use datafusion_expr::{DmlStatement, LogicalPlan as DfLogicalPlan, LogicalPlan, WriteOp}; +use datafusion_expr::{ + AggregateUDF, DmlStatement, LogicalPlan as DfLogicalPlan, LogicalPlan, WriteOp, +}; use datatypes::prelude::VectorRef; use datatypes::schema::Schema; use futures_util::StreamExt; @@ -454,14 +455,14 @@ impl QueryEngine for DatafusionQueryEngine { /// `SELECT "my_UDAF"(x)` will look for an aggregate named `"my_UDAF"` /// /// So it's better to make UDAF name lowercase when creating one. - fn register_aggregate_function(&self, func: AggregateFunctionMetaRef) { - self.state.register_aggregate_function(func); + fn register_aggregate_function(&self, func: AggregateUDF) { + self.state.register_aggr_function(func); } - /// Register an UDF function. + /// Register an scalar function. /// Will override if the function with same name is already registered. - fn register_function(&self, func: FunctionRef) { - self.state.register_function(func); + fn register_scalar_function(&self, func: ScalarFunctionFactory) { + self.state.register_scalar_function(func); } fn read_table(&self, table: TableRef) -> Result { diff --git a/src/query/src/datafusion/planner.rs b/src/query/src/datafusion/planner.rs index 6d0d99e296..5607e724ff 100644 --- a/src/query/src/datafusion/planner.rs +++ b/src/query/src/datafusion/planner.rs @@ -18,12 +18,7 @@ use std::sync::Arc; use arrow_schema::DataType; use catalog::table_source::DfTableSourceProvider; -use common_function::aggr::{ - GeoPathAccumulator, HllState, UddSketchState, GEO_PATH_NAME, HLL_MERGE_NAME, HLL_NAME, - UDDSKETCH_MERGE_NAME, UDDSKETCH_STATE_NAME, -}; -use common_function::scalars::udf::create_udf; -use common_query::logical_plan::create_aggregate_function; +use common_function::function::FunctionContext; use datafusion::common::TableReference; use datafusion::datasource::cte_worktable::CteWorkTable; use datafusion::datasource::file_format::{format_as_file_type, FileFormatFactory}; @@ -151,38 +146,21 @@ impl ContextProvider for DfContextProviderAdapter { } fn get_function_meta(&self, name: &str) -> Option> { - self.engine_state.udf_function(name).map_or_else( + self.engine_state.scalar_function(name).map_or_else( || self.session_state.scalar_functions().get(name).cloned(), |func| { - Some(Arc::new(create_udf( - func, - self.query_ctx.clone(), - self.engine_state.function_state(), - ))) + Some(Arc::new(func.provide(FunctionContext { + query_ctx: self.query_ctx.clone(), + state: self.engine_state.function_state(), + }))) }, ) } fn get_aggregate_meta(&self, name: &str) -> Option> { - if name == UDDSKETCH_STATE_NAME { - return Some(Arc::new(UddSketchState::state_udf_impl())); - } else if name == UDDSKETCH_MERGE_NAME { - return Some(Arc::new(UddSketchState::merge_udf_impl())); - } else if name == HLL_NAME { - return Some(Arc::new(HllState::state_udf_impl())); - } else if name == HLL_MERGE_NAME { - return Some(Arc::new(HllState::merge_udf_impl())); - } else if name == GEO_PATH_NAME { - return Some(Arc::new(GeoPathAccumulator::udf_impl())); - } - - self.engine_state.aggregate_function(name).map_or_else( + self.engine_state.aggr_function(name).map_or_else( || self.session_state.aggregate_functions().get(name).cloned(), - |func| { - Some(Arc::new( - create_aggregate_function(func.name(), func.args_count(), func.create()).into(), - )) - }, + |func| Some(Arc::new(func)), ) } @@ -213,13 +191,13 @@ impl ContextProvider for DfContextProviderAdapter { } fn udf_names(&self) -> Vec { - let mut names = self.engine_state.udf_names(); + let mut names = self.engine_state.scalar_names(); names.extend(self.session_state.scalar_functions().keys().cloned()); names } fn udaf_names(&self) -> Vec { - let mut names = self.engine_state.udaf_names(); + let mut names = self.engine_state.aggr_names(); names.extend(self.session_state.aggregate_functions().keys().cloned()); names } diff --git a/src/query/src/query_engine.rs b/src/query/src/query_engine.rs index 8b0c091054..e413854b75 100644 --- a/src/query/src/query_engine.rs +++ b/src/query/src/query_engine.rs @@ -22,14 +22,13 @@ use std::sync::Arc; use async_trait::async_trait; use catalog::CatalogManagerRef; use common_base::Plugins; -use common_function::function::FunctionRef; +use common_function::function_factory::ScalarFunctionFactory; use common_function::function_registry::FUNCTION_REGISTRY; use common_function::handlers::{ FlowServiceHandlerRef, ProcedureServiceHandlerRef, TableMutationHandlerRef, }; -use common_function::scalars::aggregate::AggregateFunctionMetaRef; use common_query::Output; -use datafusion_expr::LogicalPlan; +use datafusion_expr::{AggregateUDF, LogicalPlan}; use datatypes::schema::Schema; pub use default_serializer::{DefaultPlanDecoder, DefaultSerializer}; use session::context::QueryContextRef; @@ -79,11 +78,11 @@ pub trait QueryEngine: Send + Sync { /// /// # Panics /// Will panic if the function with same name is already registered. - fn register_aggregate_function(&self, func: AggregateFunctionMetaRef); + fn register_aggregate_function(&self, func: AggregateUDF); - /// Register a SQL function. + /// Register a scalar function. /// Will override if the function with same name is already registered. - fn register_function(&self, func: FunctionRef); + fn register_scalar_function(&self, func: ScalarFunctionFactory); /// Create a DataFrame from a table. fn read_table(&self, table: TableRef) -> Result; @@ -154,8 +153,8 @@ impl QueryEngineFactory { /// Register all functions implemented by GreptimeDB fn register_functions(query_engine: &Arc) { - for func in FUNCTION_REGISTRY.functions() { - query_engine.register_function(func); + for func in FUNCTION_REGISTRY.scalar_functions() { + query_engine.register_scalar_function(func); } for accumulator in FUNCTION_REGISTRY.aggregate_functions() { diff --git a/src/query/src/query_engine/default_serializer.rs b/src/query/src/query_engine/default_serializer.rs index 6045415a9e..50f7c79ff3 100644 --- a/src/query/src/query_engine/default_serializer.rs +++ b/src/query/src/query_engine/default_serializer.rs @@ -15,9 +15,8 @@ use std::sync::Arc; use common_error::ext::BoxedError; -use common_function::aggr::{GeoPathAccumulator, HllState, UddSketchState}; +use common_function::function::FunctionContext; use common_function::function_registry::FUNCTION_REGISTRY; -use common_function::scalars::udf::create_udf; use common_query::error::RegisterUdfSnafu; use common_query::logical_plan::SubstraitPlanDecoder; use datafusion::catalog::CatalogProviderList; @@ -124,43 +123,46 @@ impl SubstraitPlanDecoder for DefaultPlanDecoder { // if they have the same name as the default UDFs or their alias. // e.g. The default UDF `to_char()` has an alias `date_format()`, if we register a UDF with the name `date_format()` // before we build the session state, the UDF will be lost. - for func in FUNCTION_REGISTRY.functions() { - let udf = Arc::new(create_udf( - func.clone(), - self.query_ctx.clone(), - Default::default(), - )); + for func in FUNCTION_REGISTRY.scalar_functions() { + let udf = func.provide(FunctionContext { + query_ctx: self.query_ctx.clone(), + state: Default::default(), + }); session_state - .register_udf(udf) + .register_udf(Arc::new(udf)) .context(RegisterUdfSnafu { name: func.name() })?; - let _ = session_state.register_udaf(Arc::new(UddSketchState::state_udf_impl())); - let _ = session_state.register_udaf(Arc::new(UddSketchState::merge_udf_impl())); - let _ = session_state.register_udaf(Arc::new(HllState::state_udf_impl())); - let _ = session_state.register_udaf(Arc::new(HllState::merge_udf_impl())); - let _ = session_state.register_udaf(Arc::new(GeoPathAccumulator::udf_impl())); - let _ = session_state.register_udaf(quantile_udaf()); - - let _ = session_state.register_udf(Arc::new(IDelta::::scalar_udf())); - let _ = session_state.register_udf(Arc::new(IDelta::::scalar_udf())); - let _ = session_state.register_udf(Arc::new(Rate::scalar_udf())); - let _ = session_state.register_udf(Arc::new(Increase::scalar_udf())); - let _ = session_state.register_udf(Arc::new(Delta::scalar_udf())); - let _ = session_state.register_udf(Arc::new(Resets::scalar_udf())); - let _ = session_state.register_udf(Arc::new(Changes::scalar_udf())); - let _ = session_state.register_udf(Arc::new(Deriv::scalar_udf())); - let _ = session_state.register_udf(Arc::new(Round::scalar_udf())); - let _ = session_state.register_udf(Arc::new(AvgOverTime::scalar_udf())); - let _ = session_state.register_udf(Arc::new(MinOverTime::scalar_udf())); - let _ = session_state.register_udf(Arc::new(MaxOverTime::scalar_udf())); - let _ = session_state.register_udf(Arc::new(SumOverTime::scalar_udf())); - let _ = session_state.register_udf(Arc::new(CountOverTime::scalar_udf())); - let _ = session_state.register_udf(Arc::new(LastOverTime::scalar_udf())); - let _ = session_state.register_udf(Arc::new(AbsentOverTime::scalar_udf())); - let _ = session_state.register_udf(Arc::new(PresentOverTime::scalar_udf())); - let _ = session_state.register_udf(Arc::new(StddevOverTime::scalar_udf())); - let _ = session_state.register_udf(Arc::new(StdvarOverTime::scalar_udf())); - // TODO(ruihang): add quantile_over_time, predict_linear, holt_winters, round } + + for func in FUNCTION_REGISTRY.aggregate_functions() { + let name = func.name().to_string(); + session_state + .register_udaf(Arc::new(func)) + .context(RegisterUdfSnafu { name })?; + } + + let _ = session_state.register_udaf(quantile_udaf()); + + let _ = session_state.register_udf(Arc::new(IDelta::::scalar_udf())); + let _ = session_state.register_udf(Arc::new(IDelta::::scalar_udf())); + let _ = session_state.register_udf(Arc::new(Rate::scalar_udf())); + let _ = session_state.register_udf(Arc::new(Increase::scalar_udf())); + let _ = session_state.register_udf(Arc::new(Delta::scalar_udf())); + let _ = session_state.register_udf(Arc::new(Resets::scalar_udf())); + let _ = session_state.register_udf(Arc::new(Changes::scalar_udf())); + let _ = session_state.register_udf(Arc::new(Deriv::scalar_udf())); + let _ = session_state.register_udf(Arc::new(Round::scalar_udf())); + let _ = session_state.register_udf(Arc::new(AvgOverTime::scalar_udf())); + let _ = session_state.register_udf(Arc::new(MinOverTime::scalar_udf())); + let _ = session_state.register_udf(Arc::new(MaxOverTime::scalar_udf())); + let _ = session_state.register_udf(Arc::new(SumOverTime::scalar_udf())); + let _ = session_state.register_udf(Arc::new(CountOverTime::scalar_udf())); + let _ = session_state.register_udf(Arc::new(LastOverTime::scalar_udf())); + let _ = session_state.register_udf(Arc::new(AbsentOverTime::scalar_udf())); + let _ = session_state.register_udf(Arc::new(PresentOverTime::scalar_udf())); + let _ = session_state.register_udf(Arc::new(StddevOverTime::scalar_udf())); + let _ = session_state.register_udf(Arc::new(StdvarOverTime::scalar_udf())); + // TODO(ruihang): add quantile_over_time, predict_linear, holt_winters, round + let logical_plan = DFLogicalSubstraitConvertor .decode(message, session_state) .await diff --git a/src/query/src/query_engine/state.rs b/src/query/src/query_engine/state.rs index 7acd38aa37..09f5e25d9d 100644 --- a/src/query/src/query_engine/state.rs +++ b/src/query/src/query_engine/state.rs @@ -19,11 +19,10 @@ use std::sync::{Arc, RwLock}; use async_trait::async_trait; use catalog::CatalogManagerRef; use common_base::Plugins; -use common_function::function::FunctionRef; +use common_function::function_factory::ScalarFunctionFactory; use common_function::handlers::{ FlowServiceHandlerRef, ProcedureServiceHandlerRef, TableMutationHandlerRef, }; -use common_function::scalars::aggregate::AggregateFunctionMetaRef; use common_function::state::FunctionState; use common_telemetry::warn; use datafusion::dataframe::DataFrame; @@ -37,7 +36,7 @@ use datafusion::physical_optimizer::sanity_checker::SanityCheckPlan; use datafusion::physical_optimizer::PhysicalOptimizerRule; use datafusion::physical_plan::ExecutionPlan; use datafusion::physical_planner::{DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner}; -use datafusion_expr::LogicalPlan as DfLogicalPlan; +use datafusion_expr::{AggregateUDF, LogicalPlan as DfLogicalPlan}; use datafusion_optimizer::analyzer::count_wildcard_rule::CountWildcardRule; use datafusion_optimizer::analyzer::{Analyzer, AnalyzerRule}; use datafusion_optimizer::optimizer::Optimizer; @@ -70,8 +69,8 @@ pub struct QueryEngineState { df_context: SessionContext, catalog_manager: CatalogManagerRef, function_state: Arc, - udf_functions: Arc>>, - aggregate_functions: Arc>>, + scalar_functions: Arc>>, + aggr_functions: Arc>>, extension_rules: Vec>, plugins: Plugins, } @@ -186,10 +185,10 @@ impl QueryEngineState { procedure_service_handler, flow_service_handler, }), - aggregate_functions: Arc::new(RwLock::new(HashMap::new())), + aggr_functions: Arc::new(RwLock::new(HashMap::new())), extension_rules, plugins, - udf_functions: Arc::new(RwLock::new(HashMap::new())), + scalar_functions: Arc::new(RwLock::new(HashMap::new())), } } @@ -222,38 +221,28 @@ impl QueryEngineState { self.session_state().optimize(&plan) } - /// Register an udf function. - /// Will override if the function with same name is already registered. - pub fn register_function(&self, func: FunctionRef) { - let name = func.name().to_string(); - let x = self - .udf_functions - .write() - .unwrap() - .insert(name.clone(), func); - - if x.is_some() { - warn!("Already registered udf function '{name}'"); - } - } - - /// Retrieve the udf function by name - pub fn udf_function(&self, function_name: &str) -> Option { - self.udf_functions + /// Retrieve the scalar function by name + pub fn scalar_function(&self, function_name: &str) -> Option { + self.scalar_functions .read() .unwrap() .get(function_name) .cloned() } - /// Retrieve udf function names. - pub fn udf_names(&self) -> Vec { - self.udf_functions.read().unwrap().keys().cloned().collect() + /// Retrieve scalar function names. + pub fn scalar_names(&self) -> Vec { + self.scalar_functions + .read() + .unwrap() + .keys() + .cloned() + .collect() } /// Retrieve the aggregate function by name - pub fn aggregate_function(&self, function_name: &str) -> Option { - self.aggregate_functions + pub fn aggr_function(&self, function_name: &str) -> Option { + self.aggr_functions .read() .unwrap() .get(function_name) @@ -261,8 +250,8 @@ impl QueryEngineState { } /// Retrieve aggregate function names. - pub fn udaf_names(&self) -> Vec { - self.aggregate_functions + pub fn aggr_names(&self) -> Vec { + self.aggr_functions .read() .unwrap() .keys() @@ -270,6 +259,21 @@ impl QueryEngineState { .collect() } + /// Register an scalar function. + /// Will override if the function with same name is already registered. + pub fn register_scalar_function(&self, func: ScalarFunctionFactory) { + let name = func.name().to_string(); + let x = self + .scalar_functions + .write() + .unwrap() + .insert(name.clone(), func); + + if x.is_some() { + warn!("Already registered scalar function '{name}'"); + } + } + /// Register an aggregate function. /// /// # Panics @@ -278,10 +282,10 @@ impl QueryEngineState { /// Panicking consideration: currently the aggregated functions are all statically registered, /// user cannot define their own aggregate functions on the fly. So we can panic here. If that /// invariant is broken in the future, we should return an error instead of panicking. - pub fn register_aggregate_function(&self, func: AggregateFunctionMetaRef) { - let name = func.name(); + pub fn register_aggr_function(&self, func: AggregateUDF) { + let name = func.name().to_string(); let x = self - .aggregate_functions + .aggr_functions .write() .unwrap() .insert(name.clone(), func); diff --git a/src/query/src/tests/my_sum_udaf_example.rs b/src/query/src/tests/my_sum_udaf_example.rs index 63653ffbce..9c7bc0fdd1 100644 --- a/src/query/src/tests/my_sum_udaf_example.rs +++ b/src/query/src/tests/my_sum_udaf_example.rs @@ -16,11 +16,12 @@ use std::fmt::Debug; use std::marker::PhantomData; use std::sync::Arc; -use common_function::scalars::aggregate::AggregateFunctionMeta; use common_macro::{as_aggr_func_creator, AggrFuncTypeStore}; use common_query::error::{CreateAccumulatorSnafu, Result as QueryResult}; use common_query::logical_plan::accumulator::AggrFuncTypeStore; -use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; +use common_query::logical_plan::{ + create_aggregate_function, Accumulator, AggregateFunctionCreator, +}; use common_query::prelude::*; use common_recordbatch::{RecordBatch, RecordBatches}; use datatypes::prelude::*; @@ -207,11 +208,14 @@ where let engine = new_query_engine_with_table(testing_table); - engine.register_aggregate_function(Arc::new(AggregateFunctionMeta::new( - "my_sum", - 1, - Arc::new(|| Arc::new(MySumAccumulatorCreator::default())), - ))); + engine.register_aggregate_function( + create_aggregate_function( + "my_sum".to_string(), + 1, + Arc::new(MySumAccumulatorCreator::default()), + ) + .into(), + ); let sql = format!("select MY_SUM({column_name}) as my_sum from {table_name}"); let batches = exec_selection(engine, &sql).await;