Compare commits

...

1 Commits

Author SHA1 Message Date
Zhenchi
01081ef97b refactor: unify function registry (Part 2)
Signed-off-by: Zhenchi <zhongzc_arch@outlook.com>
2025-06-11 08:21:11 +00:00
7 changed files with 69 additions and 40 deletions

1
Cargo.lock generated
View File

@@ -2336,6 +2336,7 @@ dependencies = [
"num-traits",
"once_cell",
"paste",
"promql",
"s2",
"serde",
"serde_json",

View File

@@ -47,6 +47,7 @@ num = "0.4"
num-traits = "0.2"
once_cell.workspace = true
paste.workspace = true
promql.workspace = true
s2 = { version = "0.0.12", optional = true }
serde.workspace = true
serde_json.workspace = true

View File

@@ -39,7 +39,7 @@ use crate::system::SystemFunction;
#[derive(Default)]
pub struct FunctionRegistry {
functions: RwLock<HashMap<String, ScalarFunctionFactory>>,
scalar_functions: RwLock<HashMap<String, ScalarFunctionFactory>>,
async_functions: RwLock<HashMap<String, AsyncFunctionRef>>,
aggregate_functions: RwLock<HashMap<String, AggregateUDF>>,
}
@@ -48,7 +48,7 @@ impl FunctionRegistry {
pub fn register(&self, func: impl Into<ScalarFunctionFactory>) {
let func = func.into();
let _ = self
.functions
.scalar_functions
.write()
.unwrap()
.insert(func.name().to_string(), func);
@@ -87,13 +87,17 @@ impl FunctionRegistry {
.collect()
}
#[cfg(test)]
pub fn get_function(&self, name: &str) -> Option<ScalarFunctionFactory> {
self.functions.read().unwrap().get(name).cloned()
pub fn get_scalar_function(&self, name: &str) -> Option<ScalarFunctionFactory> {
self.scalar_functions.read().unwrap().get(name).cloned()
}
pub fn scalar_functions(&self) -> Vec<ScalarFunctionFactory> {
self.functions.read().unwrap().values().cloned().collect()
self.scalar_functions
.read()
.unwrap()
.values()
.cloned()
.collect()
}
pub fn aggregate_functions(&self) -> Vec<AggregateUDF> {
@@ -144,6 +148,11 @@ pub static FUNCTION_REGISTRY: Lazy<Arc<FunctionRegistry>> = Lazy::new(|| {
// Approximate functions
ApproximateFunction::register(&function_registry);
// PromQL aggregate functions
for aggr in promql::functions::aggr_funcs() {
function_registry.register_aggr(aggr);
}
Arc::new(function_registry)
});
@@ -156,10 +165,10 @@ mod tests {
fn test_function_registry() {
let registry = FunctionRegistry::default();
assert!(registry.get_function("test_and").is_none());
assert!(registry.get_scalar_function("test_and").is_none());
assert!(registry.scalar_functions().is_empty());
registry.register_scalar(TestAndFunction);
let _ = registry.get_function("test_and").unwrap();
let _ = registry.get_scalar_function("test_and").unwrap();
assert_eq!(1, registry.scalar_functions().len());
}
}

View File

@@ -34,6 +34,7 @@ pub use changes::Changes;
use datafusion::arrow::array::{ArrayRef, Float64Array, TimestampMillisecondArray};
use datafusion::error::DataFusionError;
use datafusion::physical_plan::ColumnarValue;
use datafusion_expr::{AggregateUDF, ScalarUDF};
pub use deriv::Deriv;
pub use extrapolate_rate::{Delta, Increase, Rate};
pub use holt_winters::HoltWinters;
@@ -44,6 +45,39 @@ pub use quantile_aggr::{quantile_udaf, QUANTILE_NAME};
pub use resets::Resets;
pub use round::Round;
/// Range functions for PromQL.
pub fn range_funcs() -> Vec<ScalarUDF> {
vec![
IDelta::<false>::scalar_udf(),
IDelta::<true>::scalar_udf(),
Rate::scalar_udf(),
Increase::scalar_udf(),
Delta::scalar_udf(),
Resets::scalar_udf(),
Changes::scalar_udf(),
Deriv::scalar_udf(),
Round::scalar_udf(),
AvgOverTime::scalar_udf(),
MinOverTime::scalar_udf(),
MaxOverTime::scalar_udf(),
SumOverTime::scalar_udf(),
CountOverTime::scalar_udf(),
LastOverTime::scalar_udf(),
AbsentOverTime::scalar_udf(),
PresentOverTime::scalar_udf(),
StddevOverTime::scalar_udf(),
StdvarOverTime::scalar_udf(),
QuantileOverTime::scalar_udf(),
PredictLinear::scalar_udf(),
HoltWinters::scalar_udf(),
]
}
/// Aggregate functions for PromQL.
pub fn aggr_funcs() -> Vec<AggregateUDF> {
vec![quantile_udaf()]
}
/// Extracts an array from a `ColumnarValue`.
///
/// If the `ColumnarValue` is a scalar, it converts it to an array of size 1.

View File

@@ -40,8 +40,8 @@ pub struct QuantileAccumulator {
/// Create a quantile `AggregateUDF` for PromQL quantile operator,
/// which calculates φ-quantile (0 ≤ φ ≤ 1) over dimensions
pub fn quantile_udaf() -> Arc<AggregateUDF> {
Arc::new(create_udaf(
pub fn quantile_udaf() -> AggregateUDF {
create_udaf(
QUANTILE_NAME,
// Input type: (φ, values)
vec![DataType::Float64, DataType::Float64],
@@ -63,7 +63,7 @@ pub fn quantile_udaf() -> Arc<AggregateUDF> {
)]
.into(),
)]),
))
)
}
impl QuantileAccumulator {

View File

@@ -1948,7 +1948,7 @@ impl PromPlanner {
token::T_QUANTILE => {
let q = Self::get_param_value_as_f64(op, param)?;
non_col_args.push(lit(q));
quantile_udaf()
Arc::new(quantile_udaf())
}
token::T_AVG => avg_udaf(),
token::T_COUNT_VALUES | token::T_COUNT => count_udaf(),

View File

@@ -28,11 +28,6 @@ use datafusion::execution::{FunctionRegistry, SessionStateBuilder};
use datafusion::logical_expr::LogicalPlan;
use datafusion_expr::UserDefinedLogicalNode;
use greptime_proto::substrait_extension::MergeScan as PbMergeScan;
use promql::functions::{
quantile_udaf, AbsentOverTime, AvgOverTime, Changes, CountOverTime, Delta, Deriv, IDelta,
Increase, LastOverTime, MaxOverTime, MinOverTime, PresentOverTime, Rate, Resets, Round,
StddevOverTime, StdvarOverTime, SumOverTime,
};
use prost::Message;
use session::context::QueryContextRef;
use snafu::ResultExt;
@@ -117,12 +112,15 @@ impl SubstraitPlanDecoder for DefaultPlanDecoder {
let mut session_state = SessionStateBuilder::new_from_existing(self.session_state.clone())
.with_catalog_list(catalog_list)
.build();
// Substrait decoder will look up the UDFs in SessionState, so we need to register them
// Note: the query context must be passed to set the timezone
// We MUST register the UDFs after we build the session state, otherwise the UDFs will be lost
// if they have the same name as the default UDFs or their alias.
// e.g. The default UDF `to_char()` has an alias `date_format()`, if we register a UDF with the name `date_format()`
// before we build the session state, the UDF will be lost.
// Scalar functions
for func in FUNCTION_REGISTRY.scalar_functions() {
let udf = func.provide(FunctionContext {
query_ctx: self.query_ctx.clone(),
@@ -133,6 +131,15 @@ impl SubstraitPlanDecoder for DefaultPlanDecoder {
.context(RegisterUdfSnafu { name: func.name() })?;
}
// PromQL range functions
for func in promql::functions::range_funcs() {
let name = func.name().to_string();
session_state
.register_udf(Arc::new(func))
.context(RegisterUdfSnafu { name })?;
}
// Aggregate functions
for func in FUNCTION_REGISTRY.aggregate_functions() {
let name = func.name().to_string();
session_state
@@ -140,29 +147,6 @@ impl SubstraitPlanDecoder for DefaultPlanDecoder {
.context(RegisterUdfSnafu { name })?;
}
let _ = session_state.register_udaf(quantile_udaf());
let _ = session_state.register_udf(Arc::new(IDelta::<false>::scalar_udf()));
let _ = session_state.register_udf(Arc::new(IDelta::<true>::scalar_udf()));
let _ = session_state.register_udf(Arc::new(Rate::scalar_udf()));
let _ = session_state.register_udf(Arc::new(Increase::scalar_udf()));
let _ = session_state.register_udf(Arc::new(Delta::scalar_udf()));
let _ = session_state.register_udf(Arc::new(Resets::scalar_udf()));
let _ = session_state.register_udf(Arc::new(Changes::scalar_udf()));
let _ = session_state.register_udf(Arc::new(Deriv::scalar_udf()));
let _ = session_state.register_udf(Arc::new(Round::scalar_udf()));
let _ = session_state.register_udf(Arc::new(AvgOverTime::scalar_udf()));
let _ = session_state.register_udf(Arc::new(MinOverTime::scalar_udf()));
let _ = session_state.register_udf(Arc::new(MaxOverTime::scalar_udf()));
let _ = session_state.register_udf(Arc::new(SumOverTime::scalar_udf()));
let _ = session_state.register_udf(Arc::new(CountOverTime::scalar_udf()));
let _ = session_state.register_udf(Arc::new(LastOverTime::scalar_udf()));
let _ = session_state.register_udf(Arc::new(AbsentOverTime::scalar_udf()));
let _ = session_state.register_udf(Arc::new(PresentOverTime::scalar_udf()));
let _ = session_state.register_udf(Arc::new(StddevOverTime::scalar_udf()));
let _ = session_state.register_udf(Arc::new(StdvarOverTime::scalar_udf()));
// TODO(ruihang): add quantile_over_time, predict_linear, holt_winters, round
let logical_plan = DFLogicalSubstraitConvertor
.decode(message, session_state)
.await