feat: implement clamp_min and clamp_max (#6116)

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
This commit is contained in:
Ruihang Xia
2025-05-20 05:32:03 +08:00
committed by GitHub
parent a56e6e04c2
commit cd9b6990bf
4 changed files with 410 additions and 3 deletions

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
mod clamp;
pub mod clamp;
mod modulo;
mod pow;
mod rate;
@@ -20,7 +20,7 @@ mod rate;
use std::fmt;
use std::sync::Arc;
pub use clamp::ClampFunction;
pub use clamp::{ClampFunction, ClampMaxFunction, ClampMinFunction};
use common_query::error::{GeneralDataFusionSnafu, Result};
use common_query::prelude::Signature;
use datafusion::error::DataFusionError;
@@ -44,6 +44,8 @@ impl MathFunction {
registry.register(Arc::new(RateFunction));
registry.register(Arc::new(RangeFunction));
registry.register(Arc::new(ClampFunction));
registry.register(Arc::new(ClampMinFunction));
registry.register(Arc::new(ClampMaxFunction));
}
}

View File

@@ -155,6 +155,182 @@ fn clamp_impl<T: LogicalPrimitiveType, const CLAMP_MIN: bool, const CLAMP_MAX: b
Ok(Arc::new(PrimitiveVector::<T>::from(result)))
}
#[derive(Clone, Debug, Default)]
pub struct ClampMinFunction;
const CLAMP_MIN_NAME: &str = "clamp_min";
impl Function for ClampMinFunction {
fn name(&self) -> &str {
CLAMP_MIN_NAME
}
fn return_type(&self, input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
Ok(input_types[0].clone())
}
fn signature(&self) -> Signature {
// input, min
Signature::uniform(2, ConcreteDataType::numerics(), Volatility::Immutable)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly 2, have: {}",
columns.len()
),
}
);
ensure!(
columns[0].data_type().is_numeric(),
InvalidFuncArgsSnafu {
err_msg: format!(
"The first arg's type is not numeric, have: {}",
columns[0].data_type()
),
}
);
ensure!(
columns[0].data_type() == columns[1].data_type(),
InvalidFuncArgsSnafu {
err_msg: format!(
"Arguments don't have identical types: {}, {}",
columns[0].data_type(),
columns[1].data_type()
),
}
);
ensure!(
columns[1].len() == 1,
InvalidFuncArgsSnafu {
err_msg: format!(
"The second arg (min) should be scalar, have: {:?}",
columns[1]
),
}
);
with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
let input_array = columns[0].to_arrow_array();
let input = input_array
.as_any()
.downcast_ref::<PrimitiveArray<<$S as LogicalPrimitiveType>::ArrowPrimitive>>()
.unwrap();
let min = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0))
.with_context(|| {
InvalidFuncArgsSnafu {
err_msg: "The second arg (min) should not be none",
}
})?;
// For clamp_min, max is effectively infinity, so we don't use it in the clamp_impl logic.
// We pass a default/dummy value for max.
let max_dummy = <$S as LogicalPrimitiveType>::Native::default();
clamp_impl::<$S, true, false>(input, min, max_dummy)
},{
unreachable!()
})
}
}
impl Display for ClampMinFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", CLAMP_MIN_NAME.to_ascii_uppercase())
}
}
#[derive(Clone, Debug, Default)]
pub struct ClampMaxFunction;
const CLAMP_MAX_NAME: &str = "clamp_max";
impl Function for ClampMaxFunction {
fn name(&self) -> &str {
CLAMP_MAX_NAME
}
fn return_type(&self, input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
Ok(input_types[0].clone())
}
fn signature(&self) -> Signature {
// input, max
Signature::uniform(2, ConcreteDataType::numerics(), Volatility::Immutable)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly 2, have: {}",
columns.len()
),
}
);
ensure!(
columns[0].data_type().is_numeric(),
InvalidFuncArgsSnafu {
err_msg: format!(
"The first arg's type is not numeric, have: {}",
columns[0].data_type()
),
}
);
ensure!(
columns[0].data_type() == columns[1].data_type(),
InvalidFuncArgsSnafu {
err_msg: format!(
"Arguments don't have identical types: {}, {}",
columns[0].data_type(),
columns[1].data_type()
),
}
);
ensure!(
columns[1].len() == 1,
InvalidFuncArgsSnafu {
err_msg: format!(
"The second arg (max) should be scalar, have: {:?}",
columns[1]
),
}
);
with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
let input_array = columns[0].to_arrow_array();
let input = input_array
.as_any()
.downcast_ref::<PrimitiveArray<<$S as LogicalPrimitiveType>::ArrowPrimitive>>()
.unwrap();
let max = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0))
.with_context(|| {
InvalidFuncArgsSnafu {
err_msg: "The second arg (max) should not be none",
}
})?;
// For clamp_max, min is effectively -infinity, so we don't use it in the clamp_impl logic.
// We pass a default/dummy value for min.
let min_dummy = <$S as LogicalPrimitiveType>::Native::default();
clamp_impl::<$S, false, true>(input, min_dummy, max)
},{
unreachable!()
})
}
}
impl Display for ClampMaxFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", CLAMP_MAX_NAME.to_ascii_uppercase())
}
}
#[cfg(test)]
mod test {
@@ -394,4 +570,134 @@ mod test {
let result = func.eval(&FunctionContext::default(), args.as_slice());
assert!(result.is_err());
}
#[test]
fn clamp_min_i64() {
let inputs = [
(
vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
-1,
vec![Some(-1), Some(-1), Some(-1), Some(0), Some(1), Some(2)],
),
(
vec![Some(-3), None, Some(-1), None, None, Some(2)],
-2,
vec![Some(-2), None, Some(-1), None, None, Some(2)],
),
];
let func = ClampMinFunction;
for (in_data, min, expected) in inputs {
let args = [
Arc::new(Int64Vector::from(in_data)) as _,
Arc::new(Int64Vector::from_vec(vec![min])) as _,
];
let result = func
.eval(&FunctionContext::default(), args.as_slice())
.unwrap();
let expected: VectorRef = Arc::new(Int64Vector::from(expected));
assert_eq!(expected, result);
}
}
#[test]
fn clamp_max_i64() {
let inputs = [
(
vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
1,
vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(1)],
),
(
vec![Some(-3), None, Some(-1), None, None, Some(2)],
0,
vec![Some(-3), None, Some(-1), None, None, Some(0)],
),
];
let func = ClampMaxFunction;
for (in_data, max, expected) in inputs {
let args = [
Arc::new(Int64Vector::from(in_data)) as _,
Arc::new(Int64Vector::from_vec(vec![max])) as _,
];
let result = func
.eval(&FunctionContext::default(), args.as_slice())
.unwrap();
let expected: VectorRef = Arc::new(Int64Vector::from(expected));
assert_eq!(expected, result);
}
}
#[test]
fn clamp_min_f64() {
let inputs = [(
vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
-1.0,
vec![Some(-1.0), Some(-1.0), Some(-1.0), Some(0.0), Some(1.0)],
)];
let func = ClampMinFunction;
for (in_data, min, expected) in inputs {
let args = [
Arc::new(Float64Vector::from(in_data)) as _,
Arc::new(Float64Vector::from_vec(vec![min])) as _,
];
let result = func
.eval(&FunctionContext::default(), args.as_slice())
.unwrap();
let expected: VectorRef = Arc::new(Float64Vector::from(expected));
assert_eq!(expected, result);
}
}
#[test]
fn clamp_max_f64() {
let inputs = [(
vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)],
0.0,
vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(0.0)],
)];
let func = ClampMaxFunction;
for (in_data, max, expected) in inputs {
let args = [
Arc::new(Float64Vector::from(in_data)) as _,
Arc::new(Float64Vector::from_vec(vec![max])) as _,
];
let result = func
.eval(&FunctionContext::default(), args.as_slice())
.unwrap();
let expected: VectorRef = Arc::new(Float64Vector::from(expected));
assert_eq!(expected, result);
}
}
#[test]
fn clamp_min_type_not_match() {
let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
let min = -1;
let func = ClampMinFunction;
let args = [
Arc::new(Float64Vector::from(input)) as _,
Arc::new(Int64Vector::from_vec(vec![min])) as _,
];
let result = func.eval(&FunctionContext::default(), args.as_slice());
assert!(result.is_err());
}
#[test]
fn clamp_max_type_not_match() {
let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
let max = 1;
let func = ClampMaxFunction;
let args = [
Arc::new(Float64Vector::from(input)) as _,
Arc::new(Int64Vector::from_vec(vec![max])) as _,
];
let result = func.eval(&FunctionContext::default(), args.as_slice());
assert!(result.is_err());
}
}

View File

@@ -78,3 +78,83 @@ SELECT CLAMP(10, 1, 0);
Error: 3001(EngineExecuteQuery), Invalid function args: The second arg should be less than or equal to the third arg, have: ConstantVector([Int64(1); 1]), ConstantVector([Int64(0); 1])
SELECT CLAMP_MIN(10, 12);
+--------------------------------+
| clamp_min(Int64(10),Int64(12)) |
+--------------------------------+
| 12 |
+--------------------------------+
SELECT CLAMP_MIN(10, 8);
+-------------------------------+
| clamp_min(Int64(10),Int64(8)) |
+-------------------------------+
| 10 |
+-------------------------------+
SELECT CLAMP_MIN(10.5, 10.6);
+----------------------------------------+
| clamp_min(Float64(10.5),Float64(10.6)) |
+----------------------------------------+
| 10.6 |
+----------------------------------------+
SELECT CLAMP_MIN(10.5, 10.4);
+----------------------------------------+
| clamp_min(Float64(10.5),Float64(10.4)) |
+----------------------------------------+
| 10.5 |
+----------------------------------------+
SELECT CLAMP_MIN(-5, -3);
+--------------------------------+
| clamp_min(Int64(-5),Int64(-3)) |
+--------------------------------+
| -3 |
+--------------------------------+
SELECT CLAMP_MAX(10, 12);
+--------------------------------+
| clamp_max(Int64(10),Int64(12)) |
+--------------------------------+
| 10 |
+--------------------------------+
SELECT CLAMP_MAX(10, 8);
+-------------------------------+
| clamp_max(Int64(10),Int64(8)) |
+-------------------------------+
| 8 |
+-------------------------------+
SELECT CLAMP_MAX(10.5, 10.6);
+----------------------------------------+
| clamp_max(Float64(10.5),Float64(10.6)) |
+----------------------------------------+
| 10.5 |
+----------------------------------------+
SELECT CLAMP_MAX(10.5, 10.4);
+----------------------------------------+
| clamp_max(Float64(10.5),Float64(10.4)) |
+----------------------------------------+
| 10.4 |
+----------------------------------------+
SELECT CLAMP_MAX(-5, -7);
+--------------------------------+
| clamp_max(Int64(-5),Int64(-7)) |
+--------------------------------+
| -7 |
+--------------------------------+

View File

@@ -1,4 +1,3 @@
SELECT MOD(18, 4);
SELECT MOD(-18, 4);
@@ -23,3 +22,23 @@ SELECT CLAMP(-10, 0, 1);
SELECT CLAMP(0.5, 0, 1);
SELECT CLAMP(10, 1, 0);
SELECT CLAMP_MIN(10, 12);
SELECT CLAMP_MIN(10, 8);
SELECT CLAMP_MIN(10.5, 10.6);
SELECT CLAMP_MIN(10.5, 10.4);
SELECT CLAMP_MIN(-5, -3);
SELECT CLAMP_MAX(10, 12);
SELECT CLAMP_MAX(10, 8);
SELECT CLAMP_MAX(10.5, 10.6);
SELECT CLAMP_MAX(10.5, 10.4);
SELECT CLAMP_MAX(-5, -7);