From 09dacc8e9b79e88b03382714644cb249060ce425 Mon Sep 17 00:00:00 2001 From: Sicong Hu <72676749+SNC123@users.noreply.github.com> Date: Sun, 16 Mar 2025 18:43:53 +0800 Subject: [PATCH] feat: add `vec_subvector` function (#5683) * feat: add vec_subvector function * change datatype of arg1 and arg2 from u64 to i64 * add sqlness test * improve description comments --- src/common/function/src/scalars/vector.rs | 2 + .../src/scalars/vector/vector_subvector.rs | 240 ++++++++++++++++++ src/datatypes/src/value.rs | 15 +- .../common/function/vector/vector.result | 50 ++++ .../common/function/vector/vector.sql | 23 ++ 5 files changed, 329 insertions(+), 1 deletion(-) create mode 100644 src/common/function/src/scalars/vector/vector_subvector.rs diff --git a/src/common/function/src/scalars/vector.rs b/src/common/function/src/scalars/vector.rs index 381c757d9b..b7c3d193f3 100644 --- a/src/common/function/src/scalars/vector.rs +++ b/src/common/function/src/scalars/vector.rs @@ -27,6 +27,7 @@ mod vector_div; mod vector_mul; mod vector_norm; mod vector_sub; +mod vector_subvector; use std::sync::Arc; @@ -56,6 +57,7 @@ impl VectorFunction { registry.register(Arc::new(vector_div::VectorDivFunction)); registry.register(Arc::new(vector_norm::VectorNormFunction)); registry.register(Arc::new(vector_dim::VectorDimFunction)); + registry.register(Arc::new(vector_subvector::VectorSubvectorFunction)); registry.register(Arc::new(elem_sum::ElemSumFunction)); registry.register(Arc::new(elem_product::ElemProductFunction)); } diff --git a/src/common/function/src/scalars/vector/vector_subvector.rs b/src/common/function/src/scalars/vector/vector_subvector.rs new file mode 100644 index 0000000000..2836696853 --- /dev/null +++ b/src/common/function/src/scalars/vector/vector_subvector.rs @@ -0,0 +1,240 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::borrow::Cow; +use std::fmt::Display; + +use common_query::error::{InvalidFuncArgsSnafu, Result}; +use common_query::prelude::{Signature, TypeSignature}; +use datafusion_expr::Volatility; +use datatypes::prelude::ConcreteDataType; +use datatypes::scalars::ScalarVectorBuilder; +use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef}; +use snafu::ensure; + +use crate::function::{Function, FunctionContext}; +use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit}; + +const NAME: &str = "vec_subvector"; + +/// Returns a subvector from start(included) to end(excluded) index. +/// +/// # Example +/// +/// ```sql +/// SELECT vec_to_string(vec_subvector("[1, 2, 3, 4, 5]", 1, 3)) as result; +/// +/// +---------+ +/// | result | +/// +---------+ +/// | [2, 3] | +/// +---------+ +/// +/// ``` +/// + +#[derive(Debug, Clone, Default)] +pub struct VectorSubvectorFunction; + +impl Function for VectorSubvectorFunction { + fn name(&self) -> &str { + NAME + } + + fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result { + Ok(ConcreteDataType::binary_datatype()) + } + + fn signature(&self) -> Signature { + Signature::one_of( + vec![ + TypeSignature::Exact(vec![ + ConcreteDataType::string_datatype(), + ConcreteDataType::int64_datatype(), + ConcreteDataType::int64_datatype(), + ]), + TypeSignature::Exact(vec![ + ConcreteDataType::binary_datatype(), + ConcreteDataType::int64_datatype(), + ConcreteDataType::int64_datatype(), + ]), + ], + Volatility::Immutable, + ) + } + + fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result { + ensure!( + columns.len() == 3, + InvalidFuncArgsSnafu { + err_msg: format!( + "The length of the args is not correct, expect exactly three, have: {}", + columns.len() + ) + } + ); + + let arg0 = &columns[0]; + let arg1 = &columns[1]; + let arg2 = &columns[2]; + + ensure!( + arg0.len() == arg1.len() && arg1.len() == arg2.len(), + InvalidFuncArgsSnafu { + err_msg: format!( + "The lengths of the vector are not aligned, args 0: {}, args 1: {}, args 2: {}", + arg0.len(), + arg1.len(), + arg2.len() + ) + } + ); + + let len = arg0.len(); + let mut result = BinaryVectorBuilder::with_capacity(len); + if len == 0 { + return Ok(result.to_vector()); + } + + let arg0_const = as_veclit_if_const(arg0)?; + + for i in 0..len { + let arg0 = match arg0_const.as_ref() { + Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())), + None => as_veclit(arg0.get_ref(i))?, + }; + let arg1 = arg1.get(i).as_i64(); + let arg2 = arg2.get(i).as_i64(); + let (Some(arg0), Some(arg1), Some(arg2)) = (arg0, arg1, arg2) else { + result.push_null(); + continue; + }; + + ensure!( + 0 <= arg1 && arg1 <= arg2 && arg2 as usize <= arg0.len(), + InvalidFuncArgsSnafu { + err_msg: format!( + "Invalid start and end indices: start={}, end={}, vec_len={}", + arg1, + arg2, + arg0.len() + ) + } + ); + + let subvector = &arg0[arg1 as usize..arg2 as usize]; + let binlit = veclit_to_binlit(subvector); + result.push(Some(&binlit)); + } + + Ok(result.to_vector()) + } +} + +impl Display for VectorSubvectorFunction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", NAME.to_ascii_uppercase()) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use common_query::error::Error; + use datatypes::vectors::{Int64Vector, StringVector}; + + use super::*; + use crate::function::FunctionContext; + #[test] + fn test_subvector() { + let func = VectorSubvectorFunction; + + let input0 = Arc::new(StringVector::from(vec![ + Some("[1.0, 2.0, 3.0, 4.0, 5.0]".to_string()), + Some("[6.0, 7.0, 8.0, 9.0, 10.0]".to_string()), + None, + Some("[11.0, 12.0, 13.0]".to_string()), + ])); + let input1 = Arc::new(Int64Vector::from(vec![Some(1), Some(0), Some(0), Some(1)])); + let input2 = Arc::new(Int64Vector::from(vec![Some(3), Some(5), Some(2), Some(3)])); + + let result = func + .eval(&FunctionContext::default(), &[input0, input1, input2]) + .unwrap(); + + let result = result.as_ref(); + assert_eq!(result.len(), 4); + assert_eq!( + result.get_ref(0).as_binary().unwrap(), + Some(veclit_to_binlit(&[2.0, 3.0]).as_slice()) + ); + assert_eq!( + result.get_ref(1).as_binary().unwrap(), + Some(veclit_to_binlit(&[6.0, 7.0, 8.0, 9.0, 10.0]).as_slice()) + ); + assert!(result.get_ref(2).is_null()); + assert_eq!( + result.get_ref(3).as_binary().unwrap(), + Some(veclit_to_binlit(&[12.0, 13.0]).as_slice()) + ); + } + #[test] + fn test_subvector_error() { + let func = VectorSubvectorFunction; + + let input0 = Arc::new(StringVector::from(vec![ + Some("[1.0, 2.0, 3.0]".to_string()), + Some("[4.0, 5.0, 6.0]".to_string()), + ])); + let input1 = Arc::new(Int64Vector::from(vec![Some(1), Some(2)])); + let input2 = Arc::new(Int64Vector::from(vec![Some(3)])); + + let result = func.eval(&FunctionContext::default(), &[input0, input1, input2]); + + match result { + Err(Error::InvalidFuncArgs { err_msg, .. }) => { + assert_eq!( + err_msg, + "The lengths of the vector are not aligned, args 0: 2, args 1: 2, args 2: 1" + ) + } + _ => unreachable!(), + } + } + + #[test] + fn test_subvector_invalid_indices() { + let func = VectorSubvectorFunction; + + let input0 = Arc::new(StringVector::from(vec![ + Some("[1.0, 2.0, 3.0]".to_string()), + Some("[4.0, 5.0, 6.0]".to_string()), + ])); + let input1 = Arc::new(Int64Vector::from(vec![Some(1), Some(3)])); + let input2 = Arc::new(Int64Vector::from(vec![Some(3), Some(4)])); + + let result = func.eval(&FunctionContext::default(), &[input0, input1, input2]); + + match result { + Err(Error::InvalidFuncArgs { err_msg, .. }) => { + assert_eq!( + err_msg, + "Invalid start and end indices: start=3, end=4, vec_len=3" + ) + } + _ => unreachable!(), + } + } +} diff --git a/src/datatypes/src/value.rs b/src/datatypes/src/value.rs index bf41d2b764..aa670d2fe8 100644 --- a/src/datatypes/src/value.rs +++ b/src/datatypes/src/value.rs @@ -285,6 +285,20 @@ impl Value { } } + /// Cast Value to i64. Return None if value is not a valid int64 data type. + pub fn as_i64(&self) -> Option { + match self { + Value::Int8(v) => Some(*v as _), + Value::Int16(v) => Some(*v as _), + Value::Int32(v) => Some(*v as _), + Value::Int64(v) => Some(*v), + Value::UInt8(v) => Some(*v as _), + Value::UInt16(v) => Some(*v as _), + Value::UInt32(v) => Some(*v as _), + _ => None, + } + } + /// Cast Value to u64. Return None if value is not a valid uint64 data type. pub fn as_u64(&self) -> Option { match self { @@ -295,7 +309,6 @@ impl Value { _ => None, } } - /// Cast Value to f64. Return None if it's not castable; pub fn as_f64_lossy(&self) -> Option { match self { diff --git a/tests/cases/standalone/common/function/vector/vector.result b/tests/cases/standalone/common/function/vector/vector.result index 7f40c73636..859268b45f 100644 --- a/tests/cases/standalone/common/function/vector/vector.result +++ b/tests/cases/standalone/common/function/vector/vector.result @@ -326,3 +326,53 @@ FROM ( | [7.0, 8.0, 9.0, 10.0] | 4 | +-----------------------+------------+ +SELECT vec_to_string(vec_subvector('[1.0,2.0,3.0,4.0,5.0]', 0, 3)); + ++-------------------------------------------------------------------------------+ +| vec_to_string(vec_subvector(Utf8("[1.0,2.0,3.0,4.0,5.0]"),Int64(0),Int64(3))) | ++-------------------------------------------------------------------------------+ +| [1,2,3] | ++-------------------------------------------------------------------------------+ + +SELECT vec_to_string(vec_subvector('[1.0,2.0,3.0,4.0,5.0]', 5, 5)); + ++-------------------------------------------------------------------------------+ +| vec_to_string(vec_subvector(Utf8("[1.0,2.0,3.0,4.0,5.0]"),Int64(5),Int64(5))) | ++-------------------------------------------------------------------------------+ +| [] | ++-------------------------------------------------------------------------------+ + +SELECT v, vec_to_string(vec_subvector(v, 3, 5)) +FROM ( + SELECT '[1.0, 2.0, 3.0, 4.0, 5.0]' AS v + UNION ALL + SELECT '[-1.0, -2.0, -3.0, -4.0, -5.0, -6.0]' AS v + UNION ALL + SELECT '[4.0, 5.0, 6.0, 10, -8, 100]' AS v +) ORDER BY v; + ++--------------------------------------+---------------------------------------------------+ +| v | vec_to_string(vec_subvector(v,Int64(3),Int64(5))) | ++--------------------------------------+---------------------------------------------------+ +| [-1.0, -2.0, -3.0, -4.0, -5.0, -6.0] | [-4,-5] | +| [1.0, 2.0, 3.0, 4.0, 5.0] | [4,5] | +| [4.0, 5.0, 6.0, 10, -8, 100] | [10,-8] | ++--------------------------------------+---------------------------------------------------+ + +SELECT vec_to_string(vec_subvector(v, 0, 5)) +FROM ( + SELECT '[1.1, 2.2, 3.3, 4.4, 5.5]' AS v + UNION ALL + SELECT '[-1.1, -2.1, -3.1, -4.1, -5.1, -6.1]' AS v + UNION ALL + SELECT '[4.0, 5.0, 6.0, 10, -8, 100]' AS v +) ORDER BY v; + ++---------------------------------------------------+ +| vec_to_string(vec_subvector(v,Int64(0),Int64(5))) | ++---------------------------------------------------+ +| [-1.1,-2.1,-3.1,-4.1,-5.1] | +| [1.1,2.2,3.3,4.4,5.5] | +| [4,5,6,10,-8] | ++---------------------------------------------------+ + diff --git a/tests/cases/standalone/common/function/vector/vector.sql b/tests/cases/standalone/common/function/vector/vector.sql index b53b6af453..8cf9a1d188 100644 --- a/tests/cases/standalone/common/function/vector/vector.sql +++ b/tests/cases/standalone/common/function/vector/vector.sql @@ -99,3 +99,26 @@ FROM ( UNION ALL SELECT '[7.0, 8.0, 9.0, 10.0]' AS v ) Order By vec_dim(v) ASC; + +SELECT vec_to_string(vec_subvector('[1.0,2.0,3.0,4.0,5.0]', 0, 3)); + +SELECT vec_to_string(vec_subvector('[1.0,2.0,3.0,4.0,5.0]', 5, 5)); + +SELECT v, vec_to_string(vec_subvector(v, 3, 5)) +FROM ( + SELECT '[1.0, 2.0, 3.0, 4.0, 5.0]' AS v + UNION ALL + SELECT '[-1.0, -2.0, -3.0, -4.0, -5.0, -6.0]' AS v + UNION ALL + SELECT '[4.0, 5.0, 6.0, 10, -8, 100]' AS v +) ORDER BY v; + +SELECT vec_to_string(vec_subvector(v, 0, 5)) +FROM ( + SELECT '[1.1, 2.2, 3.3, 4.4, 5.5]' AS v + UNION ALL + SELECT '[-1.1, -2.1, -3.1, -4.1, -5.1, -6.1]' AS v + UNION ALL + SELECT '[4.0, 5.0, 6.0, 10, -8, 100]' AS v +) ORDER BY v; +