diff --git a/src/datatypes/src/lib.rs b/src/datatypes/src/lib.rs index 3ce78322fe..0b2588c753 100644 --- a/src/datatypes/src/lib.rs +++ b/src/datatypes/src/lib.rs @@ -32,5 +32,5 @@ pub mod types; pub mod value; pub mod vectors; -pub use arrow; +pub use arrow::{self, compute}; pub use error::{Error, Result}; diff --git a/src/promql/src/functions.rs b/src/promql/src/functions.rs index dd12e1b616..4209a517c6 100644 --- a/src/promql/src/functions.rs +++ b/src/promql/src/functions.rs @@ -21,6 +21,7 @@ mod idelta; mod predict_linear; mod quantile; mod resets; +mod round; #[cfg(test)] mod test_util; @@ -39,6 +40,7 @@ pub use idelta::IDelta; pub use predict_linear::PredictLinear; pub use quantile::QuantileOverTime; pub use resets::Resets; +pub use round::Round; pub(crate) fn extract_array(columnar_value: &ColumnarValue) -> Result { if let ColumnarValue::Array(array) = columnar_value { diff --git a/src/promql/src/functions/round.rs b/src/promql/src/functions/round.rs new file mode 100644 index 0000000000..11779db22e --- /dev/null +++ b/src/promql/src/functions/round.rs @@ -0,0 +1,105 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use datafusion::error::DataFusionError; +use datafusion_expr::{create_udf, ColumnarValue, ScalarUDF, Volatility}; +use datatypes::arrow::array::AsArray; +use datatypes::arrow::datatypes::{DataType, Float64Type}; +use datatypes::compute; + +use crate::functions::extract_array; + +pub struct Round { + nearest: f64, +} + +impl Round { + fn new(nearest: f64) -> Self { + Self { nearest } + } + + pub const fn name() -> &'static str { + "prom_round" + } + + fn input_type() -> Vec { + vec![DataType::Float64] + } + + pub fn return_type() -> DataType { + DataType::Float64 + } + + pub fn scalar_udf(nearest: f64) -> ScalarUDF { + create_udf( + Self::name(), + Self::input_type(), + Self::return_type(), + Volatility::Immutable, + Arc::new(move |input: &_| Self::new(nearest).calc(input)) as _, + ) + } + + fn calc(&self, input: &[ColumnarValue]) -> Result { + assert_eq!(input.len(), 1); + + let value_array = extract_array(&input[0])?; + + if self.nearest == 0.0 { + let values = value_array.as_primitive::(); + let result = compute::unary::<_, _, Float64Type>(values, |a| a.round()); + Ok(ColumnarValue::Array(Arc::new(result) as _)) + } else { + let values = value_array.as_primitive::(); + let nearest = self.nearest; + let result = + compute::unary::<_, _, Float64Type>(values, |a| ((a / nearest).round() * nearest)); + Ok(ColumnarValue::Array(Arc::new(result) as _)) + } + } +} + +#[cfg(test)] +mod tests { + use datatypes::arrow::array::Float64Array; + + use super::*; + + fn test_round_f64(value: Vec, nearest: f64, expected: Vec) { + let round_udf = Round::scalar_udf(nearest); + let input = vec![ColumnarValue::Array(Arc::new(Float64Array::from(value)))]; + let result = round_udf.invoke_batch(&input, 1).unwrap(); + let result_array = extract_array(&result).unwrap(); + assert_eq!(result_array.len(), 1); + assert_eq!( + result_array.as_primitive::().values(), + &expected + ); + } + + #[test] + fn test_round() { + test_round_f64(vec![123.456], 0.001, vec![123.456]); + test_round_f64(vec![123.456], 0.01, vec![123.46000000000001]); + test_round_f64(vec![123.456], 0.1, vec![123.5]); + test_round_f64(vec![123.456], 0.0, vec![123.0]); + test_round_f64(vec![123.456], 1.0, vec![123.0]); + test_round_f64(vec![123.456], 10.0, vec![120.0]); + test_round_f64(vec![123.456], 100.0, vec![100.0]); + test_round_f64(vec![123.456], 105.0, vec![105.0]); + test_round_f64(vec![123.456], 1000.0, vec![0.0]); + } +} diff --git a/src/query/src/promql/planner.rs b/src/query/src/promql/planner.rs index b1cb51e829..7b6d90374d 100644 --- a/src/query/src/promql/planner.rs +++ b/src/query/src/promql/planner.rs @@ -52,7 +52,7 @@ use promql::extension_plan::{ use promql::functions::{ AbsentOverTime, AvgOverTime, Changes, CountOverTime, Delta, Deriv, HoltWinters, IDelta, Increase, LastOverTime, MaxOverTime, MinOverTime, PredictLinear, PresentOverTime, - QuantileOverTime, Rate, Resets, StddevOverTime, StdvarOverTime, SumOverTime, + QuantileOverTime, Rate, Resets, Round, StddevOverTime, StdvarOverTime, SumOverTime, }; use promql_parser::label::{MatchOp, Matcher, Matchers, METRIC_NAME}; use promql_parser::parser::token::TokenType; @@ -1509,6 +1509,20 @@ impl PromPlanner { ScalarFunc::GeneratedExpr } + "round" => { + let nearest = match other_input_exprs.pop_front() { + Some(DfExpr::Literal(ScalarValue::Float64(Some(t)))) => t, + Some(DfExpr::Literal(ScalarValue::Int64(Some(t)))) => t as f64, + None => 0.0, + other => UnexpectedPlanExprSnafu { + desc: format!("expected f64 literal as t, but found {:?}", other), + } + .fail()?, + }; + + ScalarFunc::DataFusionUdf(Arc::new(Round::scalar_udf(nearest))) + } + _ => { if let Some(f) = session_state.scalar_functions().get(func.name) { ScalarFunc::DataFusionBuiltin(f.clone()) diff --git a/tests/cases/standalone/common/promql/round_fn.result b/tests/cases/standalone/common/promql/round_fn.result new file mode 100644 index 0000000000..fe12ca6f67 --- /dev/null +++ b/tests/cases/standalone/common/promql/round_fn.result @@ -0,0 +1,81 @@ +create table cache_hit ( + ts timestamp time index, + job string, + greptime_value double, + primary key (job) +); + +Affected Rows: 0 + +insert into cache_hit values + (3000, "read", 123.45), + (3000, "write", 234.567), + (4000, "read", 345.678), + (4000, "write", 456.789); + +Affected Rows: 4 + +-- SQLNESS SORT_RESULT 3 1 +tql eval (3, 4, '1s') round(cache_hit, 0.01); + ++---------------------+----------------------------+-------+ +| ts | prom_round(greptime_value) | job | ++---------------------+----------------------------+-------+ +| 1970-01-01T00:00:03 | 123.45 | read | +| 1970-01-01T00:00:03 | 234.57 | write | +| 1970-01-01T00:00:04 | 345.68 | read | +| 1970-01-01T00:00:04 | 456.79 | write | ++---------------------+----------------------------+-------+ + +-- SQLNESS SORT_RESULT 3 1 +tql eval (3, 4, '1s') round(cache_hit, 0.1); + ++---------------------+----------------------------+-------+ +| ts | prom_round(greptime_value) | job | ++---------------------+----------------------------+-------+ +| 1970-01-01T00:00:03 | 123.5 | read | +| 1970-01-01T00:00:03 | 234.60000000000002 | write | +| 1970-01-01T00:00:04 | 345.70000000000005 | read | +| 1970-01-01T00:00:04 | 456.8 | write | ++---------------------+----------------------------+-------+ + +-- SQLNESS SORT_RESULT 3 1 +tql eval (3, 4, '1s') round(cache_hit, 1.0); + ++---------------------+----------------------------+-------+ +| ts | prom_round(greptime_value) | job | ++---------------------+----------------------------+-------+ +| 1970-01-01T00:00:03 | 123.0 | read | +| 1970-01-01T00:00:03 | 235.0 | write | +| 1970-01-01T00:00:04 | 346.0 | read | +| 1970-01-01T00:00:04 | 457.0 | write | ++---------------------+----------------------------+-------+ + +-- SQLNESS SORT_RESULT 3 1 +tql eval (3, 4, '1s') round(cache_hit); + ++---------------------+----------------------------+-------+ +| ts | prom_round(greptime_value) | job | ++---------------------+----------------------------+-------+ +| 1970-01-01T00:00:03 | 123.0 | read | +| 1970-01-01T00:00:03 | 235.0 | write | +| 1970-01-01T00:00:04 | 346.0 | read | +| 1970-01-01T00:00:04 | 457.0 | write | ++---------------------+----------------------------+-------+ + +-- SQLNESS SORT_RESULT 3 1 +tql eval (3, 4, '1s') round(cache_hit, 10.0); + ++---------------------+----------------------------+-------+ +| ts | prom_round(greptime_value) | job | ++---------------------+----------------------------+-------+ +| 1970-01-01T00:00:03 | 120.0 | read | +| 1970-01-01T00:00:03 | 230.0 | write | +| 1970-01-01T00:00:04 | 350.0 | read | +| 1970-01-01T00:00:04 | 460.0 | write | ++---------------------+----------------------------+-------+ + +drop table cache_hit; + +Affected Rows: 0 + diff --git a/tests/cases/standalone/common/promql/round_fn.sql b/tests/cases/standalone/common/promql/round_fn.sql new file mode 100644 index 0000000000..a623cc8adb --- /dev/null +++ b/tests/cases/standalone/common/promql/round_fn.sql @@ -0,0 +1,30 @@ + +create table cache_hit ( + ts timestamp time index, + job string, + greptime_value double, + primary key (job) +); + +insert into cache_hit values + (3000, "read", 123.45), + (3000, "write", 234.567), + (4000, "read", 345.678), + (4000, "write", 456.789); + +-- SQLNESS SORT_RESULT 3 1 +tql eval (3, 4, '1s') round(cache_hit, 0.01); + +-- SQLNESS SORT_RESULT 3 1 +tql eval (3, 4, '1s') round(cache_hit, 0.1); + +-- SQLNESS SORT_RESULT 3 1 +tql eval (3, 4, '1s') round(cache_hit, 1.0); + +-- SQLNESS SORT_RESULT 3 1 +tql eval (3, 4, '1s') round(cache_hit); + +-- SQLNESS SORT_RESULT 3 1 +tql eval (3, 4, '1s') round(cache_hit, 10.0); + +drop table cache_hit;