diff --git a/src/common/function/src/scalars/math.rs b/src/common/function/src/scalars/math.rs index 152ba999f3..fd37a9fd6e 100644 --- a/src/common/function/src/scalars/math.rs +++ b/src/common/function/src/scalars/math.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -mod clamp; +pub mod clamp; mod modulo; mod pow; mod rate; @@ -20,7 +20,7 @@ mod rate; use std::fmt; use std::sync::Arc; -pub use clamp::ClampFunction; +pub use clamp::{ClampFunction, ClampMaxFunction, ClampMinFunction}; use common_query::error::{GeneralDataFusionSnafu, Result}; use common_query::prelude::Signature; use datafusion::error::DataFusionError; @@ -44,6 +44,8 @@ impl MathFunction { registry.register(Arc::new(RateFunction)); registry.register(Arc::new(RangeFunction)); registry.register(Arc::new(ClampFunction)); + registry.register(Arc::new(ClampMinFunction)); + registry.register(Arc::new(ClampMaxFunction)); } } diff --git a/src/common/function/src/scalars/math/clamp.rs b/src/common/function/src/scalars/math/clamp.rs index 6c19da8212..87d1b0ff1e 100644 --- a/src/common/function/src/scalars/math/clamp.rs +++ b/src/common/function/src/scalars/math/clamp.rs @@ -155,6 +155,182 @@ fn clamp_impl::from(result))) } +#[derive(Clone, Debug, Default)] +pub struct ClampMinFunction; + +const CLAMP_MIN_NAME: &str = "clamp_min"; + +impl Function for ClampMinFunction { + fn name(&self) -> &str { + CLAMP_MIN_NAME + } + + fn return_type(&self, input_types: &[ConcreteDataType]) -> Result { + Ok(input_types[0].clone()) + } + + fn signature(&self) -> Signature { + // input, min + Signature::uniform(2, ConcreteDataType::numerics(), Volatility::Immutable) + } + + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { + ensure!( + columns.len() == 2, + InvalidFuncArgsSnafu { + err_msg: format!( + "The length of the args is not correct, expect exactly 2, have: {}", + columns.len() + ), + } + ); + ensure!( + columns[0].data_type().is_numeric(), + InvalidFuncArgsSnafu { + err_msg: format!( + "The first arg's type is not numeric, have: {}", + columns[0].data_type() + ), + } + ); + ensure!( + columns[0].data_type() == columns[1].data_type(), + InvalidFuncArgsSnafu { + err_msg: format!( + "Arguments don't have identical types: {}, {}", + columns[0].data_type(), + columns[1].data_type() + ), + } + ); + ensure!( + columns[1].len() == 1, + InvalidFuncArgsSnafu { + err_msg: format!( + "The second arg (min) should be scalar, have: {:?}", + columns[1] + ), + } + ); + + with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| { + let input_array = columns[0].to_arrow_array(); + let input = input_array + .as_any() + .downcast_ref::::ArrowPrimitive>>() + .unwrap(); + + let min = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0)) + .with_context(|| { + InvalidFuncArgsSnafu { + err_msg: "The second arg (min) should not be none", + } + })?; + // For clamp_min, max is effectively infinity, so we don't use it in the clamp_impl logic. + // We pass a default/dummy value for max. + let max_dummy = <$S as LogicalPrimitiveType>::Native::default(); + + clamp_impl::<$S, true, false>(input, min, max_dummy) + },{ + unreachable!() + }) + } +} + +impl Display for ClampMinFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", CLAMP_MIN_NAME.to_ascii_uppercase()) + } +} + +#[derive(Clone, Debug, Default)] +pub struct ClampMaxFunction; + +const CLAMP_MAX_NAME: &str = "clamp_max"; + +impl Function for ClampMaxFunction { + fn name(&self) -> &str { + CLAMP_MAX_NAME + } + + fn return_type(&self, input_types: &[ConcreteDataType]) -> Result { + Ok(input_types[0].clone()) + } + + fn signature(&self) -> Signature { + // input, max + Signature::uniform(2, ConcreteDataType::numerics(), Volatility::Immutable) + } + + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { + ensure!( + columns.len() == 2, + InvalidFuncArgsSnafu { + err_msg: format!( + "The length of the args is not correct, expect exactly 2, have: {}", + columns.len() + ), + } + ); + ensure!( + columns[0].data_type().is_numeric(), + InvalidFuncArgsSnafu { + err_msg: format!( + "The first arg's type is not numeric, have: {}", + columns[0].data_type() + ), + } + ); + ensure!( + columns[0].data_type() == columns[1].data_type(), + InvalidFuncArgsSnafu { + err_msg: format!( + "Arguments don't have identical types: {}, {}", + columns[0].data_type(), + columns[1].data_type() + ), + } + ); + ensure!( + columns[1].len() == 1, + InvalidFuncArgsSnafu { + err_msg: format!( + "The second arg (max) should be scalar, have: {:?}", + columns[1] + ), + } + ); + + with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| { + let input_array = columns[0].to_arrow_array(); + let input = input_array + .as_any() + .downcast_ref::::ArrowPrimitive>>() + .unwrap(); + + let max = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0)) + .with_context(|| { + InvalidFuncArgsSnafu { + err_msg: "The second arg (max) should not be none", + } + })?; + // For clamp_max, min is effectively -infinity, so we don't use it in the clamp_impl logic. + // We pass a default/dummy value for min. + let min_dummy = <$S as LogicalPrimitiveType>::Native::default(); + + clamp_impl::<$S, false, true>(input, min_dummy, max) + },{ + unreachable!() + }) + } +} + +impl Display for ClampMaxFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", CLAMP_MAX_NAME.to_ascii_uppercase()) + } +} + #[cfg(test)] mod test { @@ -394,4 +570,134 @@ mod test { let result = func.eval(&FunctionContext::default(), args.as_slice()); assert!(result.is_err()); } + + #[test] + fn clamp_min_i64() { + let inputs = [ + ( + vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)], + -1, + vec![Some(-1), Some(-1), Some(-1), Some(0), Some(1), Some(2)], + ), + ( + vec![Some(-3), None, Some(-1), None, None, Some(2)], + -2, + vec![Some(-2), None, Some(-1), None, None, Some(2)], + ), + ]; + + let func = ClampMinFunction; + for (in_data, min, expected) in inputs { + let args = [ + Arc::new(Int64Vector::from(in_data)) as _, + Arc::new(Int64Vector::from_vec(vec![min])) as _, + ]; + let result = func + .eval(&FunctionContext::default(), args.as_slice()) + .unwrap(); + let expected: VectorRef = Arc::new(Int64Vector::from(expected)); + assert_eq!(expected, result); + } + } + + #[test] + fn clamp_max_i64() { + let inputs = [ + ( + vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)], + 1, + vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(1)], + ), + ( + vec![Some(-3), None, Some(-1), None, None, Some(2)], + 0, + vec![Some(-3), None, Some(-1), None, None, Some(0)], + ), + ]; + + let func = ClampMaxFunction; + for (in_data, max, expected) in inputs { + let args = [ + Arc::new(Int64Vector::from(in_data)) as _, + Arc::new(Int64Vector::from_vec(vec![max])) as _, + ]; + let result = func + .eval(&FunctionContext::default(), args.as_slice()) + .unwrap(); + let expected: VectorRef = Arc::new(Int64Vector::from(expected)); + assert_eq!(expected, result); + } + } + + #[test] + fn clamp_min_f64() { + let inputs = [( + vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)], + -1.0, + vec![Some(-1.0), Some(-1.0), Some(-1.0), Some(0.0), Some(1.0)], + )]; + + let func = ClampMinFunction; + for (in_data, min, expected) in inputs { + let args = [ + Arc::new(Float64Vector::from(in_data)) as _, + Arc::new(Float64Vector::from_vec(vec![min])) as _, + ]; + let result = func + .eval(&FunctionContext::default(), args.as_slice()) + .unwrap(); + let expected: VectorRef = Arc::new(Float64Vector::from(expected)); + assert_eq!(expected, result); + } + } + + #[test] + fn clamp_max_f64() { + let inputs = [( + vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)], + 0.0, + vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(0.0)], + )]; + + let func = ClampMaxFunction; + for (in_data, max, expected) in inputs { + let args = [ + Arc::new(Float64Vector::from(in_data)) as _, + Arc::new(Float64Vector::from_vec(vec![max])) as _, + ]; + let result = func + .eval(&FunctionContext::default(), args.as_slice()) + .unwrap(); + let expected: VectorRef = Arc::new(Float64Vector::from(expected)); + assert_eq!(expected, result); + } + } + + #[test] + fn clamp_min_type_not_match() { + let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)]; + let min = -1; + + let func = ClampMinFunction; + let args = [ + Arc::new(Float64Vector::from(input)) as _, + Arc::new(Int64Vector::from_vec(vec![min])) as _, + ]; + let result = func.eval(&FunctionContext::default(), args.as_slice()); + assert!(result.is_err()); + } + + #[test] + fn clamp_max_type_not_match() { + let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)]; + let max = 1; + + let func = ClampMaxFunction; + let args = [ + Arc::new(Float64Vector::from(input)) as _, + Arc::new(Int64Vector::from_vec(vec![max])) as _, + ]; + let result = func.eval(&FunctionContext::default(), args.as_slice()); + assert!(result.is_err()); + } } diff --git a/tests/cases/standalone/common/function/arithmetic.result b/tests/cases/standalone/common/function/arithmetic.result index 20c612d56a..4dd780c61b 100644 --- a/tests/cases/standalone/common/function/arithmetic.result +++ b/tests/cases/standalone/common/function/arithmetic.result @@ -78,3 +78,83 @@ SELECT CLAMP(10, 1, 0); Error: 3001(EngineExecuteQuery), Invalid function args: The second arg should be less than or equal to the third arg, have: ConstantVector([Int64(1); 1]), ConstantVector([Int64(0); 1]) +SELECT CLAMP_MIN(10, 12); + ++--------------------------------+ +| clamp_min(Int64(10),Int64(12)) | ++--------------------------------+ +| 12 | ++--------------------------------+ + +SELECT CLAMP_MIN(10, 8); + ++-------------------------------+ +| clamp_min(Int64(10),Int64(8)) | ++-------------------------------+ +| 10 | ++-------------------------------+ + +SELECT CLAMP_MIN(10.5, 10.6); + ++----------------------------------------+ +| clamp_min(Float64(10.5),Float64(10.6)) | ++----------------------------------------+ +| 10.6 | ++----------------------------------------+ + +SELECT CLAMP_MIN(10.5, 10.4); + ++----------------------------------------+ +| clamp_min(Float64(10.5),Float64(10.4)) | ++----------------------------------------+ +| 10.5 | ++----------------------------------------+ + +SELECT CLAMP_MIN(-5, -3); + ++--------------------------------+ +| clamp_min(Int64(-5),Int64(-3)) | ++--------------------------------+ +| -3 | ++--------------------------------+ + +SELECT CLAMP_MAX(10, 12); + ++--------------------------------+ +| clamp_max(Int64(10),Int64(12)) | ++--------------------------------+ +| 10 | ++--------------------------------+ + +SELECT CLAMP_MAX(10, 8); + ++-------------------------------+ +| clamp_max(Int64(10),Int64(8)) | ++-------------------------------+ +| 8 | ++-------------------------------+ + +SELECT CLAMP_MAX(10.5, 10.6); + ++----------------------------------------+ +| clamp_max(Float64(10.5),Float64(10.6)) | ++----------------------------------------+ +| 10.5 | ++----------------------------------------+ + +SELECT CLAMP_MAX(10.5, 10.4); + ++----------------------------------------+ +| clamp_max(Float64(10.5),Float64(10.4)) | ++----------------------------------------+ +| 10.4 | ++----------------------------------------+ + +SELECT CLAMP_MAX(-5, -7); + ++--------------------------------+ +| clamp_max(Int64(-5),Int64(-7)) | ++--------------------------------+ +| -7 | ++--------------------------------+ + diff --git a/tests/cases/standalone/common/function/arithmetic.sql b/tests/cases/standalone/common/function/arithmetic.sql index da391d3d16..bda3111b35 100644 --- a/tests/cases/standalone/common/function/arithmetic.sql +++ b/tests/cases/standalone/common/function/arithmetic.sql @@ -1,4 +1,3 @@ - SELECT MOD(18, 4); SELECT MOD(-18, 4); @@ -23,3 +22,23 @@ SELECT CLAMP(-10, 0, 1); SELECT CLAMP(0.5, 0, 1); SELECT CLAMP(10, 1, 0); + +SELECT CLAMP_MIN(10, 12); + +SELECT CLAMP_MIN(10, 8); + +SELECT CLAMP_MIN(10.5, 10.6); + +SELECT CLAMP_MIN(10.5, 10.4); + +SELECT CLAMP_MIN(-5, -3); + +SELECT CLAMP_MAX(10, 12); + +SELECT CLAMP_MAX(10, 8); + +SELECT CLAMP_MAX(10.5, 10.6); + +SELECT CLAMP_MAX(10.5, 10.4); + +SELECT CLAMP_MAX(-5, -7);