diff --git a/Cargo.lock b/Cargo.lock index 928e2fb267..88516b3d76 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2336,6 +2336,7 @@ dependencies = [ "num-traits", "once_cell", "paste", + "promql", "s2", "serde", "serde_json", diff --git a/src/common/function/Cargo.toml b/src/common/function/Cargo.toml index 292fc07cf0..4cebea4fb0 100644 --- a/src/common/function/Cargo.toml +++ b/src/common/function/Cargo.toml @@ -47,6 +47,7 @@ num = "0.4" num-traits = "0.2" once_cell.workspace = true paste.workspace = true +promql.workspace = true s2 = { version = "0.0.12", optional = true } serde.workspace = true serde_json.workspace = true diff --git a/src/common/function/src/function_registry.rs b/src/common/function/src/function_registry.rs index 134f040526..683412b20c 100644 --- a/src/common/function/src/function_registry.rs +++ b/src/common/function/src/function_registry.rs @@ -39,7 +39,7 @@ use crate::system::SystemFunction; #[derive(Default)] pub struct FunctionRegistry { - functions: RwLock>, + scalar_functions: RwLock>, async_functions: RwLock>, aggregate_functions: RwLock>, } @@ -48,7 +48,7 @@ impl FunctionRegistry { pub fn register(&self, func: impl Into) { let func = func.into(); let _ = self - .functions + .scalar_functions .write() .unwrap() .insert(func.name().to_string(), func); @@ -87,13 +87,17 @@ impl FunctionRegistry { .collect() } - #[cfg(test)] - pub fn get_function(&self, name: &str) -> Option { - self.functions.read().unwrap().get(name).cloned() + pub fn get_scalar_function(&self, name: &str) -> Option { + self.scalar_functions.read().unwrap().get(name).cloned() } pub fn scalar_functions(&self) -> Vec { - self.functions.read().unwrap().values().cloned().collect() + self.scalar_functions + .read() + .unwrap() + .values() + .cloned() + .collect() } pub fn aggregate_functions(&self) -> Vec { @@ -144,6 +148,11 @@ pub static FUNCTION_REGISTRY: Lazy> = Lazy::new(|| { // Approximate functions ApproximateFunction::register(&function_registry); + // PromQL aggregate functions + for aggr in promql::functions::aggr_funcs() { + function_registry.register_aggr(aggr); + } + Arc::new(function_registry) }); @@ -156,10 +165,10 @@ mod tests { fn test_function_registry() { let registry = FunctionRegistry::default(); - assert!(registry.get_function("test_and").is_none()); + assert!(registry.get_scalar_function("test_and").is_none()); assert!(registry.scalar_functions().is_empty()); registry.register_scalar(TestAndFunction); - let _ = registry.get_function("test_and").unwrap(); + let _ = registry.get_scalar_function("test_and").unwrap(); assert_eq!(1, registry.scalar_functions().len()); } } diff --git a/src/promql/src/functions.rs b/src/promql/src/functions.rs index 12841dee09..53c58930e6 100644 --- a/src/promql/src/functions.rs +++ b/src/promql/src/functions.rs @@ -34,6 +34,7 @@ pub use changes::Changes; use datafusion::arrow::array::{ArrayRef, Float64Array, TimestampMillisecondArray}; use datafusion::error::DataFusionError; use datafusion::physical_plan::ColumnarValue; +use datafusion_expr::{AggregateUDF, ScalarUDF}; pub use deriv::Deriv; pub use extrapolate_rate::{Delta, Increase, Rate}; pub use holt_winters::HoltWinters; @@ -44,6 +45,39 @@ pub use quantile_aggr::{quantile_udaf, QUANTILE_NAME}; pub use resets::Resets; pub use round::Round; +/// Range functions for PromQL. +pub fn range_funcs() -> Vec { + vec![ + IDelta::::scalar_udf(), + IDelta::::scalar_udf(), + Rate::scalar_udf(), + Increase::scalar_udf(), + Delta::scalar_udf(), + Resets::scalar_udf(), + Changes::scalar_udf(), + Deriv::scalar_udf(), + Round::scalar_udf(), + AvgOverTime::scalar_udf(), + MinOverTime::scalar_udf(), + MaxOverTime::scalar_udf(), + SumOverTime::scalar_udf(), + CountOverTime::scalar_udf(), + LastOverTime::scalar_udf(), + AbsentOverTime::scalar_udf(), + PresentOverTime::scalar_udf(), + StddevOverTime::scalar_udf(), + StdvarOverTime::scalar_udf(), + QuantileOverTime::scalar_udf(), + PredictLinear::scalar_udf(), + HoltWinters::scalar_udf(), + ] +} + +/// Aggregate functions for PromQL. +pub fn aggr_funcs() -> Vec { + vec![quantile_udaf()] +} + /// Extracts an array from a `ColumnarValue`. /// /// If the `ColumnarValue` is a scalar, it converts it to an array of size 1. diff --git a/src/promql/src/functions/quantile_aggr.rs b/src/promql/src/functions/quantile_aggr.rs index 5652f57342..90e96276c2 100644 --- a/src/promql/src/functions/quantile_aggr.rs +++ b/src/promql/src/functions/quantile_aggr.rs @@ -40,8 +40,8 @@ pub struct QuantileAccumulator { /// Create a quantile `AggregateUDF` for PromQL quantile operator, /// which calculates φ-quantile (0 ≤ φ ≤ 1) over dimensions -pub fn quantile_udaf() -> Arc { - Arc::new(create_udaf( +pub fn quantile_udaf() -> AggregateUDF { + create_udaf( QUANTILE_NAME, // Input type: (φ, values) vec![DataType::Float64, DataType::Float64], @@ -63,7 +63,7 @@ pub fn quantile_udaf() -> Arc { )] .into(), )]), - )) + ) } impl QuantileAccumulator { diff --git a/src/query/src/promql/planner.rs b/src/query/src/promql/planner.rs index 6df5bdbbba..5aa0a987ee 100644 --- a/src/query/src/promql/planner.rs +++ b/src/query/src/promql/planner.rs @@ -1948,7 +1948,7 @@ impl PromPlanner { token::T_QUANTILE => { let q = Self::get_param_value_as_f64(op, param)?; non_col_args.push(lit(q)); - quantile_udaf() + Arc::new(quantile_udaf()) } token::T_AVG => avg_udaf(), token::T_COUNT_VALUES | token::T_COUNT => count_udaf(), diff --git a/src/query/src/query_engine/default_serializer.rs b/src/query/src/query_engine/default_serializer.rs index 50f7c79ff3..259075d6ae 100644 --- a/src/query/src/query_engine/default_serializer.rs +++ b/src/query/src/query_engine/default_serializer.rs @@ -28,11 +28,6 @@ use datafusion::execution::{FunctionRegistry, SessionStateBuilder}; use datafusion::logical_expr::LogicalPlan; use datafusion_expr::UserDefinedLogicalNode; use greptime_proto::substrait_extension::MergeScan as PbMergeScan; -use promql::functions::{ - quantile_udaf, AbsentOverTime, AvgOverTime, Changes, CountOverTime, Delta, Deriv, IDelta, - Increase, LastOverTime, MaxOverTime, MinOverTime, PresentOverTime, Rate, Resets, Round, - StddevOverTime, StdvarOverTime, SumOverTime, -}; use prost::Message; use session::context::QueryContextRef; use snafu::ResultExt; @@ -117,12 +112,15 @@ impl SubstraitPlanDecoder for DefaultPlanDecoder { let mut session_state = SessionStateBuilder::new_from_existing(self.session_state.clone()) .with_catalog_list(catalog_list) .build(); + // Substrait decoder will look up the UDFs in SessionState, so we need to register them // Note: the query context must be passed to set the timezone // We MUST register the UDFs after we build the session state, otherwise the UDFs will be lost // if they have the same name as the default UDFs or their alias. // e.g. The default UDF `to_char()` has an alias `date_format()`, if we register a UDF with the name `date_format()` // before we build the session state, the UDF will be lost. + + // Scalar functions for func in FUNCTION_REGISTRY.scalar_functions() { let udf = func.provide(FunctionContext { query_ctx: self.query_ctx.clone(), @@ -133,6 +131,15 @@ impl SubstraitPlanDecoder for DefaultPlanDecoder { .context(RegisterUdfSnafu { name: func.name() })?; } + // PromQL range functions + for func in promql::functions::range_funcs() { + let name = func.name().to_string(); + session_state + .register_udf(Arc::new(func)) + .context(RegisterUdfSnafu { name })?; + } + + // Aggregate functions for func in FUNCTION_REGISTRY.aggregate_functions() { let name = func.name().to_string(); session_state @@ -140,29 +147,6 @@ impl SubstraitPlanDecoder for DefaultPlanDecoder { .context(RegisterUdfSnafu { name })?; } - let _ = session_state.register_udaf(quantile_udaf()); - - let _ = session_state.register_udf(Arc::new(IDelta::::scalar_udf())); - let _ = session_state.register_udf(Arc::new(IDelta::::scalar_udf())); - let _ = session_state.register_udf(Arc::new(Rate::scalar_udf())); - let _ = session_state.register_udf(Arc::new(Increase::scalar_udf())); - let _ = session_state.register_udf(Arc::new(Delta::scalar_udf())); - let _ = session_state.register_udf(Arc::new(Resets::scalar_udf())); - let _ = session_state.register_udf(Arc::new(Changes::scalar_udf())); - let _ = session_state.register_udf(Arc::new(Deriv::scalar_udf())); - let _ = session_state.register_udf(Arc::new(Round::scalar_udf())); - let _ = session_state.register_udf(Arc::new(AvgOverTime::scalar_udf())); - let _ = session_state.register_udf(Arc::new(MinOverTime::scalar_udf())); - let _ = session_state.register_udf(Arc::new(MaxOverTime::scalar_udf())); - let _ = session_state.register_udf(Arc::new(SumOverTime::scalar_udf())); - let _ = session_state.register_udf(Arc::new(CountOverTime::scalar_udf())); - let _ = session_state.register_udf(Arc::new(LastOverTime::scalar_udf())); - let _ = session_state.register_udf(Arc::new(AbsentOverTime::scalar_udf())); - let _ = session_state.register_udf(Arc::new(PresentOverTime::scalar_udf())); - let _ = session_state.register_udf(Arc::new(StddevOverTime::scalar_udf())); - let _ = session_state.register_udf(Arc::new(StdvarOverTime::scalar_udf())); - // TODO(ruihang): add quantile_over_time, predict_linear, holt_winters, round - let logical_plan = DFLogicalSubstraitConvertor .decode(message, session_state) .await