diff --git a/src/common/macro/src/range_fn.rs b/src/common/macro/src/range_fn.rs index 622e21ef6c..c907f1d0d1 100644 --- a/src/common/macro/src/range_fn.rs +++ b/src/common/macro/src/range_fn.rs @@ -56,6 +56,18 @@ pub(crate) fn process_range_fn(args: TokenStream, input: TokenStream) -> TokenSt } = &sig; let arg_types = ok!(extract_input_types(inputs)); + // with format like Float64Array + let array_types = arg_types + .iter() + .map(|ty| { + if let Type::Reference(TypeReference { elem, .. }) = ty { + elem.as_ref().clone() + } else { + ty.clone() + } + }) + .collect::>(); + // build the struct and its impl block // only do this when `display_name` is specified if let Ok(display_name) = get_ident(&arg_map, "display_name", arg_span) { @@ -64,6 +76,8 @@ pub(crate) fn process_range_fn(args: TokenStream, input: TokenStream) -> TokenSt vis, ok!(get_ident(&arg_map, "name", arg_span)), display_name, + array_types, + ok!(get_ident(&arg_map, "ret", arg_span)), ); result.extend(struct_code); } @@ -90,6 +104,8 @@ fn build_struct( vis: Visibility, name: Ident, display_name_ident: Ident, + array_types: Vec, + return_array_type: Ident, ) -> TokenStream { let display_name = display_name_ident.to_string(); quote! { @@ -114,18 +130,12 @@ fn build_struct( } } - // TODO(ruihang): this should be parameterized - // time index column and value column fn input_type() -> Vec { - vec![ - RangeArray::convert_data_type(DataType::Timestamp(TimeUnit::Millisecond, None)), - RangeArray::convert_data_type(DataType::Float64), - ] + vec![#( RangeArray::convert_data_type(#array_types::new_null(0).data_type().clone()), )*] } - // TODO(ruihang): this should be parameterized fn return_type() -> DataType { - DataType::Float64 + #return_array_type::new_null(0).data_type().clone() } } } @@ -160,6 +170,7 @@ fn build_calc_fn( .map(|name| Ident::new(&format!("{}_range_array", name), name.span())) .collect::>(); let first_range_array_name = range_array_names.first().unwrap().clone(); + let first_param_name = param_names.first().unwrap().clone(); quote! { impl #name { @@ -168,13 +179,29 @@ fn build_calc_fn( #( let #range_array_names = RangeArray::try_new(extract_array(&input[#param_numbers])?.to_data().into())?; )* - // TODO(ruihang): add ensure!() + // check arrays len + { + let len_first = #first_range_array_name.len(); + #( + if len_first != #range_array_names.len() { + return Err(DataFusionError::Execution(format!("RangeArray have different lengths in PromQL function {}: array1={}, array2={}", #name::name(), len_first, #range_array_names.len()))); + } + )* + } let mut result_array = Vec::new(); for index in 0..#first_range_array_name.len(){ #( let #param_names = #range_array_names.get(index).unwrap().as_any().downcast_ref::<#unref_param_types>().unwrap().clone(); )* - // TODO(ruihang): add ensure!() to check length + // check element len + { + let len_first = #first_param_name.len(); + #( + if len_first != #param_names.len() { + return Err(DataFusionError::Execution(format!("RangeArray's element {} have different lengths in PromQL function {}: array1={}, array2={}", index, #name::name(), len_first, #param_names.len()))); + } + )* + } let result = #fn_name(#( &#param_names, )*); result_array.push(result); diff --git a/src/promql/src/functions/aggr_over_time.rs b/src/promql/src/functions/aggr_over_time.rs index f57b2d5612..e02e4a7d91 100644 --- a/src/promql/src/functions/aggr_over_time.rs +++ b/src/promql/src/functions/aggr_over_time.rs @@ -16,7 +16,6 @@ use std::sync::Arc; use common_macro::range_fn; use datafusion::arrow::array::{Float64Array, TimestampMillisecondArray}; -use datafusion::arrow::datatypes::TimeUnit; use datafusion::common::DataFusionError; use datafusion::logical_expr::{ScalarUDF, Signature, TypeSignature, Volatility}; use datafusion::physical_plan::ColumnarValue; diff --git a/src/promql/src/functions/changes.rs b/src/promql/src/functions/changes.rs index a8b29c9cbd..bb547e87f1 100644 --- a/src/promql/src/functions/changes.rs +++ b/src/promql/src/functions/changes.rs @@ -19,7 +19,6 @@ use std::sync::Arc; use common_macro::range_fn; use datafusion::arrow::array::{Float64Array, TimestampMillisecondArray}; -use datafusion::arrow::datatypes::TimeUnit; use datafusion::common::DataFusionError; use datafusion::logical_expr::{ScalarUDF, Signature, TypeSignature, Volatility}; use datafusion::physical_plan::ColumnarValue; diff --git a/src/promql/src/functions/deriv.rs b/src/promql/src/functions/deriv.rs index e573242e82..462637ceb5 100644 --- a/src/promql/src/functions/deriv.rs +++ b/src/promql/src/functions/deriv.rs @@ -19,7 +19,6 @@ use std::sync::Arc; use common_macro::range_fn; use datafusion::arrow::array::{Float64Array, TimestampMillisecondArray}; -use datafusion::arrow::datatypes::TimeUnit; use datafusion::common::DataFusionError; use datafusion::logical_expr::{ScalarUDF, Signature, TypeSignature, Volatility}; use datafusion::physical_plan::ColumnarValue; diff --git a/src/promql/src/functions/resets.rs b/src/promql/src/functions/resets.rs index 218e190873..00dec32d01 100644 --- a/src/promql/src/functions/resets.rs +++ b/src/promql/src/functions/resets.rs @@ -19,7 +19,6 @@ use std::sync::Arc; use common_macro::range_fn; use datafusion::arrow::array::{Float64Array, TimestampMillisecondArray}; -use datafusion::arrow::datatypes::TimeUnit; use datafusion::common::DataFusionError; use datafusion::logical_expr::{ScalarUDF, Signature, TypeSignature, Volatility}; use datafusion::physical_plan::ColumnarValue;