common_function/
function_registry.rs1use std::collections::HashMap;
17use std::sync::{Arc, LazyLock, RwLock};
18
19use datafusion::catalog::TableFunction;
20use datafusion_expr::{AggregateUDF, WindowUDF};
21
22use crate::admin::AdminFunction;
23use crate::aggrs::aggr_wrapper::StateMergeHelper;
24use crate::aggrs::approximate::ApproximateFunction;
25use crate::aggrs::count_hash::CountHash;
26use crate::aggrs::vector::VectorFunction as VectorAggrFunction;
27use crate::function::{Function, FunctionRef};
28use crate::function_factory::ScalarFunctionFactory;
29use crate::scalars::anomaly::AnomalyFunction;
30use crate::scalars::date::DateFunction;
31use crate::scalars::expression::ExpressionFunction;
32use crate::scalars::hll_count::HllCalcFunction;
33use crate::scalars::ip::IpFunctions;
34use crate::scalars::json::JsonFunction;
35use crate::scalars::matches::MatchesFunction;
36use crate::scalars::matches_term::MatchesTermFunction;
37use crate::scalars::math::MathFunction;
38use crate::scalars::primary_key::DecodePrimaryKeyFunction;
39use crate::scalars::string::register_string_functions;
40use crate::scalars::timestamp::TimestampFunction;
41use crate::scalars::uddsketch_calc::UddSketchCalcFunction;
42use crate::scalars::vector::VectorFunction as VectorScalarFunction;
43use crate::system::SystemFunction;
44
45#[derive(Default)]
46pub struct FunctionRegistry {
47 functions: RwLock<HashMap<String, ScalarFunctionFactory>>,
48 aggregate_functions: RwLock<HashMap<String, AggregateUDF>>,
49 table_functions: RwLock<HashMap<String, Arc<TableFunction>>>,
50 window_functions: RwLock<HashMap<String, WindowUDF>>,
51}
52
53impl FunctionRegistry {
54 pub fn register(&self, func: impl Into<ScalarFunctionFactory>) {
63 let func = func.into();
64 let _ = self
65 .functions
66 .write()
67 .unwrap()
68 .insert(func.name().to_string(), func);
69 }
70
71 pub fn register_scalar(&self, func: impl Function + 'static) {
73 let func = Arc::new(func) as FunctionRef;
74
75 for alias in func.aliases() {
76 let func: ScalarFunctionFactory = func.clone().into();
77 let alias = ScalarFunctionFactory {
78 name: alias.clone(),
79 ..func
80 };
81 self.register(alias);
82 }
83
84 self.register(func)
85 }
86
87 pub fn register_aggr(&self, func: AggregateUDF) {
89 let _ = self
90 .aggregate_functions
91 .write()
92 .unwrap()
93 .insert(func.name().to_string(), func);
94 }
95
96 pub fn register_table_function(&self, func: TableFunction) {
98 let _ = self
99 .table_functions
100 .write()
101 .unwrap()
102 .insert(func.name().to_string(), Arc::new(func));
103 }
104
105 pub fn register_window(&self, func: WindowUDF) {
107 let _ = self
108 .window_functions
109 .write()
110 .unwrap()
111 .insert(func.name().to_string(), func);
112 }
113
114 pub fn get_function(&self, name: &str) -> Option<ScalarFunctionFactory> {
115 self.functions.read().unwrap().get(name).cloned()
116 }
117
118 pub fn scalar_functions(&self) -> Vec<ScalarFunctionFactory> {
120 self.functions.read().unwrap().values().cloned().collect()
121 }
122
123 pub fn aggregate_functions(&self) -> Vec<AggregateUDF> {
125 self.aggregate_functions
126 .read()
127 .unwrap()
128 .values()
129 .cloned()
130 .collect()
131 }
132
133 pub fn table_functions(&self) -> Vec<Arc<TableFunction>> {
134 self.table_functions
135 .read()
136 .unwrap()
137 .values()
138 .cloned()
139 .collect()
140 }
141
142 pub fn window_functions(&self) -> Vec<WindowUDF> {
144 self.window_functions
145 .read()
146 .unwrap()
147 .values()
148 .cloned()
149 .collect()
150 }
151
152 pub fn is_aggr_func_exist(&self, name: &str) -> bool {
154 self.aggregate_functions.read().unwrap().contains_key(name)
155 }
156}
157
158pub static FUNCTION_REGISTRY: LazyLock<Arc<FunctionRegistry>> = LazyLock::new(|| {
159 let function_registry = FunctionRegistry::default();
160
161 MathFunction::register(&function_registry);
163 TimestampFunction::register(&function_registry);
164 DateFunction::register(&function_registry);
165 ExpressionFunction::register(&function_registry);
166 UddSketchCalcFunction::register(&function_registry);
167 HllCalcFunction::register(&function_registry);
168 DecodePrimaryKeyFunction::register(&function_registry);
169
170 MatchesFunction::register(&function_registry);
172 MatchesTermFunction::register(&function_registry);
173
174 SystemFunction::register(&function_registry);
176 AdminFunction::register(&function_registry);
177
178 JsonFunction::register(&function_registry);
180
181 register_string_functions(&function_registry);
183
184 VectorScalarFunction::register(&function_registry);
186 VectorAggrFunction::register(&function_registry);
187
188 #[cfg(feature = "geo")]
190 crate::scalars::geo::GeoFunctions::register(&function_registry);
191 #[cfg(feature = "geo")]
192 crate::aggrs::geo::GeoFunction::register(&function_registry);
193
194 IpFunctions::register(&function_registry);
196
197 ApproximateFunction::register(&function_registry);
199
200 CountHash::register(&function_registry);
202
203 StateMergeHelper::register(&function_registry);
205
206 AnomalyFunction::register(&function_registry);
208
209 Arc::new(function_registry)
210});
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215 use crate::scalars::test::TestAndFunction;
216
217 #[test]
218 fn test_function_registry() {
219 let registry = FunctionRegistry::default();
220
221 assert!(registry.get_function("test_and").is_none());
222 assert!(registry.scalar_functions().is_empty());
223 registry.register_scalar(TestAndFunction::default());
224 let _ = registry.get_function("test_and").unwrap();
225 assert_eq!(1, registry.scalar_functions().len());
226 }
227}