From ca4d690424b03806ea0f8bd5e491585224bbf220 Mon Sep 17 00:00:00 2001 From: Eugene Tolbakov Date: Sat, 13 Jan 2024 00:24:25 +0000 Subject: [PATCH] feat: add modulo function (#3147) * feat: add modulo function * fix: address CR feedback --- src/common/function/src/scalars/math.rs | 3 + .../function/src/scalars/math/modulo.rs | 241 ++++++++++++++++++ .../common/function/arithmetic.result | 52 ++++ .../standalone/common/function/arithmetic.sql | 16 ++ 4 files changed, 312 insertions(+) create mode 100644 src/common/function/src/scalars/math/modulo.rs create mode 100644 tests/cases/standalone/common/function/arithmetic.result create mode 100644 tests/cases/standalone/common/function/arithmetic.sql diff --git a/src/common/function/src/scalars/math.rs b/src/common/function/src/scalars/math.rs index 1e9609c517..660bbffda3 100644 --- a/src/common/function/src/scalars/math.rs +++ b/src/common/function/src/scalars/math.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +mod modulo; mod pow; mod rate; @@ -30,11 +31,13 @@ use snafu::ResultExt; use crate::function::{Function, FunctionContext}; use crate::function_registry::FunctionRegistry; +use crate::scalars::math::modulo::ModuloFunction; pub(crate) struct MathFunction; impl MathFunction { pub fn register(registry: &FunctionRegistry) { + registry.register(Arc::new(ModuloFunction)); registry.register(Arc::new(PowFunction)); registry.register(Arc::new(RateFunction)); registry.register(Arc::new(RangeFunction)) diff --git a/src/common/function/src/scalars/math/modulo.rs b/src/common/function/src/scalars/math/modulo.rs new file mode 100644 index 0000000000..df2d84e66b --- /dev/null +++ b/src/common/function/src/scalars/math/modulo.rs @@ -0,0 +1,241 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fmt; +use std::fmt::Display; + +use common_query::error; +use common_query::error::{ArrowComputeSnafu, InvalidFuncArgsSnafu, Result}; +use common_query::prelude::{Signature, Volatility}; +use datatypes::arrow::compute; +use datatypes::arrow::compute::kernels::numeric; +use datatypes::arrow::datatypes::DataType as ArrowDataType; +use datatypes::prelude::ConcreteDataType; +use datatypes::vectors::{Helper, VectorRef}; +use snafu::{ensure, ResultExt}; + +use crate::function::{Function, FunctionContext}; + +const NAME: &str = "mod"; + +/// The function to find remainders +#[derive(Clone, Debug, Default)] +pub struct ModuloFunction; + +impl Display for ModuloFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", NAME.to_ascii_uppercase()) + } +} + +impl Function for ModuloFunction { + fn name(&self) -> &str { + NAME + } + + fn return_type(&self, input_types: &[ConcreteDataType]) -> Result { + if input_types.iter().all(ConcreteDataType::is_signed) { + Ok(ConcreteDataType::int64_datatype()) + } else if input_types.iter().all(ConcreteDataType::is_unsigned) { + Ok(ConcreteDataType::uint64_datatype()) + } else { + Ok(ConcreteDataType::float64_datatype()) + } + } + + fn signature(&self) -> Signature { + Signature::uniform(2, ConcreteDataType::numerics(), Volatility::Immutable) + } + + fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + ensure!( + columns.len() == 2, + InvalidFuncArgsSnafu { + err_msg: format!( + "The length of the args is not correct, expect exactly two, have: {}", + columns.len() + ), + } + ); + let nums = &columns[0]; + let divs = &columns[1]; + let nums_arrow_array = &nums.to_arrow_array(); + let divs_arrow_array = &divs.to_arrow_array(); + let array = numeric::rem(nums_arrow_array, divs_arrow_array).context(ArrowComputeSnafu)?; + + let result = match nums.data_type() { + ConcreteDataType::Int8(_) + | ConcreteDataType::Int16(_) + | ConcreteDataType::Int32(_) + | ConcreteDataType::Int64(_) => compute::cast(&array, &ArrowDataType::Int64), + ConcreteDataType::UInt8(_) + | ConcreteDataType::UInt16(_) + | ConcreteDataType::UInt32(_) + | ConcreteDataType::UInt64(_) => compute::cast(&array, &ArrowDataType::UInt64), + ConcreteDataType::Float32(_) | ConcreteDataType::Float64(_) => { + compute::cast(&array, &ArrowDataType::Float64) + } + _ => unreachable!("unexpected datatype: {:?}", nums.data_type()), + } + .context(ArrowComputeSnafu)?; + Helper::try_into_vector(&result).context(error::FromArrowArraySnafu) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use common_error::ext::ErrorExt; + use datatypes::value::Value; + use datatypes::vectors::{Float64Vector, Int32Vector, StringVector, UInt32Vector}; + + use super::*; + #[test] + fn test_mod_function_signed() { + let function = ModuloFunction; + assert_eq!("mod", function.name()); + assert_eq!( + ConcreteDataType::int64_datatype(), + function + .return_type(&[ConcreteDataType::int64_datatype()]) + .unwrap() + ); + assert_eq!( + ConcreteDataType::int64_datatype(), + function + .return_type(&[ConcreteDataType::int32_datatype()]) + .unwrap() + ); + + let nums = vec![18, -17, 5, -6]; + let divs = vec![4, 8, -5, -5]; + + let args: Vec = vec![ + Arc::new(Int32Vector::from_vec(nums.clone())), + Arc::new(Int32Vector::from_vec(divs.clone())), + ]; + let result = function.eval(FunctionContext::default(), &args).unwrap(); + assert_eq!(result.len(), 4); + for i in 0..3 { + let p: i64 = (nums[i] % divs[i]) as i64; + assert!(matches!(result.get(i), Value::Int64(v) if v == p)); + } + } + + #[test] + fn test_mod_function_unsigned() { + let function = ModuloFunction; + assert_eq!("mod", function.name()); + assert_eq!( + ConcreteDataType::uint64_datatype(), + function + .return_type(&[ConcreteDataType::uint64_datatype()]) + .unwrap() + ); + assert_eq!( + ConcreteDataType::uint64_datatype(), + function + .return_type(&[ConcreteDataType::uint32_datatype()]) + .unwrap() + ); + + let nums: Vec = vec![18, 17, 5, 6]; + let divs: Vec = vec![4, 8, 5, 5]; + + let args: Vec = vec![ + Arc::new(UInt32Vector::from_vec(nums.clone())), + Arc::new(UInt32Vector::from_vec(divs.clone())), + ]; + let result = function.eval(FunctionContext::default(), &args).unwrap(); + assert_eq!(result.len(), 4); + for i in 0..3 { + let p: u64 = (nums[i] % divs[i]) as u64; + assert!(matches!(result.get(i), Value::UInt64(v) if v == p)); + } + } + + #[test] + fn test_mod_function_float() { + let function = ModuloFunction; + assert_eq!("mod", function.name()); + assert_eq!( + ConcreteDataType::float64_datatype(), + function + .return_type(&[ConcreteDataType::float64_datatype()]) + .unwrap() + ); + assert_eq!( + ConcreteDataType::float64_datatype(), + function + .return_type(&[ConcreteDataType::float32_datatype()]) + .unwrap() + ); + + let nums = vec![18.0, 17.0, 5.0, 6.0]; + let divs = vec![4.0, 8.0, 5.0, 5.0]; + + let args: Vec = vec![ + Arc::new(Float64Vector::from_vec(nums.clone())), + Arc::new(Float64Vector::from_vec(divs.clone())), + ]; + let result = function.eval(FunctionContext::default(), &args).unwrap(); + assert_eq!(result.len(), 4); + for i in 0..3 { + let p: f64 = nums[i] % divs[i]; + assert!(matches!(result.get(i), Value::Float64(v) if v == p)); + } + } + + #[test] + fn test_mod_function_errors() { + let function = ModuloFunction; + assert_eq!("mod", function.name()); + let nums = vec![27]; + let divs = vec![0]; + + let args: Vec = vec![ + Arc::new(Int32Vector::from_vec(nums.clone())), + Arc::new(Int32Vector::from_vec(divs.clone())), + ]; + let result = function.eval(FunctionContext::default(), &args); + assert!(result.is_err()); + let err_msg = result.unwrap_err().output_msg(); + assert_eq!( + err_msg, + "Failed to perform compute operation on arrow arrays: Divide by zero error" + ); + + let nums = vec![27]; + + let args: Vec = vec![Arc::new(Int32Vector::from_vec(nums.clone()))]; + let result = function.eval(FunctionContext::default(), &args); + assert!(result.is_err()); + let err_msg = result.unwrap_err().output_msg(); + assert!( + err_msg.contains("The length of the args is not correct, expect exactly two, have: 1") + ); + + let nums = vec!["27"]; + let divs = vec!["4"]; + let args: Vec = vec![ + Arc::new(StringVector::from(nums.clone())), + Arc::new(StringVector::from(divs.clone())), + ]; + let result = function.eval(FunctionContext::default(), &args); + assert!(result.is_err()); + let err_msg = result.unwrap_err().output_msg(); + assert!(err_msg.contains("Invalid arithmetic operation")); + } +} diff --git a/tests/cases/standalone/common/function/arithmetic.result b/tests/cases/standalone/common/function/arithmetic.result new file mode 100644 index 0000000000..caa8f1e397 --- /dev/null +++ b/tests/cases/standalone/common/function/arithmetic.result @@ -0,0 +1,52 @@ +SELECT MOD(18, 4); + ++-------------------------+ +| mod(Int64(18),Int64(4)) | ++-------------------------+ +| 2 | ++-------------------------+ + +SELECT MOD(-18, 4); + ++--------------------------+ +| mod(Int64(-18),Int64(4)) | ++--------------------------+ +| -2 | ++--------------------------+ + +SELECT MOD(18.0, 4.0); + ++-----------------------------+ +| mod(Float64(18),Float64(4)) | ++-----------------------------+ +| 2.0 | ++-----------------------------+ + +SELECT MOD(18, 0); + +Error: 3001(EngineExecuteQuery), DataFusion error: Divide by zero error + +SELECT POW (2, 5); + ++------------------------+ +| pow(Int64(2),Int64(5)) | ++------------------------+ +| 32.0 | ++------------------------+ + +SELECT POW (1.01, 365); + ++-------------------------------+ +| pow(Float64(1.01),Int64(365)) | ++-------------------------------+ +| 37.78343433288728 | ++-------------------------------+ + +SELECT POW (0.99, 365); + ++-------------------------------+ +| pow(Float64(0.99),Int64(365)) | ++-------------------------------+ +| 0.025517964452291125 | ++-------------------------------+ + diff --git a/tests/cases/standalone/common/function/arithmetic.sql b/tests/cases/standalone/common/function/arithmetic.sql new file mode 100644 index 0000000000..fd048f00a0 --- /dev/null +++ b/tests/cases/standalone/common/function/arithmetic.sql @@ -0,0 +1,16 @@ + +SELECT MOD(18, 4); + +SELECT MOD(-18, 4); + +SELECT MOD(18.0, 4.0); + +SELECT MOD(18, 0); + + + +SELECT POW (2, 5); + +SELECT POW (1.01, 365); + +SELECT POW (0.99, 365);