diff --git a/src/common/function/src/scalars/math.rs b/src/common/function/src/scalars/math.rs index 660bbffda3..f7d50f881d 100644 --- a/src/common/function/src/scalars/math.rs +++ b/src/common/function/src/scalars/math.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +mod clamp; mod modulo; mod pow; mod rate; @@ -19,6 +20,7 @@ mod rate; use std::fmt; use std::sync::Arc; +pub use clamp::ClampFunction; use common_query::error::{GeneralDataFusionSnafu, Result}; use common_query::prelude::Signature; use datafusion::error::DataFusionError; @@ -40,7 +42,8 @@ impl MathFunction { registry.register(Arc::new(ModuloFunction)); registry.register(Arc::new(PowFunction)); registry.register(Arc::new(RateFunction)); - registry.register(Arc::new(RangeFunction)) + registry.register(Arc::new(RangeFunction)); + registry.register(Arc::new(ClampFunction)); } } diff --git a/src/common/function/src/scalars/math/clamp.rs b/src/common/function/src/scalars/math/clamp.rs new file mode 100644 index 0000000000..58a2dcefd4 --- /dev/null +++ b/src/common/function/src/scalars/math/clamp.rs @@ -0,0 +1,403 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fmt::{self, Display}; +use std::sync::Arc; + +use common_query::error::{InvalidFuncArgsSnafu, Result}; +use common_query::prelude::Signature; +use datafusion::arrow::array::{ArrayIter, PrimitiveArray}; +use datafusion::logical_expr::Volatility; +use datatypes::data_type::{ConcreteDataType, DataType}; +use datatypes::prelude::VectorRef; +use datatypes::types::LogicalPrimitiveType; +use datatypes::value::TryAsPrimitive; +use datatypes::vectors::PrimitiveVector; +use datatypes::with_match_primitive_type_id; +use snafu::{ensure, OptionExt}; + +use crate::function::Function; + +#[derive(Clone, Debug, Default)] +pub struct ClampFunction; + +const CLAMP_NAME: &str = "clamp"; + +impl Function for ClampFunction { + fn name(&self) -> &str { + CLAMP_NAME + } + + fn return_type(&self, input_types: &[ConcreteDataType]) -> Result { + // Type check is done by `signature` + Ok(input_types[0].clone()) + } + + fn signature(&self) -> Signature { + // input, min, max + Signature::uniform(3, ConcreteDataType::numerics(), Volatility::Immutable) + } + + fn eval( + &self, + _func_ctx: crate::function::FunctionContext, + columns: &[VectorRef], + ) -> Result { + ensure!( + columns.len() == 3, + InvalidFuncArgsSnafu { + err_msg: format!( + "The length of the args is not correct, expect exactly 3, have: {}", + columns.len() + ), + } + ); + ensure!( + columns[0].data_type().is_numeric(), + InvalidFuncArgsSnafu { + err_msg: format!( + "The first arg's type is not numeric, have: {}", + columns[0].data_type() + ), + } + ); + ensure!( + columns[0].data_type() == columns[1].data_type() + && columns[1].data_type() == columns[2].data_type(), + InvalidFuncArgsSnafu { + err_msg: format!( + "Arguments don't have identical types: {}, {}, {}", + columns[0].data_type(), + columns[1].data_type(), + columns[2].data_type() + ), + } + ); + ensure!( + columns[1].len() == 1 && columns[2].len() == 1, + InvalidFuncArgsSnafu { + err_msg: format!( + "The second and third args should be scalar, have: {:?}, {:?}", + columns[1], columns[2] + ), + } + ); + + with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| { + let input_array = columns[0].to_arrow_array(); + let input = input_array + .as_any() + .downcast_ref::::ArrowPrimitive>>() + .unwrap(); + + let min = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0)) + .with_context(|| { + InvalidFuncArgsSnafu { + err_msg: "The second arg should not be none", + } + })?; + let max = TryAsPrimitive::<$S>::try_as_primitive(&columns[2].get(0)) + .with_context(|| { + InvalidFuncArgsSnafu { + err_msg: "The third arg should not be none", + } + })?; + + // ensure min <= max + ensure!( + min <= max, + InvalidFuncArgsSnafu { + err_msg: format!( + "The second arg should be less than or equal to the third arg, have: {:?}, {:?}", + columns[1], columns[2] + ), + } + ); + + clamp_impl::<$S, true, true>(input, min, max) + },{ + unreachable!() + }) + } +} + +impl Display for ClampFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", CLAMP_NAME.to_ascii_uppercase()) + } +} + +fn clamp_impl( + input: &PrimitiveArray, + min: T::Native, + max: T::Native, +) -> Result { + common_telemetry::info!("[DEBUG] min {min:?}, max {max:?}"); + + let iter = ArrayIter::new(input); + let result = iter.map(|x| { + x.map(|x| { + if CLAMP_MIN && x < min { + min + } else if CLAMP_MAX && x > max { + max + } else { + x + } + }) + }); + let result = PrimitiveArray::::from_iter(result); + Ok(Arc::new(PrimitiveVector::::from(result))) +} + +#[cfg(test)] +mod test { + + use std::sync::Arc; + + use datatypes::prelude::ScalarVector; + use datatypes::vectors::{ + ConstantVector, Float64Vector, Int64Vector, StringVector, UInt64Vector, + }; + + use super::*; + use crate::function::FunctionContext; + + #[test] + fn clamp_i64() { + let inputs = [ + ( + vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)], + -1, + 10, + vec![Some(-1), Some(-1), Some(-1), Some(0), Some(1), Some(2)], + ), + ( + vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)], + 0, + 0, + vec![Some(0), Some(0), Some(0), Some(0), Some(0), Some(0)], + ), + ( + vec![Some(-3), None, Some(-1), None, None, Some(2)], + -2, + 1, + vec![Some(-2), None, Some(-1), None, None, Some(1)], + ), + ( + vec![None, None, None, None, None], + 0, + 1, + vec![None, None, None, None, None], + ), + ]; + + let func = ClampFunction; + for (in_data, min, max, expected) in inputs { + let args = [ + Arc::new(Int64Vector::from(in_data)) as _, + Arc::new(Int64Vector::from_vec(vec![min])) as _, + Arc::new(Int64Vector::from_vec(vec![max])) as _, + ]; + let result = func + .eval(FunctionContext::default(), args.as_slice()) + .unwrap(); + let expected: VectorRef = Arc::new(Int64Vector::from(expected)); + assert_eq!(expected, result); + } + } + + #[test] + fn clamp_u64() { + let inputs = [ + ( + vec![Some(0), Some(1), Some(2), Some(3), Some(4), Some(5)], + 1, + 3, + vec![Some(1), Some(1), Some(2), Some(3), Some(3), Some(3)], + ), + ( + vec![Some(0), Some(1), Some(2), Some(3), Some(4), Some(5)], + 0, + 0, + vec![Some(0), Some(0), Some(0), Some(0), Some(0), Some(0)], + ), + ( + vec![Some(0), None, Some(2), None, None, Some(5)], + 1, + 3, + vec![Some(1), None, Some(2), None, None, Some(3)], + ), + ( + vec![None, None, None, None, None], + 0, + 1, + vec![None, None, None, None, None], + ), + ]; + + let func = ClampFunction; + for (in_data, min, max, expected) in inputs { + let args = [ + Arc::new(UInt64Vector::from(in_data)) as _, + Arc::new(UInt64Vector::from_vec(vec![min])) as _, + Arc::new(UInt64Vector::from_vec(vec![max])) as _, + ]; + let result = func + .eval(FunctionContext::default(), args.as_slice()) + .unwrap(); + let expected: VectorRef = Arc::new(UInt64Vector::from(expected)); + assert_eq!(expected, result); + } + } + + #[test] + fn clamp_f64() { + let inputs = [ + ( + vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)], + -1.0, + 10.0, + vec![Some(-1.0), Some(-1.0), Some(-1.0), Some(0.0), Some(1.0)], + ), + ( + vec![Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)], + 0.0, + 0.0, + vec![Some(0.0), Some(0.0), Some(0.0), Some(0.0)], + ), + ( + vec![Some(-3.0), None, Some(-1.0), None, None, Some(2.0)], + -2.0, + 1.0, + vec![Some(-2.0), None, Some(-1.0), None, None, Some(1.0)], + ), + ( + vec![None, None, None, None, None], + 0.0, + 1.0, + vec![None, None, None, None, None], + ), + ]; + + let func = ClampFunction; + for (in_data, min, max, expected) in inputs { + let args = [ + Arc::new(Float64Vector::from(in_data)) as _, + Arc::new(Float64Vector::from_vec(vec![min])) as _, + Arc::new(Float64Vector::from_vec(vec![max])) as _, + ]; + let result = func + .eval(FunctionContext::default(), args.as_slice()) + .unwrap(); + let expected: VectorRef = Arc::new(Float64Vector::from(expected)); + assert_eq!(expected, result); + } + } + + #[test] + fn clamp_const_i32() { + let input = vec![Some(5)]; + let min = 2; + let max = 4; + + let func = ClampFunction; + let args = [ + Arc::new(ConstantVector::new(Arc::new(Int64Vector::from(input)), 1)) as _, + Arc::new(Int64Vector::from_vec(vec![min])) as _, + Arc::new(Int64Vector::from_vec(vec![max])) as _, + ]; + let result = func + .eval(FunctionContext::default(), args.as_slice()) + .unwrap(); + let expected: VectorRef = Arc::new(Int64Vector::from(vec![Some(4)])); + assert_eq!(expected, result); + } + + #[test] + fn clamp_invalid_min_max() { + let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)]; + let min = 10.0; + let max = -1.0; + + let func = ClampFunction; + let args = [ + Arc::new(Float64Vector::from(input)) as _, + Arc::new(Float64Vector::from_vec(vec![min])) as _, + Arc::new(Float64Vector::from_vec(vec![max])) as _, + ]; + let result = func.eval(FunctionContext::default(), args.as_slice()); + assert!(result.is_err()); + } + + #[test] + fn clamp_type_not_match() { + let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)]; + let min = -1; + let max = 10; + + let func = ClampFunction; + let args = [ + Arc::new(Float64Vector::from(input)) as _, + Arc::new(Int64Vector::from_vec(vec![min])) as _, + Arc::new(UInt64Vector::from_vec(vec![max])) as _, + ]; + let result = func.eval(FunctionContext::default(), args.as_slice()); + assert!(result.is_err()); + } + + #[test] + fn clamp_min_is_not_scalar() { + let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)]; + let min = -10.0; + let max = 1.0; + + let func = ClampFunction; + let args = [ + Arc::new(Float64Vector::from(input)) as _, + Arc::new(Float64Vector::from_vec(vec![min, min])) as _, + Arc::new(Float64Vector::from_vec(vec![max])) as _, + ]; + let result = func.eval(FunctionContext::default(), args.as_slice()); + assert!(result.is_err()); + } + + #[test] + fn clamp_no_max() { + let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)]; + let min = -10.0; + + let func = ClampFunction; + let args = [ + Arc::new(Float64Vector::from(input)) as _, + Arc::new(Float64Vector::from_vec(vec![min])) as _, + ]; + let result = func.eval(FunctionContext::default(), args.as_slice()); + assert!(result.is_err()); + } + + #[test] + fn clamp_on_string() { + let input = vec![Some("foo"), Some("foo"), Some("foo"), Some("foo")]; + + let func = ClampFunction; + let args = [ + Arc::new(StringVector::from(input)) as _, + Arc::new(StringVector::from_vec(vec!["bar"])) as _, + Arc::new(StringVector::from_vec(vec!["baz"])) as _, + ]; + let result = func.eval(FunctionContext::default(), args.as_slice()); + assert!(result.is_err()); + } +} diff --git a/src/datatypes/src/value.rs b/src/datatypes/src/value.rs index 031f355655..bfd4a11103 100644 --- a/src/datatypes/src/value.rs +++ b/src/datatypes/src/value.rs @@ -370,6 +370,36 @@ impl Value { } } +pub trait TryAsPrimitive { + fn try_as_primitive(&self) -> Option; +} + +macro_rules! impl_try_as_primitive { + ($Type: ident, $Variant: ident) => { + impl TryAsPrimitive for Value { + fn try_as_primitive( + &self, + ) -> Option<::Native> { + match self { + Value::$Variant(v) => Some((*v).into()), + _ => None, + } + } + } + }; +} + +impl_try_as_primitive!(Int8Type, Int8); +impl_try_as_primitive!(Int16Type, Int16); +impl_try_as_primitive!(Int32Type, Int32); +impl_try_as_primitive!(Int64Type, Int64); +impl_try_as_primitive!(UInt8Type, UInt8); +impl_try_as_primitive!(UInt16Type, UInt16); +impl_try_as_primitive!(UInt32Type, UInt32); +impl_try_as_primitive!(UInt64Type, UInt64); +impl_try_as_primitive!(Float32Type, Float32); +impl_try_as_primitive!(Float64Type, Float64); + pub fn to_null_scalar_value(output_type: &ConcreteDataType) -> Result { Ok(match output_type { ConcreteDataType::Null(_) => ScalarValue::Null, diff --git a/tests/cases/standalone/common/function/arithmetic.result b/tests/cases/standalone/common/function/arithmetic.result index caa8f1e397..563053fbce 100644 --- a/tests/cases/standalone/common/function/arithmetic.result +++ b/tests/cases/standalone/common/function/arithmetic.result @@ -50,3 +50,31 @@ SELECT POW (0.99, 365); | 0.025517964452291125 | +-------------------------------+ +SELECT CLAMP(10, 0, 1); + ++------------------------------------+ +| clamp(Int64(10),Int64(0),Int64(1)) | ++------------------------------------+ +| 1 | ++------------------------------------+ + +SELECT CLAMP(-10, 0, 1); + ++-------------------------------------+ +| clamp(Int64(-10),Int64(0),Int64(1)) | ++-------------------------------------+ +| 0 | ++-------------------------------------+ + +SELECT CLAMP(0.5, 0, 1); + ++---------------------------------------+ +| clamp(Float64(0.5),Int64(0),Int64(1)) | ++---------------------------------------+ +| 0.5 | ++---------------------------------------+ + +SELECT CLAMP(10, 1, 0); + +Error: 3001(EngineExecuteQuery), DataFusion error: Invalid function args: The second arg should be less than or equal to the third arg, have: ConstantVector([Int64(1); 1]), ConstantVector([Int64(0); 1]) + diff --git a/tests/cases/standalone/common/function/arithmetic.sql b/tests/cases/standalone/common/function/arithmetic.sql index fd048f00a0..da391d3d16 100644 --- a/tests/cases/standalone/common/function/arithmetic.sql +++ b/tests/cases/standalone/common/function/arithmetic.sql @@ -14,3 +14,12 @@ SELECT POW (2, 5); SELECT POW (1.01, 365); SELECT POW (0.99, 365); + + +SELECT CLAMP(10, 0, 1); + +SELECT CLAMP(-10, 0, 1); + +SELECT CLAMP(0.5, 0, 1); + +SELECT CLAMP(10, 1, 0);