From 3a527c0fd59f5e98eda7abef5c54f03bfab0c22f Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Tue, 7 Mar 2023 23:39:45 +0800 Subject: [PATCH] feat: impl proc macro `range_fn` and some `aggr_over_time` functions (#1072) * impl range_fn proc macro Signed-off-by: Ruihang Xia * impl some aggr_over_time fn Signed-off-by: Ruihang Xia * impl present_over_time and absent_over_time Signed-off-by: Ruihang Xia * accomplish planner, and correct type cast Signed-off-by: Ruihang Xia * clean up Signed-off-by: Ruihang Xia * document the macro Signed-off-by: Ruihang Xia * fix styles Signed-off-by: Ruihang Xia * update irate/idelta test Signed-off-by: Ruihang Xia * add test cases Signed-off-by: Ruihang Xia * fix clippy Signed-off-by: Ruihang Xia --------- Signed-off-by: Ruihang Xia --- Cargo.lock | 2 + src/common/function-macro/Cargo.toml | 1 + src/common/function-macro/src/lib.rs | 31 ++ src/common/function-macro/src/range_fn.rs | 230 ++++++++++++++ src/promql/Cargo.toml | 1 + src/promql/src/functions.rs | 7 + src/promql/src/functions/aggr_over_time.rs | 335 +++++++++++++++++++++ src/promql/src/functions/idelta.rs | 44 +-- src/promql/src/functions/test_util.rs | 43 +++ src/promql/src/planner.rs | 30 +- 10 files changed, 688 insertions(+), 36 deletions(-) create mode 100644 src/common/function-macro/src/range_fn.rs create mode 100644 src/promql/src/functions/aggr_over_time.rs create mode 100644 src/promql/src/functions/test_util.rs diff --git a/Cargo.lock b/Cargo.lock index 76ec94f73e..8cdfe1285e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1522,6 +1522,7 @@ dependencies = [ "arc-swap", "common-query", "datatypes", + "proc-macro2", "quote", "snafu", "static_assertions", @@ -5394,6 +5395,7 @@ dependencies = [ "catalog", "common-catalog", "common-error", + "common-function-macro", "datafusion", "datatypes", "futures", diff --git a/src/common/function-macro/Cargo.toml b/src/common/function-macro/Cargo.toml index 70d2008c22..c9078dc13e 100644 --- a/src/common/function-macro/Cargo.toml +++ b/src/common/function-macro/Cargo.toml @@ -10,6 +10,7 @@ proc-macro = true [dependencies] quote = "1.0" syn = "1.0" +proc-macro2 = "1.0" [dev-dependencies] arc-swap = "1.0" diff --git a/src/common/function-macro/src/lib.rs b/src/common/function-macro/src/lib.rs index 5eb2c93b9c..0afa5b7e7b 100644 --- a/src/common/function-macro/src/lib.rs +++ b/src/common/function-macro/src/lib.rs @@ -12,8 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +mod range_fn; + use proc_macro::TokenStream; use quote::{quote, quote_spanned}; +use range_fn::process_range_fn; use syn::parse::Parser; use syn::spanned::Spanned; use syn::{parse_macro_input, DeriveInput, ItemStruct}; @@ -83,3 +86,31 @@ pub fn as_aggr_func_creator(_args: TokenStream, input: TokenStream) -> TokenStre } .into() } + +/// Attribute macro to convert an arithimetic function to a range function. The annotated function +/// should accept servaral arrays as input and return a single value as output. This procedure +/// macro can works on any number of input parameters. Return type can be either primitive type +/// or wrapped in `Option`. +/// +/// # Example +/// Take `count_over_time()` in PromQL as an example: +/// ```rust, ignore +/// /// The count of all values in the specified interval. +/// #[range_fn( +/// name = "CountOverTime", +/// ret = "Float64Array", +/// display_name = "prom_count_over_time" +/// )] +/// pub fn count_over_time(_: &TimestampMillisecondArray, values: &Float64Array) -> f64 { +/// values.len() as f64 +/// } +/// ``` +/// +/// # Arguments +/// - `name`: The name of the generated [ScalarUDF] struct. +/// - `ret`: The return type of the generated UDF function. +/// - `display_name`: The display name of the generated UDF function. +#[proc_macro_attribute] +pub fn range_fn(args: TokenStream, input: TokenStream) -> TokenStream { + process_range_fn(args, input) +} diff --git a/src/common/function-macro/src/range_fn.rs b/src/common/function-macro/src/range_fn.rs new file mode 100644 index 0000000000..c7fcedbdae --- /dev/null +++ b/src/common/function-macro/src/range_fn.rs @@ -0,0 +1,230 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; + +use proc_macro::TokenStream; +use proc_macro2::Span; +use quote::quote; +use syn::punctuated::Punctuated; +use syn::spanned::Spanned; +use syn::token::Comma; +use syn::{ + parse_macro_input, Attribute, AttributeArgs, FnArg, Ident, ItemFn, Meta, MetaNameValue, + NestedMeta, Signature, Type, TypeReference, Visibility, +}; + +/// Internal util macro to early return on error. +macro_rules! ok { + ($item:expr) => { + match $item { + Ok(item) => item, + Err(e) => return e.into_compile_error().into(), + } + }; +} + +pub(crate) fn process_range_fn(args: TokenStream, input: TokenStream) -> TokenStream { + // extract arg map + let arg_pairs = parse_macro_input!(args as AttributeArgs); + let arg_span = arg_pairs[0].span(); + let arg_map = ok!(extract_arg_map(arg_pairs)); + + // decompose the fn block + let compute_fn = parse_macro_input!(input as ItemFn); + let ItemFn { + attrs, + vis, + sig, + block, + } = compute_fn; + + // extract fn arg list + let Signature { + inputs, + ident: fn_name, + .. + } = &sig; + let arg_types = ok!(extract_input_types(inputs)); + + // build the struct and its impl block + let struct_code = build_struct( + attrs, + vis, + ok!(get_ident(&arg_map, "name", arg_span)), + ok!(get_ident(&arg_map, "display_name", arg_span)), + ); + let calc_fn_code = build_calc_fn( + ok!(get_ident(&arg_map, "name", arg_span)), + arg_types, + fn_name.clone(), + ok!(get_ident(&arg_map, "ret", arg_span)), + ); + // preserve this fn, but remove its `pub` modifier + let input_fn_code: TokenStream = quote! { + #sig { #block } + } + .into(); + + let mut result = TokenStream::new(); + result.extend(struct_code); + result.extend(calc_fn_code); + result.extend(input_fn_code); + result +} + +/// Extract a String <-> Ident map from the attribute args. +fn extract_arg_map(args: Vec) -> Result, syn::Error> { + args.into_iter() + .map(|meta| { + if let NestedMeta::Meta(Meta::NameValue(MetaNameValue { path, lit, .. })) = meta { + let name = path.get_ident().unwrap().to_string(); + let ident = match lit { + syn::Lit::Str(lit_str) => lit_str.parse::(), + _ => Err(syn::Error::new( + lit.span(), + "Unexpected attribute format. Expected `name = \"value\"`", + )), + }?; + Ok((name, ident)) + } else { + Err(syn::Error::new( + meta.span(), + "Unexpected attribute format. Expected `name = \"value\"`", + )) + } + }) + .collect::, syn::Error>>() +} + +/// Helper function to get an Ident from the previous arg map. +fn get_ident(map: &HashMap, key: &str, span: Span) -> Result { + map.get(key) + .cloned() + .ok_or_else(|| syn::Error::new(span, format!("Expect attribute {key} but not found"))) +} + +/// Extract the argument list from the annotated function. +fn extract_input_types(inputs: &Punctuated) -> Result, syn::Error> { + inputs + .iter() + .map(|arg| match arg { + FnArg::Receiver(receiver) => Err(syn::Error::new(receiver.span(), "expected bool")), + FnArg::Typed(pat_type) => Ok(*pat_type.ty.clone()), + }) + .collect() +} + +fn build_struct( + attrs: Vec, + vis: Visibility, + name: Ident, + display_name_ident: Ident, +) -> TokenStream { + let display_name = display_name_ident.to_string(); + quote! { + #(#attrs)* + #[derive(Debug)] + #vis struct #name {} + + impl #name { + pub const fn name() -> &'static str { + #display_name + } + + pub fn scalar_udf() -> ScalarUDF { + ScalarUDF { + name: Self::name().to_string(), + signature: Signature::new( + TypeSignature::Exact(Self::input_type()), + Volatility::Immutable, + ), + return_type: Arc::new(|_| Ok(Arc::new(Self::return_type()))), + fun: Arc::new(Self::calc), + } + } + + // TODO(ruihang): this should be parameterized + // time index column and value column + fn input_type() -> Vec { + vec![ + RangeArray::convert_data_type(DataType::Timestamp(TimeUnit::Millisecond, None)), + RangeArray::convert_data_type(DataType::Float64), + ] + } + + // TODO(ruihang): this should be parameterized + fn return_type() -> DataType { + DataType::Float64 + } + } + } + .into() +} + +fn build_calc_fn( + name: Ident, + param_types: Vec, + fn_name: Ident, + ret_type: Ident, +) -> TokenStream { + let param_names = param_types + .iter() + .enumerate() + .map(|(i, ty)| Ident::new(&format!("param_{}", i), ty.span())) + .collect::>(); + let unref_param_types = param_types + .iter() + .map(|ty| { + if let Type::Reference(TypeReference { elem, .. }) = ty { + elem.as_ref().clone() + } else { + ty.clone() + } + }) + .collect::>(); + let num_params = param_types.len(); + let param_numbers = (0..num_params).collect::>(); + let range_array_names = param_names + .iter() + .map(|name| Ident::new(&format!("{}_range_array", name), name.span())) + .collect::>(); + let first_range_array_name = range_array_names.first().unwrap().clone(); + + quote! { + impl #name { + fn calc(input: &[ColumnarValue]) -> Result { + assert_eq!(input.len(), #num_params); + + #( let #range_array_names = RangeArray::try_new(extract_array(&input[#param_numbers])?.data().clone().into())?; )* + + // TODO(ruihang): add ensure!() + + let mut result_array = Vec::new(); + for index in 0..#first_range_array_name.len(){ + #( let #param_names = #range_array_names.get(index).unwrap().as_any().downcast_ref::<#unref_param_types>().unwrap().clone(); )* + + // TODO(ruihang): add ensure!() to check length + + let result = #fn_name(#( &#param_names, )*); + result_array.push(result); + } + + let result = ColumnarValue::Array(Arc::new(#ret_type::from_iter(result_array))); + Ok(result) + } + } + } + .into() +} diff --git a/src/promql/Cargo.toml b/src/promql/Cargo.toml index 20a499b942..a4d2f37e96 100644 --- a/src/promql/Cargo.toml +++ b/src/promql/Cargo.toml @@ -11,6 +11,7 @@ bytemuck = "1.12" catalog = { path = "../catalog" } common-error = { path = "../common/error" } common-catalog = { path = "../common/catalog" } +common-function-macro = { path = "../common/function-macro" } datafusion.workspace = true datatypes = { path = "../datatypes" } futures = "0.3" diff --git a/src/promql/src/functions.rs b/src/promql/src/functions.rs index aaa6aa0258..b0e64fd506 100644 --- a/src/promql/src/functions.rs +++ b/src/promql/src/functions.rs @@ -12,9 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. +mod aggr_over_time; mod idelta; mod increase; +#[cfg(test)] +mod test_util; +pub use aggr_over_time::{ + AbsentOverTime, AvgOverTime, CountOverTime, LastOverTime, MaxOverTime, MinOverTime, + PresentOverTime, SumOverTime, +}; use datafusion::arrow::array::ArrayRef; use datafusion::error::DataFusionError; use datafusion::physical_plan::ColumnarValue; diff --git a/src/promql/src/functions/aggr_over_time.rs b/src/promql/src/functions/aggr_over_time.rs new file mode 100644 index 0000000000..451008d420 --- /dev/null +++ b/src/promql/src/functions/aggr_over_time.rs @@ -0,0 +1,335 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use common_function_macro::range_fn; +use datafusion::arrow::array::{Float64Array, TimestampMillisecondArray}; +use datafusion::arrow::datatypes::TimeUnit; +use datafusion::common::DataFusionError; +use datafusion::logical_expr::{ScalarUDF, Signature, TypeSignature, Volatility}; +use datafusion::physical_plan::ColumnarValue; +use datatypes::arrow::array::Array; +use datatypes::arrow::compute; +use datatypes::arrow::datatypes::DataType; + +use crate::functions::extract_array; +use crate::range_array::RangeArray; + +/// The average value of all points in the specified interval. +#[range_fn( + name = "AvgOverTime", + ret = "Float64Array", + display_name = "prom_avg_over_time" +)] +pub fn avg_over_time(_: &TimestampMillisecondArray, values: &Float64Array) -> Option { + compute::sum(values).map(|result| result / values.len() as f64) +} + +/// The minimum value of all points in the specified interval. +#[range_fn( + name = "MinOverTime", + ret = "Float64Array", + display_name = "prom_min_over_time" +)] +pub fn min_over_time(_: &TimestampMillisecondArray, values: &Float64Array) -> Option { + compute::min(values) +} + +/// The maximum value of all points in the specified interval. +#[range_fn( + name = "MaxOverTime", + ret = "Float64Array", + display_name = "prom_max_over_time" +)] +pub fn max_over_time(_: &TimestampMillisecondArray, values: &Float64Array) -> Option { + compute::max(values) +} + +/// The sum of all values in the specified interval. +#[range_fn( + name = "SumOverTime", + ret = "Float64Array", + display_name = "prom_sum_over_time" +)] +pub fn sum_over_time(_: &TimestampMillisecondArray, values: &Float64Array) -> Option { + compute::sum(values) +} + +/// The count of all values in the specified interval. +#[range_fn( + name = "CountOverTime", + ret = "Float64Array", + display_name = "prom_count_over_time" +)] +pub fn count_over_time(_: &TimestampMillisecondArray, values: &Float64Array) -> f64 { + values.len() as f64 +} + +/// The most recent point value in specified interval. +#[range_fn( + name = "LastOverTime", + ret = "Float64Array", + display_name = "prom_last_over_time" +)] +pub fn last_over_time(_: &TimestampMillisecondArray, values: &Float64Array) -> Option { + values.values().last().copied() +} + +/// absent_over_time returns an empty vector if the range vector passed to it has any +/// elements (floats or native histograms) and a 1-element vector with the value 1 if +/// the range vector passed to it has no elements. +#[range_fn( + name = "AbsentOverTime", + ret = "Float64Array", + display_name = "prom_absent_over_time" +)] +pub fn absent_over_time(_: &TimestampMillisecondArray, values: &Float64Array) -> Option { + if values.is_empty() { + Some(1.0) + } else { + None + } +} + +/// the value 1 for any series in the specified interval. +#[range_fn( + name = "PresentOverTime", + ret = "Float64Array", + display_name = "prom_present_over_time" +)] +pub fn present_over_time(_: &TimestampMillisecondArray, values: &Float64Array) -> Option { + if values.is_empty() { + None + } else { + Some(1.0) + } +} + +// TODO(ruihang): support quantile_over_time, stddev_over_time, and stdvar_over_time + +#[cfg(test)] +mod test { + use super::*; + use crate::functions::test_util::simple_range_udf_runner; + + // build timestamp range and value range arrays for test + fn build_test_range_arrays() -> (RangeArray, RangeArray) { + let ts_array = Arc::new(TimestampMillisecondArray::from_iter( + [ + 1000i64, 3000, 5000, 7000, 9000, 11000, 13000, 15000, 17000, 200000, 500000, + ] + .into_iter() + .map(Some), + )); + let ranges = [ + (0, 2), + (0, 5), + (1, 1), // only 1 element + (2, 0), // empty range + (2, 0), // empty range + (3, 3), + (4, 3), + (5, 3), + (8, 1), // only 1 element + (9, 0), // empty range + ]; + + let values_array = Arc::new(Float64Array::from_iter([ + 12.345678, 87.654321, 31.415927, 27.182818, 70.710678, 41.421356, 57.735027, 69.314718, + 98.019802, 1.98019802, 61.803399, + ])); + + let ts_range_array = RangeArray::from_ranges(ts_array, ranges).unwrap(); + let value_range_array = RangeArray::from_ranges(values_array, ranges).unwrap(); + + (ts_range_array, value_range_array) + } + + #[test] + fn calculate_avg_over_time() { + let (ts_array, value_array) = build_test_range_arrays(); + simple_range_udf_runner( + AvgOverTime::scalar_udf(), + ts_array, + value_array, + vec![ + Some(49.9999995), + Some(45.8618844), + Some(87.654321), + None, + None, + Some(46.438284), + Some(56.62235366666667), + Some(56.15703366666667), + Some(98.019802), + None, + ], + ); + } + + #[test] + fn calculate_min_over_time() { + let (ts_array, value_array) = build_test_range_arrays(); + simple_range_udf_runner( + MinOverTime::scalar_udf(), + ts_array, + value_array, + vec![ + Some(12.345678), + Some(12.345678), + Some(87.654321), + None, + None, + Some(27.182818), + Some(41.421356), + Some(41.421356), + Some(98.019802), + None, + ], + ); + } + + #[test] + fn calculate_max_over_time() { + let (ts_array, value_array) = build_test_range_arrays(); + simple_range_udf_runner( + MaxOverTime::scalar_udf(), + ts_array, + value_array, + vec![ + Some(87.654321), + Some(87.654321), + Some(87.654321), + None, + None, + Some(70.710678), + Some(70.710678), + Some(69.314718), + Some(98.019802), + None, + ], + ); + } + + #[test] + fn calculate_sum_over_time() { + let (ts_array, value_array) = build_test_range_arrays(); + simple_range_udf_runner( + SumOverTime::scalar_udf(), + ts_array, + value_array, + vec![ + Some(99.999999), + Some(229.309422), + Some(87.654321), + None, + None, + Some(139.314852), + Some(169.867061), + Some(168.471101), + Some(98.019802), + None, + ], + ); + } + + #[test] + fn calculate_count_over_time() { + let (ts_array, value_array) = build_test_range_arrays(); + simple_range_udf_runner( + CountOverTime::scalar_udf(), + ts_array, + value_array, + vec![ + Some(2.0), + Some(5.0), + Some(1.0), + Some(0.0), + Some(0.0), + Some(3.0), + Some(3.0), + Some(3.0), + Some(1.0), + Some(0.0), + ], + ); + } + + #[test] + fn calculate_last_over_time() { + let (ts_array, value_array) = build_test_range_arrays(); + simple_range_udf_runner( + LastOverTime::scalar_udf(), + ts_array, + value_array, + vec![ + Some(87.654321), + Some(70.710678), + Some(87.654321), + None, + None, + Some(41.421356), + Some(57.735027), + Some(69.314718), + Some(98.019802), + None, + ], + ); + } + + #[test] + fn calculate_absent_over_time() { + let (ts_array, value_array) = build_test_range_arrays(); + simple_range_udf_runner( + AbsentOverTime::scalar_udf(), + ts_array, + value_array, + vec![ + None, + None, + None, + Some(1.0), + Some(1.0), + None, + None, + None, + None, + Some(1.0), + ], + ); + } + + #[test] + fn calculate_present_over_time() { + let (ts_array, value_array) = build_test_range_arrays(); + simple_range_udf_runner( + PresentOverTime::scalar_udf(), + ts_array, + value_array, + vec![ + Some(1.0), + Some(1.0), + Some(1.0), + None, + None, + Some(1.0), + Some(1.0), + Some(1.0), + Some(1.0), + None, + ], + ); + } +} diff --git a/src/promql/src/functions/idelta.rs b/src/promql/src/functions/idelta.rs index 6d138dbb99..3c3b8e52b5 100644 --- a/src/promql/src/functions/idelta.rs +++ b/src/promql/src/functions/idelta.rs @@ -169,36 +169,7 @@ impl Display for IDelta { mod test { use super::*; - - fn idelta_runner(input_ts: RangeArray, input_value: RangeArray, expected: Vec) { - let input = vec![ - ColumnarValue::Array(Arc::new(input_ts.into_dict())), - ColumnarValue::Array(Arc::new(input_value.into_dict())), - ]; - let output = extract_array(&IDelta::::calc(&input).unwrap()) - .unwrap() - .as_any() - .downcast_ref::() - .unwrap() - .values() - .to_vec(); - assert_eq!(output, expected); - } - - fn irate_runner(input_ts: RangeArray, input_value: RangeArray, expected: Vec) { - let input = vec![ - ColumnarValue::Array(Arc::new(input_ts.into_dict())), - ColumnarValue::Array(Arc::new(input_value.into_dict())), - ]; - let output = extract_array(&IDelta::::calc(&input).unwrap()) - .unwrap() - .as_any() - .downcast_ref::() - .unwrap() - .values() - .to_vec(); - assert_eq!(output, expected); - } + use crate::functions::test_util::simple_range_udf_runner; #[test] fn basic_idelta_and_irate() { @@ -214,21 +185,26 @@ mod test { ])); let values_ranges = [(0, 2), (0, 5), (1, 1), (3, 3), (8, 1), (9, 0)]; + // test idelta let ts_range_array = RangeArray::from_ranges(ts_array.clone(), ts_ranges).unwrap(); let value_range_array = RangeArray::from_ranges(values_array.clone(), values_ranges).unwrap(); - idelta_runner( + simple_range_udf_runner( + IDelta::::scalar_udf(), ts_range_array, value_range_array, - vec![1.0, -5.0, 0.0, 6.0, 0.0, 0.0], + vec![Some(1.0), Some(-5.0), None, Some(6.0), None, None], ); + // test irate let ts_range_array = RangeArray::from_ranges(ts_array, ts_ranges).unwrap(); let value_range_array = RangeArray::from_ranges(values_array, values_ranges).unwrap(); - irate_runner( + simple_range_udf_runner( + IDelta::::scalar_udf(), ts_range_array, value_range_array, - vec![0.5, 0.0, 0.0, 3.0, 0.0, 0.0], + // the second point represent counter reset + vec![Some(0.5), Some(0.0), None, Some(3.0), None, None], ); } } diff --git a/src/promql/src/functions/test_util.rs b/src/promql/src/functions/test_util.rs new file mode 100644 index 0000000000..5b9d4adef3 --- /dev/null +++ b/src/promql/src/functions/test_util.rs @@ -0,0 +1,43 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use datafusion::arrow::array::Float64Array; +use datafusion::logical_expr::ScalarUDF; +use datafusion::physical_plan::ColumnarValue; + +use crate::functions::extract_array; +use crate::range_array::RangeArray; + +/// Runner to run range UDFs that only requires ts range and value range. +pub fn simple_range_udf_runner( + range_fn: ScalarUDF, + input_ts: RangeArray, + input_value: RangeArray, + expected: Vec>, +) { + let input = vec![ + ColumnarValue::Array(Arc::new(input_ts.into_dict())), + ColumnarValue::Array(Arc::new(input_value.into_dict())), + ]; + let eval_result: Vec> = extract_array(&(range_fn.fun)(&input).unwrap()) + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .collect(); + assert_eq!(eval_result, expected) +} diff --git a/src/promql/src/planner.rs b/src/promql/src/planner.rs index 11ccbcdcb8..c2ad32446f 100644 --- a/src/promql/src/planner.rs +++ b/src/promql/src/planner.rs @@ -48,7 +48,10 @@ use crate::error::{ use crate::extension_plan::{ EmptyMetric, InstantManipulate, Millisecond, RangeManipulate, SeriesDivide, SeriesNormalize, }; -use crate::functions::{IDelta, Increase}; +use crate::functions::{ + AbsentOverTime, AvgOverTime, CountOverTime, IDelta, Increase, LastOverTime, MaxOverTime, + MinOverTime, PresentOverTime, SumOverTime, +}; const LEFT_PLAN_JOIN_ALIAS: &str = "lhs"; @@ -667,6 +670,14 @@ impl PromPlanner { "increase" => ScalarFunc::Udf(Increase::scalar_udf()), "idelta" => ScalarFunc::Udf(IDelta::::scalar_udf()), "irate" => ScalarFunc::Udf(IDelta::::scalar_udf()), + "avg_over_time" => ScalarFunc::Udf(AvgOverTime::scalar_udf()), + "min_over_time" => ScalarFunc::Udf(MinOverTime::scalar_udf()), + "max_over_time" => ScalarFunc::Udf(MaxOverTime::scalar_udf()), + "sum_over_time" => ScalarFunc::Udf(SumOverTime::scalar_udf()), + "count_over_time" => ScalarFunc::Udf(CountOverTime::scalar_udf()), + "last_over_time" => ScalarFunc::Udf(LastOverTime::scalar_udf()), + "absent_over_time" => ScalarFunc::Udf(AbsentOverTime::scalar_udf()), + "present_over_time" => ScalarFunc::Udf(PresentOverTime::scalar_udf()), _ => ScalarFunc::DataFusionBuiltin( BuiltinScalarFunction::from_str(func.name).map_err(|_| { UnsupportedExprSnafu { @@ -1592,7 +1603,6 @@ mod test { \n Sort: some_metric.tag_0 DESC NULLS LAST, some_metric.timestamp DESC NULLS LAST [tag_0:Utf8, timestamp:Timestamp(Millisecond, None), field_0:Float64;N]\ \n Filter: some_metric.timestamp >= TimestampMillisecond(-1000, None) AND some_metric.timestamp <= TimestampMillisecond(100000000, None) [tag_0:Utf8, timestamp:Timestamp(Millisecond, None), field_0:Float64;N]\ \n TableScan: some_metric, unsupported_filters=[timestamp >= TimestampMillisecond(-1000, None), timestamp <= TimestampMillisecond(100000000, None)] [tag_0:Utf8, timestamp:Timestamp(Millisecond, None), field_0:Float64;N]" - ); indie_query_plan_compare(query, expected).await; @@ -1609,7 +1619,23 @@ mod test { \n Sort: some_metric.tag_0 DESC NULLS LAST, some_metric.timestamp DESC NULLS LAST [tag_0:Utf8, timestamp:Timestamp(Millisecond, None), field_0:Float64;N]\ \n Filter: some_metric.timestamp >= TimestampMillisecond(-1000, None) AND some_metric.timestamp <= TimestampMillisecond(100000000, None) [tag_0:Utf8, timestamp:Timestamp(Millisecond, None), field_0:Float64;N]\ \n TableScan: some_metric, unsupported_filters=[timestamp >= TimestampMillisecond(-1000, None), timestamp <= TimestampMillisecond(100000000, None)] [tag_0:Utf8, timestamp:Timestamp(Millisecond, None), field_0:Float64;N]" + ); + indie_query_plan_compare(query, expected).await; + } + + #[tokio::test] + async fn count_over_time() { + let query = "count_over_time(some_metric[5m])"; + let expected = String::from( + "Filter: prom_count_over_time(timestamp_range,field_0) IS NOT NULL [timestamp:Timestamp(Millisecond, None), prom_count_over_time(timestamp_range,field_0):Float64;N, tag_0:Utf8]\ + \n Projection: some_metric.timestamp, prom_count_over_time(timestamp_range, field_0) AS prom_count_over_time(timestamp_range,field_0), some_metric.tag_0 [timestamp:Timestamp(Millisecond, None), prom_count_over_time(timestamp_range,field_0):Float64;N, tag_0:Utf8]\ + \n PromRangeManipulate: req range=[0..100000000], interval=[5000], eval range=[300000], time index=[timestamp], values=[\"field_0\"] [tag_0:Utf8, timestamp:Timestamp(Millisecond, None), field_0:Dictionary(Int64, Float64);N, timestamp_range:Dictionary(Int64, Timestamp(Millisecond, None))]\ + \n PromSeriesNormalize: offset=[0], time index=[timestamp] [tag_0:Utf8, timestamp:Timestamp(Millisecond, None), field_0:Float64;N]\ + \n PromSeriesDivide: tags=[\"tag_0\"] [tag_0:Utf8, timestamp:Timestamp(Millisecond, None), field_0:Float64;N]\ + \n Sort: some_metric.tag_0 DESC NULLS LAST, some_metric.timestamp DESC NULLS LAST [tag_0:Utf8, timestamp:Timestamp(Millisecond, None), field_0:Float64;N]\ + \n Filter: some_metric.timestamp >= TimestampMillisecond(-1000, None) AND some_metric.timestamp <= TimestampMillisecond(100000000, None) [tag_0:Utf8, timestamp:Timestamp(Millisecond, None), field_0:Float64;N]\ + \n TableScan: some_metric, unsupported_filters=[timestamp >= TimestampMillisecond(-1000, None), timestamp <= TimestampMillisecond(100000000, None)] [tag_0:Utf8, timestamp:Timestamp(Millisecond, None), field_0:Float64;N]" ); indie_query_plan_compare(query, expected).await;