feat: support arbitrary constant expression in PromQL function (#6315)

* refactor holt_winters, predict_linear, quantile, round

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* fix clippy

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* some sqlness result

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* support some functions

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* make all sqlness cases pass

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* fix other sqlness cases

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* some refactor

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* fix clippy

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

---------

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
This commit is contained in:
Ruihang Xia
2025-06-16 23:12:27 +08:00
committed by GitHub
parent 2a3445c72c
commit be4e0d589e
10 changed files with 579 additions and 302 deletions

View File

@@ -31,6 +31,60 @@ use crate::error;
use crate::functions::extract_array;
use crate::range_array::RangeArray;
/// `FactorIterator` iterates over a `ColumnarValue` that can be a scalar or an array.
struct FactorIterator<'a> {
is_scalar: bool,
array: Option<&'a Float64Array>,
scalar_val: f64,
index: usize,
len: usize,
}
impl<'a> FactorIterator<'a> {
fn new(value: &'a ColumnarValue, len: usize) -> Self {
let (is_scalar, array, scalar_val) = match value {
ColumnarValue::Array(arr) => {
(false, arr.as_any().downcast_ref::<Float64Array>(), f64::NAN)
}
ColumnarValue::Scalar(ScalarValue::Float64(Some(val))) => (true, None, *val),
_ => (true, None, f64::NAN),
};
Self {
is_scalar,
array,
scalar_val,
index: 0,
len,
}
}
}
impl<'a> Iterator for FactorIterator<'a> {
type Item = f64;
fn next(&mut self) -> Option<Self::Item> {
if self.index >= self.len {
return None;
}
self.index += 1;
if self.is_scalar {
return Some(self.scalar_val);
}
if let Some(array) = self.array {
if array.is_null(self.index - 1) {
Some(f64::NAN)
} else {
Some(array.value(self.index - 1))
}
} else {
Some(f64::NAN)
}
}
}
/// There are 3 variants of smoothing functions:
/// 1) "Simple exponential smoothing": only the `level` component (the weighted average of the observations) is used to make forecasts.
/// This method is applied for time-series data that does not exhibit trend or seasonality.
@@ -44,16 +98,9 @@ use crate::range_array::RangeArray;
/// the "Holt's linear"("double exponential smoothing") suits better and reflects implementation.
/// There's the [discussion](https://github.com/prometheus/prometheus/issues/2458) in the Prometheus Github that dates back
/// to 2017 highlighting the naming/implementation mismatch.
pub struct HoltWinters {
sf: f64,
tf: f64,
}
pub struct HoltWinters;
impl HoltWinters {
fn new(sf: f64, tf: f64) -> Self {
Self { sf, tf }
}
pub const fn name() -> &'static str {
"prom_holt_winters"
}
@@ -80,46 +127,31 @@ impl HoltWinters {
Self::input_type(),
Self::return_type(),
Volatility::Volatile,
Arc::new(move |input: &_| Self::create_function(input)?.calc(input)) as _,
Arc::new(Self::holt_winters) as _,
)
}
fn create_function(inputs: &[ColumnarValue]) -> Result<Self, DataFusionError> {
if inputs.len() != 4 {
return Err(DataFusionError::Plan(
"HoltWinters function should have 4 inputs".to_string(),
));
}
let ColumnarValue::Scalar(ScalarValue::Float64(Some(sf))) = inputs[2] else {
return Err(DataFusionError::Plan(
"HoltWinters function's third input should be a scalar float64".to_string(),
));
};
let ColumnarValue::Scalar(ScalarValue::Float64(Some(tf))) = inputs[3] else {
return Err(DataFusionError::Plan(
"HoltWinters function's fourth input should be a scalar float64".to_string(),
));
};
Ok(Self::new(sf, tf))
}
fn calc(&self, input: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
// construct matrix from input.
// The third one is level param, the fourth - trend param which are included in fields.
assert_eq!(input.len(), 4);
fn holt_winters(input: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
error::ensure(
input.len() == 4,
DataFusionError::Plan("prom_holt_winters function should have 4 inputs".to_string()),
)?;
let ts_array = extract_array(&input[0])?;
let value_array = extract_array(&input[1])?;
let sf_col = &input[2];
let tf_col = &input[3];
let ts_range: RangeArray = RangeArray::try_new(ts_array.to_data().into())?;
let value_range: RangeArray = RangeArray::try_new(value_array.to_data().into())?;
let num_rows = ts_range.len();
error::ensure(
ts_range.len() == value_range.len(),
num_rows == value_range.len(),
DataFusionError::Execution(format!(
"{}: input arrays should have the same length, found {} and {}",
Self::name(),
ts_range.len(),
num_rows,
value_range.len()
)),
)?;
@@ -142,9 +174,17 @@ impl HoltWinters {
// calculation
let mut result_array = Vec::with_capacity(ts_range.len());
for index in 0..ts_range.len() {
let timestamps = ts_range.get(index).unwrap();
let values = value_range.get(index).unwrap();
let sf_iter = FactorIterator::new(sf_col, num_rows);
let tf_iter = FactorIterator::new(tf_col, num_rows);
let iter = (0..num_rows)
.map(|i| (ts_range.get(i), value_range.get(i)))
.zip(sf_iter.zip(tf_iter));
for ((timestamps, values), (sf, tf)) in iter {
let timestamps = timestamps.unwrap();
let values = values.unwrap();
let values = values
.as_any()
.downcast_ref::<Float64Array>()
@@ -159,7 +199,8 @@ impl HoltWinters {
values.len()
)),
)?;
result_array.push(holt_winter_impl(values, self.sf, self.tf));
result_array.push(holt_winter_impl(values, sf, tf));
}
let result = ColumnarValue::Array(Arc::new(Float64Array::from_iter(result_array)));

View File

@@ -31,16 +31,9 @@ use crate::error;
use crate::functions::{extract_array, linear_regression};
use crate::range_array::RangeArray;
pub struct PredictLinear {
/// Duration. The second param of (`predict_linear(v range-vector, t scalar)`).
t: i64,
}
pub struct PredictLinear;
impl PredictLinear {
fn new(t: i64) -> Self {
Self { t }
}
pub const fn name() -> &'static str {
"prom_predict_linear"
}
@@ -59,29 +52,19 @@ impl PredictLinear {
input_types,
DataType::Float64,
Volatility::Volatile,
Arc::new(move |input: &_| Self::create_function(input)?.predict_linear(input)) as _,
Arc::new(Self::predict_linear) as _,
)
}
fn create_function(inputs: &[ColumnarValue]) -> Result<Self, DataFusionError> {
if inputs.len() != 3 {
return Err(DataFusionError::Plan(
"PredictLinear function should have 3 inputs".to_string(),
));
}
let ColumnarValue::Scalar(ScalarValue::Int64(Some(t))) = inputs[2] else {
return Err(DataFusionError::Plan(
"PredictLinear function's third input should be a scalar int64".to_string(),
));
};
Ok(Self::new(t))
}
fn predict_linear(input: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
error::ensure(
input.len() == 3,
DataFusionError::Plan("prom_predict_linear function should have 3 inputs".to_string()),
)?;
fn predict_linear(&self, input: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
// construct matrix from input.
assert_eq!(input.len(), 3);
let ts_array = extract_array(&input[0])?;
let value_array = extract_array(&input[1])?;
let t_col = &input[2];
let ts_range: RangeArray = RangeArray::try_new(ts_array.to_data().into())?;
let value_range: RangeArray = RangeArray::try_new(value_array.to_data().into())?;
@@ -111,36 +94,46 @@ impl PredictLinear {
)),
)?;
// calculation
let t_iter: Box<dyn Iterator<Item = Option<i64>>> = match t_col {
ColumnarValue::Scalar(t_scalar) => {
let t = if let ScalarValue::Int64(Some(t_val)) = t_scalar {
*t_val
} else {
// For `ScalarValue::Int64(None)` or other scalar types, returns NULL array,
// which conforms to PromQL's behavior.
let null_array = Float64Array::new_null(ts_range.len());
return Ok(ColumnarValue::Array(Arc::new(null_array)));
};
Box::new((0..ts_range.len()).map(move |_| Some(t)))
}
ColumnarValue::Array(t_array) => {
let t_array = t_array
.as_any()
.downcast_ref::<datafusion::arrow::array::Int64Array>()
.ok_or_else(|| {
DataFusionError::Execution(format!(
"{}: expect Int64 as t array's type, found {}",
Self::name(),
t_array.data_type()
))
})?;
error::ensure(
t_array.len() == ts_range.len(),
DataFusionError::Execution(format!(
"{}: t array should have the same length as other columns, found {} and {}",
Self::name(),
t_array.len(),
ts_range.len()
)),
)?;
Box::new(t_array.iter())
}
};
let mut result_array = Vec::with_capacity(ts_range.len());
for index in 0..ts_range.len() {
let timestamps = ts_range
.get(index)
.unwrap()
.as_any()
.downcast_ref::<TimestampMillisecondArray>()
.unwrap()
.clone();
let values = value_range
.get(index)
.unwrap()
.as_any()
.downcast_ref::<Float64Array>()
.unwrap()
.clone();
error::ensure(
timestamps.len() == values.len(),
DataFusionError::Execution(format!(
"{}: input arrays should have the same length, found {} and {}",
Self::name(),
timestamps.len(),
values.len()
)),
)?;
let ret = predict_linear_impl(&timestamps, &values, self.t);
for (index, t) in t_iter.enumerate() {
let (timestamps, values) = get_ts_values(&ts_range, &value_range, index, Self::name())?;
let ret = predict_linear_impl(&timestamps, &values, t.unwrap());
result_array.push(ret);
}
@@ -149,6 +142,38 @@ impl PredictLinear {
}
}
fn get_ts_values(
ts_range: &RangeArray,
value_range: &RangeArray,
index: usize,
func_name: &str,
) -> Result<(TimestampMillisecondArray, Float64Array), DataFusionError> {
let timestamps = ts_range
.get(index)
.unwrap()
.as_any()
.downcast_ref::<TimestampMillisecondArray>()
.unwrap()
.clone();
let values = value_range
.get(index)
.unwrap()
.as_any()
.downcast_ref::<Float64Array>()
.unwrap()
.clone();
error::ensure(
timestamps.len() == values.len(),
DataFusionError::Execution(format!(
"{}: time and value arrays in a group should have the same length, found {} and {}",
func_name,
timestamps.len(),
values.len()
)),
)?;
Ok((timestamps, values))
}
fn predict_linear_impl(
timestamps: &TimestampMillisecondArray,
values: &Float64Array,

View File

@@ -28,15 +28,9 @@ use crate::error;
use crate::functions::extract_array;
use crate::range_array::RangeArray;
pub struct QuantileOverTime {
quantile: f64,
}
pub struct QuantileOverTime;
impl QuantileOverTime {
fn new(quantile: f64) -> Self {
Self { quantile }
}
pub const fn name() -> &'static str {
"prom_quantile_over_time"
}
@@ -55,32 +49,21 @@ impl QuantileOverTime {
input_types,
DataType::Float64,
Volatility::Volatile,
Arc::new(move |input: &_| Self::create_function(input)?.quantile_over_time(input)) as _,
Arc::new(Self::quantile_over_time) as _,
)
}
fn create_function(inputs: &[ColumnarValue]) -> Result<Self, DataFusionError> {
if inputs.len() != 3 {
return Err(DataFusionError::Plan(
"QuantileOverTime function should have 3 inputs".to_string(),
));
}
let ColumnarValue::Scalar(ScalarValue::Float64(Some(quantile))) = inputs[2] else {
return Err(DataFusionError::Plan(
"QuantileOverTime function's third input should be a scalar float64".to_string(),
));
};
Ok(Self::new(quantile))
}
fn quantile_over_time(input: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
error::ensure(
input.len() == 3,
DataFusionError::Plan(
"prom_quantile_over_time function should have 3 inputs".to_string(),
),
)?;
fn quantile_over_time(
&self,
input: &[ColumnarValue],
) -> Result<ColumnarValue, DataFusionError> {
// construct matrix from input.
assert_eq!(input.len(), 2);
let ts_array = extract_array(&input[0])?;
let value_array = extract_array(&input[1])?;
let quantile_col = &input[2];
let ts_range: RangeArray = RangeArray::try_new(ts_array.to_data().into())?;
let value_range: RangeArray = RangeArray::try_new(value_array.to_data().into())?;
@@ -113,27 +96,85 @@ impl QuantileOverTime {
// calculation
let mut result_array = Vec::with_capacity(ts_range.len());
for index in 0..ts_range.len() {
let timestamps = ts_range.get(index).unwrap();
let values = value_range.get(index).unwrap();
let values = values
.as_any()
.downcast_ref::<Float64Array>()
.unwrap()
.values();
error::ensure(
timestamps.len() == values.len(),
DataFusionError::Execution(format!(
"{}: input arrays should have the same length, found {} and {}",
Self::name(),
timestamps.len(),
values.len()
)),
)?;
match quantile_col {
ColumnarValue::Scalar(quantile_scalar) => {
let quantile = if let ScalarValue::Float64(Some(q)) = quantile_scalar {
*q
} else {
// For `ScalarValue::Float64(None)` or other scalar types, use NAN,
// which conforms to PromQL's behavior.
f64::NAN
};
let retule = quantile_impl(values, self.quantile);
for index in 0..ts_range.len() {
let timestamps = ts_range.get(index).unwrap();
let values = value_range.get(index).unwrap();
let values = values
.as_any()
.downcast_ref::<Float64Array>()
.unwrap()
.values();
error::ensure(
timestamps.len() == values.len(),
DataFusionError::Execution(format!(
"{}: time and value arrays in a group should have the same length, found {} and {}",
Self::name(),
timestamps.len(),
values.len()
)),
)?;
result_array.push(retule);
let result = quantile_impl(values, quantile);
result_array.push(result);
}
}
ColumnarValue::Array(quantile_array) => {
let quantile_array = quantile_array
.as_any()
.downcast_ref::<Float64Array>()
.ok_or_else(|| {
DataFusionError::Execution(format!(
"{}: expect Float64 as quantile array's type, found {}",
Self::name(),
quantile_array.data_type()
))
})?;
error::ensure(
quantile_array.len() == ts_range.len(),
DataFusionError::Execution(format!(
"{}: quantile array should have the same length as other columns, found {} and {}",
Self::name(),
quantile_array.len(),
ts_range.len()
)),
)?;
for index in 0..ts_range.len() {
let timestamps = ts_range.get(index).unwrap();
let values = value_range.get(index).unwrap();
let values = values
.as_any()
.downcast_ref::<Float64Array>()
.unwrap()
.values();
error::ensure(
timestamps.len() == values.len(),
DataFusionError::Execution(format!(
"{}: time and value arrays in a group should have the same length, found {} and {}",
Self::name(),
timestamps.len(),
values.len()
)),
)?;
let quantile = if quantile_array.is_null(index) {
f64::NAN
} else {
quantile_array.value(index)
};
let result = quantile_impl(values, quantile);
result_array.push(result);
}
}
}
let result = ColumnarValue::Array(Arc::new(Float64Array::from_iter(result_array)));

View File

@@ -17,21 +17,16 @@ use std::sync::Arc;
use datafusion::error::DataFusionError;
use datafusion_common::ScalarValue;
use datafusion_expr::{create_udf, ColumnarValue, ScalarUDF, Volatility};
use datatypes::arrow::array::AsArray;
use datatypes::arrow::array::{AsArray, Float64Array, PrimitiveArray};
use datatypes::arrow::datatypes::{DataType, Float64Type};
use datatypes::compute;
use datatypes::arrow::error::ArrowError;
use crate::error;
use crate::functions::extract_array;
pub struct Round {
nearest: f64,
}
pub struct Round;
impl Round {
fn new(nearest: f64) -> Self {
Self { nearest }
}
pub const fn name() -> &'static str {
"prom_round"
}
@@ -50,39 +45,62 @@ impl Round {
Self::input_type(),
Self::return_type(),
Volatility::Volatile,
Arc::new(move |input: &_| Self::create_function(input)?.calc(input)) as _,
Arc::new(Self::round) as _,
)
}
fn create_function(inputs: &[ColumnarValue]) -> Result<Self, DataFusionError> {
if inputs.len() != 2 {
return Err(DataFusionError::Plan(
"Round function should have 2 inputs".to_string(),
));
}
let ColumnarValue::Scalar(ScalarValue::Float64(Some(nearest))) = inputs[1] else {
return Err(DataFusionError::Plan(
"Round function's second input should be a scalar float64".to_string(),
));
};
Ok(Self::new(nearest))
}
fn calc(&self, input: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
assert_eq!(input.len(), 2);
fn round(input: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
error::ensure(
input.len() == 2,
DataFusionError::Plan("prom_round function should have 2 inputs".to_string()),
)?;
let value_array = extract_array(&input[0])?;
let nearest_col = &input[1];
if self.nearest == 0.0 {
let values = value_array.as_primitive::<Float64Type>();
let result = compute::unary::<_, _, Float64Type>(values, |a| a.round());
Ok(ColumnarValue::Array(Arc::new(result) as _))
} else {
let values = value_array.as_primitive::<Float64Type>();
let nearest = self.nearest;
let result =
compute::unary::<_, _, Float64Type>(values, |a| ((a / nearest).round() * nearest));
Ok(ColumnarValue::Array(Arc::new(result) as _))
match nearest_col {
ColumnarValue::Scalar(nearest_scalar) => {
let nearest = if let ScalarValue::Float64(Some(val)) = nearest_scalar {
*val
} else {
let null_array = Float64Array::new_null(value_array.len());
return Ok(ColumnarValue::Array(Arc::new(null_array)));
};
let op = |a: f64| {
if nearest == 0.0 {
a.round()
} else {
(a / nearest).round() * nearest
}
};
let result: PrimitiveArray<Float64Type> =
value_array.as_primitive::<Float64Type>().unary(op);
Ok(ColumnarValue::Array(Arc::new(result) as _))
}
ColumnarValue::Array(nearest_array) => {
let value_array = value_array.as_primitive::<Float64Type>();
let nearest_array = nearest_array.as_primitive::<Float64Type>();
error::ensure(
value_array.len() == nearest_array.len(),
DataFusionError::Execution(format!(
"input arrays should have the same length, found {} and {}",
value_array.len(),
nearest_array.len()
)),
)?;
let result: PrimitiveArray<Float64Type> =
datatypes::arrow::compute::binary(value_array, nearest_array, |a, nearest| {
if nearest == 0.0 {
a.round()
} else {
(a / nearest).round() * nearest
}
})
.map_err(|err: ArrowError| DataFusionError::ArrowError(err, None))?;
Ok(ColumnarValue::Array(Arc::new(result) as _))
}
}
}
}

View File

@@ -1435,27 +1435,22 @@ impl PromPlanner {
for arg in args {
match *arg.clone() {
PromExpr::Aggregate(_)
| PromExpr::Unary(_)
| PromExpr::Binary(_)
| PromExpr::Paren(_)
| PromExpr::Subquery(_)
PromExpr::Subquery(_)
| PromExpr::VectorSelector(_)
| PromExpr::MatrixSelector(_)
| PromExpr::Extension(_)
| PromExpr::Aggregate(_)
| PromExpr::Paren(_)
| PromExpr::Call(_) => {
if result.input.replace(*arg.clone()).is_some() {
MultipleVectorSnafu { expr: *arg.clone() }.fail()?;
}
}
PromExpr::NumberLiteral(NumberLiteral { val, .. }) => {
let scalar_value = ScalarValue::Float64(Some(val));
result.literals.push(DfExpr::Literal(scalar_value));
}
PromExpr::StringLiteral(StringLiteral { val, .. }) => {
let scalar_value = ScalarValue::Utf8(Some(val));
result.literals.push(DfExpr::Literal(scalar_value));
_ => {
let expr =
Self::get_param_as_literal_expr(&Some(Box::new(*arg.clone())), None, None)?;
result.literals.push(expr);
}
}
}
@@ -1507,7 +1502,13 @@ impl PromPlanner {
"stddev_over_time" => ScalarFunc::Udf(Arc::new(StddevOverTime::scalar_udf())),
"stdvar_over_time" => ScalarFunc::Udf(Arc::new(StdvarOverTime::scalar_udf())),
"quantile_over_time" => ScalarFunc::Udf(Arc::new(QuantileOverTime::scalar_udf())),
"predict_linear" => ScalarFunc::Udf(Arc::new(PredictLinear::scalar_udf())),
"predict_linear" => {
other_input_exprs[0] = DfExpr::Cast(Cast {
expr: Box::new(other_input_exprs[0].clone()),
data_type: ArrowDataType::Int64,
});
ScalarFunc::Udf(Arc::new(PredictLinear::scalar_udf()))
}
"holt_winters" => ScalarFunc::Udf(Arc::new(HoltWinters::scalar_udf())),
"time" => {
exprs.push(build_special_time_expr(

View File

@@ -29,9 +29,9 @@ use datafusion::logical_expr::LogicalPlan;
use datafusion_expr::UserDefinedLogicalNode;
use greptime_proto::substrait_extension::MergeScan as PbMergeScan;
use promql::functions::{
quantile_udaf, AbsentOverTime, AvgOverTime, Changes, CountOverTime, Delta, Deriv, IDelta,
Increase, LastOverTime, MaxOverTime, MinOverTime, PresentOverTime, Rate, Resets, Round,
StddevOverTime, StdvarOverTime, SumOverTime,
quantile_udaf, AbsentOverTime, AvgOverTime, Changes, CountOverTime, Delta, Deriv, HoltWinters,
IDelta, Increase, LastOverTime, MaxOverTime, MinOverTime, PredictLinear, PresentOverTime,
QuantileOverTime, Rate, Resets, Round, StddevOverTime, StdvarOverTime, SumOverTime,
};
use prost::Message;
use session::context::QueryContextRef;
@@ -161,7 +161,9 @@ impl SubstraitPlanDecoder for DefaultPlanDecoder {
let _ = session_state.register_udf(Arc::new(PresentOverTime::scalar_udf()));
let _ = session_state.register_udf(Arc::new(StddevOverTime::scalar_udf()));
let _ = session_state.register_udf(Arc::new(StdvarOverTime::scalar_udf()));
// TODO(ruihang): add quantile_over_time, predict_linear, holt_winters, round
let _ = session_state.register_udf(Arc::new(QuantileOverTime::scalar_udf()));
let _ = session_state.register_udf(Arc::new(PredictLinear::scalar_udf()));
let _ = session_state.register_udf(Arc::new(HoltWinters::scalar_udf()));
let logical_plan = DFLogicalSubstraitConvertor
.decode(message, session_state)