From eea8b1c730afb38df1998338872a8b2ee5fd2881 Mon Sep 17 00:00:00 2001 From: pikady <68271745+Pikady@users.noreply.github.com> Date: Tue, 18 Mar 2025 15:25:53 +0800 Subject: [PATCH] feat: add `vec_kth_elem` function (#5674) * feat: add vec_kth_elem function Signed-off-by: pikady <2652917633@qq.com> * code format Signed-off-by: pikady <2652917633@qq.com> * add test sql Signed-off-by: pikady <2652917633@qq.com> * change indexing from 1-based to 0-based Signed-off-by: pikady <2652917633@qq.com> * improve code formatting and correct spelling errors Signed-off-by: pikady <2652917633@qq.com> * Update tests/cases/standalone/common/function/vector/vector.sql I noticed the two lines are identical. Could you clarify the reason for the change? Thanks! Co-authored-by: Zhenchi --------- Signed-off-by: pikady <2652917633@qq.com> Co-authored-by: Zhenchi --- src/common/function/src/scalars/vector.rs | 2 + .../src/scalars/vector/vector_kth_elem.rs | 211 ++++++++++++++++++ .../common/function/vector/vector.result | 25 +++ .../common/function/vector/vector.sql | 13 +- 4 files changed, 250 insertions(+), 1 deletion(-) create mode 100644 src/common/function/src/scalars/vector/vector_kth_elem.rs diff --git a/src/common/function/src/scalars/vector.rs b/src/common/function/src/scalars/vector.rs index b7c3d193f3..d8dc195e5b 100644 --- a/src/common/function/src/scalars/vector.rs +++ b/src/common/function/src/scalars/vector.rs @@ -24,6 +24,7 @@ pub(crate) mod sum; mod vector_add; mod vector_dim; mod vector_div; +mod vector_kth_elem; mod vector_mul; mod vector_norm; mod vector_sub; @@ -57,6 +58,7 @@ impl VectorFunction { registry.register(Arc::new(vector_div::VectorDivFunction)); registry.register(Arc::new(vector_norm::VectorNormFunction)); registry.register(Arc::new(vector_dim::VectorDimFunction)); + registry.register(Arc::new(vector_kth_elem::VectorKthElemFunction)); registry.register(Arc::new(vector_subvector::VectorSubvectorFunction)); registry.register(Arc::new(elem_sum::ElemSumFunction)); registry.register(Arc::new(elem_product::ElemProductFunction)); diff --git a/src/common/function/src/scalars/vector/vector_kth_elem.rs b/src/common/function/src/scalars/vector/vector_kth_elem.rs new file mode 100644 index 0000000000..2c1cd48e93 --- /dev/null +++ b/src/common/function/src/scalars/vector/vector_kth_elem.rs @@ -0,0 +1,211 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::borrow::Cow; +use std::fmt::Display; + +use common_query::error::{InvalidFuncArgsSnafu, Result}; +use common_query::prelude::Signature; +use datatypes::prelude::ConcreteDataType; +use datatypes::scalars::ScalarVectorBuilder; +use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef}; +use snafu::ensure; + +use crate::function::{Function, FunctionContext}; +use crate::helper; +use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const}; + +const NAME: &str = "vec_kth_elem"; + +/// Returns the k-th(0-based index) element of the vector. +/// +/// # Example +/// +/// ```sql +/// SELECT vec_kth_elem("[2, 4, 6]",1) as result; +/// +/// +---------+ +/// | result | +/// +---------+ +/// | 4 | +/// +---------+ +/// +/// ``` +/// + +#[derive(Debug, Clone, Default)] +pub struct VectorKthElemFunction; + +impl Function for VectorKthElemFunction { + fn name(&self) -> &str { + NAME + } + + fn return_type( + &self, + _input_types: &[ConcreteDataType], + ) -> common_query::error::Result { + Ok(ConcreteDataType::float32_datatype()) + } + + fn signature(&self) -> Signature { + helper::one_of_sigs2( + vec![ + ConcreteDataType::string_datatype(), + ConcreteDataType::binary_datatype(), + ], + vec![ConcreteDataType::int64_datatype()], + ) + } + + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { + ensure!( + columns.len() == 2, + InvalidFuncArgsSnafu { + err_msg: format!( + "The length of the args is not correct, expect exactly two, have: {}", + columns.len() + ), + } + ); + + let arg0 = &columns[0]; + let arg1 = &columns[1]; + + let len = arg0.len(); + let mut result = Float32VectorBuilder::with_capacity(len); + if len == 0 { + return Ok(result.to_vector()); + }; + + let arg0_const = as_veclit_if_const(arg0)?; + + for i in 0..len { + let arg0 = match arg0_const.as_ref() { + Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())), + None => as_veclit(arg0.get_ref(i))?, + }; + let Some(arg0) = arg0 else { + result.push_null(); + continue; + }; + + let arg1 = arg1.get(i).as_f64_lossy(); + let Some(arg1) = arg1 else { + result.push_null(); + continue; + }; + + ensure!( + arg1 >= 0.0 && arg1.fract() == 0.0, + InvalidFuncArgsSnafu { + err_msg: format!( + "Invalid argument: k must be a non-negative integer, but got k = {}.", + arg1 + ), + } + ); + + let k = arg1 as usize; + + ensure!( + k < arg0.len(), + InvalidFuncArgsSnafu { + err_msg: format!( + "Out of range: k must be in the range [0, {}], but got k = {}.", + arg0.len() - 1, + k + ), + } + ); + + let value = arg0[k]; + + result.push(Some(value)); + } + Ok(result.to_vector()) + } +} + +impl Display for VectorKthElemFunction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", NAME.to_ascii_uppercase()) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use common_query::error; + use datatypes::vectors::{Int64Vector, StringVector}; + + use super::*; + + #[test] + fn test_vec_kth_elem() { + let func = VectorKthElemFunction; + + let input0 = Arc::new(StringVector::from(vec![ + Some("[1.0,2.0,3.0]".to_string()), + Some("[4.0,5.0,6.0]".to_string()), + Some("[7.0,8.0,9.0]".to_string()), + None, + ])); + let input1 = Arc::new(Int64Vector::from(vec![Some(0), Some(2), None, Some(1)])); + + let result = func + .eval(&FunctionContext::default(), &[input0, input1]) + .unwrap(); + + let result = result.as_ref(); + assert_eq!(result.len(), 4); + assert_eq!(result.get_ref(0).as_f32().unwrap(), Some(1.0)); + assert_eq!(result.get_ref(1).as_f32().unwrap(), Some(6.0)); + assert!(result.get_ref(2).is_null()); + assert!(result.get_ref(3).is_null()); + + let input0 = Arc::new(StringVector::from(vec![Some("[1.0,2.0,3.0]".to_string())])); + let input1 = Arc::new(Int64Vector::from(vec![Some(3)])); + + let err = func + .eval(&FunctionContext::default(), &[input0, input1]) + .unwrap_err(); + match err { + error::Error::InvalidFuncArgs { err_msg, .. } => { + assert_eq!( + err_msg, + format!("Out of range: k must be in the range [0, 2], but got k = 3.") + ) + } + _ => unreachable!(), + } + + let input0 = Arc::new(StringVector::from(vec![Some("[1.0,2.0,3.0]".to_string())])); + let input1 = Arc::new(Int64Vector::from(vec![Some(-1)])); + + let err = func + .eval(&FunctionContext::default(), &[input0, input1]) + .unwrap_err(); + match err { + error::Error::InvalidFuncArgs { err_msg, .. } => { + assert_eq!( + err_msg, + format!("Invalid argument: k must be a non-negative integer, but got k = -1.") + ) + } + _ => unreachable!(), + } + } +} diff --git a/tests/cases/standalone/common/function/vector/vector.result b/tests/cases/standalone/common/function/vector/vector.result index 859268b45f..996138dba2 100644 --- a/tests/cases/standalone/common/function/vector/vector.result +++ b/tests/cases/standalone/common/function/vector/vector.result @@ -326,6 +326,31 @@ FROM ( | [7.0, 8.0, 9.0, 10.0] | 4 | +-----------------------+------------+ +SELECT vec_kth_elem('[1.0, 2.0, 3.0]', 2); + ++------------------------------------------------+ +| vec_kth_elem(Utf8("[1.0, 2.0, 3.0]"),Int64(2)) | ++------------------------------------------------+ +| 3.0 | ++------------------------------------------------+ + +SELECT v, vec_kth_elem(v, 0) AS first_elem +FROM ( + SELECT '[1.0, 2.0, 3.0]' AS v + UNION ALL + SELECT '[4.0, 5.0, 6.0, 7.0]' AS v + UNION ALL + SELECT '[8.0]' AS v + ) +WHERE vec_kth_elem(v, 0) > 2.0; + ++----------------------+------------+ +| v | first_elem | ++----------------------+------------+ +| [4.0, 5.0, 6.0, 7.0] | 4.0 | +| [8.0] | 8.0 | ++----------------------+------------+ + SELECT vec_to_string(vec_subvector('[1.0,2.0,3.0,4.0,5.0]', 0, 3)); +-------------------------------------------------------------------------------+ diff --git a/tests/cases/standalone/common/function/vector/vector.sql b/tests/cases/standalone/common/function/vector/vector.sql index 8cf9a1d188..05f7b6ee17 100644 --- a/tests/cases/standalone/common/function/vector/vector.sql +++ b/tests/cases/standalone/common/function/vector/vector.sql @@ -100,6 +100,18 @@ FROM ( SELECT '[7.0, 8.0, 9.0, 10.0]' AS v ) Order By vec_dim(v) ASC; +SELECT vec_kth_elem('[1.0, 2.0, 3.0]', 2); + +SELECT v, vec_kth_elem(v, 0) AS first_elem +FROM ( + SELECT '[1.0, 2.0, 3.0]' AS v + UNION ALL + SELECT '[4.0, 5.0, 6.0, 7.0]' AS v + UNION ALL + SELECT '[8.0]' AS v + ) +WHERE vec_kth_elem(v, 0) > 2.0; + SELECT vec_to_string(vec_subvector('[1.0,2.0,3.0,4.0,5.0]', 0, 3)); SELECT vec_to_string(vec_subvector('[1.0,2.0,3.0,4.0,5.0]', 5, 5)); @@ -121,4 +133,3 @@ FROM ( UNION ALL SELECT '[4.0, 5.0, 6.0, 10, -8, 100]' AS v ) ORDER BY v; -