common_function/
function_registry.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! functions registry
16use std::collections::HashMap;
17use std::sync::{Arc, LazyLock, RwLock};
18
19use datafusion::catalog::TableFunction;
20use datafusion_expr::{AggregateUDF, WindowUDF};
21
22use crate::admin::AdminFunction;
23use crate::aggrs::aggr_wrapper::StateMergeHelper;
24use crate::aggrs::approximate::ApproximateFunction;
25use crate::aggrs::count_hash::CountHash;
26use crate::aggrs::vector::VectorFunction as VectorAggrFunction;
27use crate::function::{Function, FunctionRef};
28use crate::function_factory::ScalarFunctionFactory;
29use crate::scalars::anomaly::AnomalyFunction;
30use crate::scalars::date::DateFunction;
31use crate::scalars::expression::ExpressionFunction;
32use crate::scalars::hll_count::HllCalcFunction;
33use crate::scalars::ip::IpFunctions;
34use crate::scalars::json::JsonFunction;
35use crate::scalars::matches::MatchesFunction;
36use crate::scalars::matches_term::MatchesTermFunction;
37use crate::scalars::math::MathFunction;
38use crate::scalars::primary_key::DecodePrimaryKeyFunction;
39use crate::scalars::string::register_string_functions;
40use crate::scalars::timestamp::TimestampFunction;
41use crate::scalars::uddsketch_calc::UddSketchCalcFunction;
42use crate::scalars::vector::VectorFunction as VectorScalarFunction;
43use crate::system::SystemFunction;
44
45#[derive(Default)]
46pub struct FunctionRegistry {
47    functions: RwLock<HashMap<String, ScalarFunctionFactory>>,
48    aggregate_functions: RwLock<HashMap<String, AggregateUDF>>,
49    table_functions: RwLock<HashMap<String, Arc<TableFunction>>>,
50    window_functions: RwLock<HashMap<String, WindowUDF>>,
51}
52
53impl FunctionRegistry {
54    /// Register a function in the registry by converting it into a `ScalarFunctionFactory`.
55    ///
56    /// # Arguments
57    ///
58    /// * `func` - An object that can be converted into a `ScalarFunctionFactory`.
59    ///
60    /// The function is inserted into the internal function map, keyed by its name.
61    /// If a function with the same name already exists, it will be replaced.
62    pub fn register(&self, func: impl Into<ScalarFunctionFactory>) {
63        let func = func.into();
64        let _ = self
65            .functions
66            .write()
67            .unwrap()
68            .insert(func.name().to_string(), func);
69    }
70
71    /// Register a scalar function in the registry.
72    pub fn register_scalar(&self, func: impl Function + 'static) {
73        let func = Arc::new(func) as FunctionRef;
74
75        for alias in func.aliases() {
76            let func: ScalarFunctionFactory = func.clone().into();
77            let alias = ScalarFunctionFactory {
78                name: alias.clone(),
79                ..func
80            };
81            self.register(alias);
82        }
83
84        self.register(func)
85    }
86
87    /// Register an aggregate function in the registry.
88    pub fn register_aggr(&self, func: AggregateUDF) {
89        let _ = self
90            .aggregate_functions
91            .write()
92            .unwrap()
93            .insert(func.name().to_string(), func);
94    }
95
96    /// Register a table function
97    pub fn register_table_function(&self, func: TableFunction) {
98        let _ = self
99            .table_functions
100            .write()
101            .unwrap()
102            .insert(func.name().to_string(), Arc::new(func));
103    }
104
105    /// Register a window function (UDWF).
106    pub fn register_window(&self, func: WindowUDF) {
107        let _ = self
108            .window_functions
109            .write()
110            .unwrap()
111            .insert(func.name().to_string(), func);
112    }
113
114    pub fn get_function(&self, name: &str) -> Option<ScalarFunctionFactory> {
115        self.functions.read().unwrap().get(name).cloned()
116    }
117
118    /// Returns a list of all scalar functions registered in the registry.
119    pub fn scalar_functions(&self) -> Vec<ScalarFunctionFactory> {
120        self.functions.read().unwrap().values().cloned().collect()
121    }
122
123    /// Returns a list of all aggregate functions registered in the registry.
124    pub fn aggregate_functions(&self) -> Vec<AggregateUDF> {
125        self.aggregate_functions
126            .read()
127            .unwrap()
128            .values()
129            .cloned()
130            .collect()
131    }
132
133    pub fn table_functions(&self) -> Vec<Arc<TableFunction>> {
134        self.table_functions
135            .read()
136            .unwrap()
137            .values()
138            .cloned()
139            .collect()
140    }
141
142    /// Returns a list of all window functions registered in the registry.
143    pub fn window_functions(&self) -> Vec<WindowUDF> {
144        self.window_functions
145            .read()
146            .unwrap()
147            .values()
148            .cloned()
149            .collect()
150    }
151
152    /// Returns true if an aggregate function with the given name exists in the registry.
153    pub fn is_aggr_func_exist(&self, name: &str) -> bool {
154        self.aggregate_functions.read().unwrap().contains_key(name)
155    }
156}
157
158pub static FUNCTION_REGISTRY: LazyLock<Arc<FunctionRegistry>> = LazyLock::new(|| {
159    let function_registry = FunctionRegistry::default();
160
161    // Utility functions
162    MathFunction::register(&function_registry);
163    TimestampFunction::register(&function_registry);
164    DateFunction::register(&function_registry);
165    ExpressionFunction::register(&function_registry);
166    UddSketchCalcFunction::register(&function_registry);
167    HllCalcFunction::register(&function_registry);
168    DecodePrimaryKeyFunction::register(&function_registry);
169
170    // Full text search function
171    MatchesFunction::register(&function_registry);
172    MatchesTermFunction::register(&function_registry);
173
174    // System and administration functions
175    SystemFunction::register(&function_registry);
176    AdminFunction::register(&function_registry);
177
178    // Json related functions
179    JsonFunction::register(&function_registry);
180
181    // String related functions
182    register_string_functions(&function_registry);
183
184    // Vector related functions
185    VectorScalarFunction::register(&function_registry);
186    VectorAggrFunction::register(&function_registry);
187
188    // Geo functions
189    #[cfg(feature = "geo")]
190    crate::scalars::geo::GeoFunctions::register(&function_registry);
191    #[cfg(feature = "geo")]
192    crate::aggrs::geo::GeoFunction::register(&function_registry);
193
194    // Ip functions
195    IpFunctions::register(&function_registry);
196
197    // Approximate functions
198    ApproximateFunction::register(&function_registry);
199
200    // CountHash function
201    CountHash::register(&function_registry);
202
203    // state function of supported aggregate functions
204    StateMergeHelper::register(&function_registry);
205
206    // Anomaly detection window functions
207    AnomalyFunction::register(&function_registry);
208
209    Arc::new(function_registry)
210});
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215    use crate::scalars::test::TestAndFunction;
216
217    #[test]
218    fn test_function_registry() {
219        let registry = FunctionRegistry::default();
220
221        assert!(registry.get_function("test_and").is_none());
222        assert!(registry.scalar_functions().is_empty());
223        registry.register_scalar(TestAndFunction::default());
224        let _ = registry.get_function("test_and").unwrap();
225        assert_eq!(1, registry.scalar_functions().len());
226    }
227}