diff --git a/src/common/function/src/scalars/vector.rs b/src/common/function/src/scalars/vector.rs index 77344ecab4..90aed7cbd7 100644 --- a/src/common/function/src/scalars/vector.rs +++ b/src/common/function/src/scalars/vector.rs @@ -20,11 +20,12 @@ pub mod impl_conv; pub(crate) mod product; mod scalar_add; mod scalar_mul; -mod sub; pub(crate) mod sum; +mod vector_add; mod vector_div; mod vector_mul; mod vector_norm; +mod vector_sub; use std::sync::Arc; @@ -48,10 +49,11 @@ impl VectorFunction { registry.register(Arc::new(scalar_mul::ScalarMulFunction)); // vector calculation + registry.register(Arc::new(vector_add::VectorAddFunction)); + registry.register(Arc::new(vector_sub::VectorSubFunction)); registry.register(Arc::new(vector_mul::VectorMulFunction)); - registry.register(Arc::new(vector_norm::VectorNormFunction)); registry.register(Arc::new(vector_div::VectorDivFunction)); - registry.register(Arc::new(sub::SubFunction)); + registry.register(Arc::new(vector_norm::VectorNormFunction)); registry.register(Arc::new(elem_sum::ElemSumFunction)); registry.register(Arc::new(elem_product::ElemProductFunction)); } diff --git a/src/common/function/src/scalars/vector/vector_add.rs b/src/common/function/src/scalars/vector/vector_add.rs new file mode 100644 index 0000000000..f0fd9bbbc3 --- /dev/null +++ b/src/common/function/src/scalars/vector/vector_add.rs @@ -0,0 +1,214 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::borrow::Cow; +use std::fmt::Display; + +use common_query::error::InvalidFuncArgsSnafu; +use common_query::prelude::Signature; +use datatypes::prelude::ConcreteDataType; +use datatypes::scalars::ScalarVectorBuilder; +use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef}; +use nalgebra::DVectorView; +use snafu::ensure; + +use crate::function::{Function, FunctionContext}; +use crate::helper; +use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit}; + +const NAME: &str = "vec_add"; + +/// Adds corresponding elements of two vectors, returns a vector. +/// +/// # Example +/// +/// ```sql +/// SELECT vec_to_string(vec_add("[1.0, 1.0]", "[1.0, 2.0]")) as result; +/// +/// +---------------------------------------------------------------+ +/// | vec_to_string(vec_add(Utf8("[1.0, 1.0]"),Utf8("[1.0, 2.0]"))) | +/// +---------------------------------------------------------------+ +/// | [2,3] | +/// +---------------------------------------------------------------+ +/// +#[derive(Debug, Clone, Default)] +pub struct VectorAddFunction; + +impl Function for VectorAddFunction { + fn name(&self) -> &str { + NAME + } + + fn return_type( + &self, + _input_types: &[ConcreteDataType], + ) -> common_query::error::Result { + Ok(ConcreteDataType::binary_datatype()) + } + + fn signature(&self) -> Signature { + helper::one_of_sigs2( + vec![ + ConcreteDataType::string_datatype(), + ConcreteDataType::binary_datatype(), + ], + vec![ + ConcreteDataType::string_datatype(), + ConcreteDataType::binary_datatype(), + ], + ) + } + + fn eval( + &self, + _func_ctx: FunctionContext, + columns: &[VectorRef], + ) -> common_query::error::Result { + ensure!( + columns.len() == 2, + InvalidFuncArgsSnafu { + err_msg: format!( + "The length of the args is not correct, expect exactly two, have: {}", + columns.len() + ) + } + ); + let arg0 = &columns[0]; + let arg1 = &columns[1]; + + ensure!( + arg0.len() == arg1.len(), + InvalidFuncArgsSnafu { + err_msg: format!( + "The lengths of the vector are not aligned, args 0: {}, args 1: {}", + arg0.len(), + arg1.len(), + ) + } + ); + + let len = arg0.len(); + let mut result = BinaryVectorBuilder::with_capacity(len); + if len == 0 { + return Ok(result.to_vector()); + } + + let arg0_const = as_veclit_if_const(arg0)?; + let arg1_const = as_veclit_if_const(arg1)?; + + for i in 0..len { + let arg0 = match arg0_const.as_ref() { + Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())), + None => as_veclit(arg0.get_ref(i))?, + }; + let arg1 = match arg1_const.as_ref() { + Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())), + None => as_veclit(arg1.get_ref(i))?, + }; + let (Some(arg0), Some(arg1)) = (arg0, arg1) else { + result.push_null(); + continue; + }; + let vec0 = DVectorView::from_slice(&arg0, arg0.len()); + let vec1 = DVectorView::from_slice(&arg1, arg1.len()); + + let vec_res = vec0 + vec1; + let veclit = vec_res.as_slice(); + let binlit = veclit_to_binlit(veclit); + result.push(Some(&binlit)); + } + + Ok(result.to_vector()) + } +} + +impl Display for VectorAddFunction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", NAME.to_ascii_uppercase()) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use common_query::error::Error; + use datatypes::vectors::StringVector; + + use super::*; + + #[test] + fn test_sub() { + let func = VectorAddFunction; + + let input0 = Arc::new(StringVector::from(vec![ + Some("[1.0,2.0,3.0]".to_string()), + Some("[4.0,5.0,6.0]".to_string()), + None, + Some("[2.0,3.0,3.0]".to_string()), + ])); + let input1 = Arc::new(StringVector::from(vec![ + Some("[1.0,1.0,1.0]".to_string()), + Some("[6.0,5.0,4.0]".to_string()), + Some("[3.0,2.0,2.0]".to_string()), + None, + ])); + + let result = func + .eval(FunctionContext::default(), &[input0, input1]) + .unwrap(); + + let result = result.as_ref(); + assert_eq!(result.len(), 4); + assert_eq!( + result.get_ref(0).as_binary().unwrap(), + Some(veclit_to_binlit(&[2.0, 3.0, 4.0]).as_slice()) + ); + assert_eq!( + result.get_ref(1).as_binary().unwrap(), + Some(veclit_to_binlit(&[10.0, 10.0, 10.0]).as_slice()) + ); + assert!(result.get_ref(2).is_null()); + assert!(result.get_ref(3).is_null()); + } + + #[test] + fn test_sub_error() { + let func = VectorAddFunction; + + let input0 = Arc::new(StringVector::from(vec![ + Some("[1.0,2.0,3.0]".to_string()), + Some("[4.0,5.0,6.0]".to_string()), + None, + Some("[2.0,3.0,3.0]".to_string()), + ])); + let input1 = Arc::new(StringVector::from(vec![ + Some("[1.0,1.0,1.0]".to_string()), + Some("[6.0,5.0,4.0]".to_string()), + Some("[3.0,2.0,2.0]".to_string()), + ])); + + let result = func.eval(FunctionContext::default(), &[input0, input1]); + + match result { + Err(Error::InvalidFuncArgs { err_msg, .. }) => { + assert_eq!( + err_msg, + "The lengths of the vector are not aligned, args 0: 4, args 1: 3" + ) + } + _ => unreachable!(), + } + } +} diff --git a/src/common/function/src/scalars/vector/sub.rs b/src/common/function/src/scalars/vector/vector_sub.rs similarity index 91% rename from src/common/function/src/scalars/vector/sub.rs rename to src/common/function/src/scalars/vector/vector_sub.rs index 6f56bd9fcd..7f97bb322e 100644 --- a/src/common/function/src/scalars/vector/sub.rs +++ b/src/common/function/src/scalars/vector/vector_sub.rs @@ -42,19 +42,10 @@ const NAME: &str = "vec_sub"; /// | [0,-1] | /// +---------------------------------------------------------------+ /// -/// -- Negative scalar to simulate subtraction -/// SELECT vec_to_string(vec_sub('[-1.0, -1.0]', '[1.0, 2.0]')); -/// -/// +-----------------------------------------------------------------+ -/// | vec_to_string(vec_sub(Utf8("[-1.0, -1.0]"),Utf8("[1.0, 2.0]"))) | -/// +-----------------------------------------------------------------+ -/// | [-2,-3] | -/// +-----------------------------------------------------------------+ -/// #[derive(Debug, Clone, Default)] -pub struct SubFunction; +pub struct VectorSubFunction; -impl Function for SubFunction { +impl Function for VectorSubFunction { fn name(&self) -> &str { NAME } @@ -142,7 +133,7 @@ impl Function for SubFunction { } } -impl Display for SubFunction { +impl Display for VectorSubFunction { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", NAME.to_ascii_uppercase()) } @@ -159,7 +150,7 @@ mod tests { #[test] fn test_sub() { - let func = SubFunction; + let func = VectorSubFunction; let input0 = Arc::new(StringVector::from(vec![ Some("[1.0,2.0,3.0]".to_string()), @@ -194,7 +185,7 @@ mod tests { #[test] fn test_sub_error() { - let func = SubFunction; + let func = VectorSubFunction; let input0 = Arc::new(StringVector::from(vec![ Some("[1.0,2.0,3.0]".to_string()), diff --git a/src/datatypes/src/vectors/binary.rs b/src/datatypes/src/vectors/binary.rs index 45aa40d5e8..6c0b47b803 100644 --- a/src/datatypes/src/vectors/binary.rs +++ b/src/datatypes/src/vectors/binary.rs @@ -77,27 +77,32 @@ impl BinaryVector { .unwrap() .iter() { - let v = if let Some(binary) = binary { - let bytes_size = dim as usize * std::mem::size_of::(); - if let Ok(s) = String::from_utf8(binary.to_vec()) { - let v = parse_string_to_vector_type_value(&s, Some(dim))?; - Some(v) - } else if binary.len() == dim as usize * std::mem::size_of::() { - Some(binary.to_vec()) - } else { - return InvalidVectorSnafu { - msg: format!( - "Unexpected bytes size for vector value, expected {}, got {}", - bytes_size, - binary.len() - ), - } - .fail(); - } - } else { - None + let Some(binary) = binary else { + vector.push(None); + continue; }; - vector.push(v); + + if let Ok(s) = String::from_utf8(binary.to_vec()) { + if let Ok(v) = parse_string_to_vector_type_value(&s, Some(dim)) { + vector.push(Some(v)); + continue; + } + } + + let expected_bytes_size = dim as usize * std::mem::size_of::(); + if binary.len() == expected_bytes_size { + vector.push(Some(binary.to_vec())); + continue; + } else { + return InvalidVectorSnafu { + msg: format!( + "Unexpected bytes size for vector value, expected {}, got {}", + expected_bytes_size, + binary.len() + ), + } + .fail(); + } } Ok(BinaryVector::from(vector)) } diff --git a/tests/cases/standalone/common/function/vector/vector.result b/tests/cases/standalone/common/function/vector/vector.result index 945072411c..1b81fa98b0 100644 --- a/tests/cases/standalone/common/function/vector/vector.result +++ b/tests/cases/standalone/common/function/vector/vector.result @@ -22,6 +22,30 @@ SELECT vec_to_string(parse_vec('[]')); | [] | +--------------------------------------+ +SELECT vec_to_string(vec_add('[1.0, 2.0]', '[3.0, 4.0]')); + ++---------------------------------------------------------------+ +| vec_to_string(vec_add(Utf8("[1.0, 2.0]"),Utf8("[3.0, 4.0]"))) | ++---------------------------------------------------------------+ +| [4,6] | ++---------------------------------------------------------------+ + +SELECT vec_to_string(vec_add(parse_vec('[1.0, 2.0]'), '[3.0, 4.0]')); + ++--------------------------------------------------------------------------+ +| vec_to_string(vec_add(parse_vec(Utf8("[1.0, 2.0]")),Utf8("[3.0, 4.0]"))) | ++--------------------------------------------------------------------------+ +| [4,6] | ++--------------------------------------------------------------------------+ + +SELECT vec_to_string(vec_add('[1.0, 2.0]', parse_vec('[3.0, 4.0]'))); + ++--------------------------------------------------------------------------+ +| vec_to_string(vec_add(Utf8("[1.0, 2.0]"),parse_vec(Utf8("[3.0, 4.0]")))) | ++--------------------------------------------------------------------------+ +| [4,6] | ++--------------------------------------------------------------------------+ + SELECT vec_to_string(vec_mul('[1.0, 2.0]', '[3.0, 4.0]')); +---------------------------------------------------------------+ @@ -230,3 +254,33 @@ SELECT vec_to_string(vec_norm(parse_vec('[7.0, -8.0, 9.0]'))); | [0.5025707,-0.5743665,0.64616233] | +--------------------------------------------------------------+ +SELECT vec_to_string(vec_sum(v)) +FROM ( + SELECT '[1.0, 2.0, 3.0]' AS v + UNION ALL + SELECT '[-1.0, -2.0, -3.0]' AS v + UNION ALL + SELECT '[4.0, 5.0, 6.0]' AS v +); + ++---------------------------+ +| vec_to_string(vec_sum(v)) | ++---------------------------+ +| [4,5,6] | ++---------------------------+ + +SELECT vec_to_string(vec_product(v)) +FROM ( + SELECT '[1.0, 2.0, 3.0]' AS v + UNION ALL + SELECT '[-1.0, -2.0, -3.0]' AS v + UNION ALL + SELECT '[4.0, 5.0, 6.0]' AS v +); + ++-------------------------------+ +| vec_to_string(vec_product(v)) | ++-------------------------------+ +| [-4,-20,-54] | ++-------------------------------+ + diff --git a/tests/cases/standalone/common/function/vector/vector.sql b/tests/cases/standalone/common/function/vector/vector.sql index feffa85be3..49c8e88f28 100644 --- a/tests/cases/standalone/common/function/vector/vector.sql +++ b/tests/cases/standalone/common/function/vector/vector.sql @@ -4,6 +4,12 @@ SELECT vec_to_string(parse_vec('[1.0, 2.0, 3.0]')); SELECT vec_to_string(parse_vec('[]')); +SELECT vec_to_string(vec_add('[1.0, 2.0]', '[3.0, 4.0]')); + +SELECT vec_to_string(vec_add(parse_vec('[1.0, 2.0]'), '[3.0, 4.0]')); + +SELECT vec_to_string(vec_add('[1.0, 2.0]', parse_vec('[3.0, 4.0]'))); + SELECT vec_to_string(vec_mul('[1.0, 2.0]', '[3.0, 4.0]')); SELECT vec_to_string(vec_mul(parse_vec('[1.0, 2.0]'), '[3.0, 4.0]')); @@ -55,3 +61,21 @@ SELECT vec_to_string(vec_norm('[7.0, 8.0, 9.0]')); SELECT vec_to_string(vec_norm('[7.0, -8.0, 9.0]')); SELECT vec_to_string(vec_norm(parse_vec('[7.0, -8.0, 9.0]'))); + +SELECT vec_to_string(vec_sum(v)) +FROM ( + SELECT '[1.0, 2.0, 3.0]' AS v + UNION ALL + SELECT '[-1.0, -2.0, -3.0]' AS v + UNION ALL + SELECT '[4.0, 5.0, 6.0]' AS v +); + +SELECT vec_to_string(vec_product(v)) +FROM ( + SELECT '[1.0, 2.0, 3.0]' AS v + UNION ALL + SELECT '[-1.0, -2.0, -3.0]' AS v + UNION ALL + SELECT '[4.0, 5.0, 6.0]' AS v +);