mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-01-03 20:02:54 +00:00
feat: add vec_subvector function (#5683)
* feat: add vec_subvector function * change datatype of arg1 and arg2 from u64 to i64 * add sqlness test * improve description comments
This commit is contained in:
@@ -27,6 +27,7 @@ mod vector_div;
|
||||
mod vector_mul;
|
||||
mod vector_norm;
|
||||
mod vector_sub;
|
||||
mod vector_subvector;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -56,6 +57,7 @@ impl VectorFunction {
|
||||
registry.register(Arc::new(vector_div::VectorDivFunction));
|
||||
registry.register(Arc::new(vector_norm::VectorNormFunction));
|
||||
registry.register(Arc::new(vector_dim::VectorDimFunction));
|
||||
registry.register(Arc::new(vector_subvector::VectorSubvectorFunction));
|
||||
registry.register(Arc::new(elem_sum::ElemSumFunction));
|
||||
registry.register(Arc::new(elem_product::ElemProductFunction));
|
||||
}
|
||||
|
||||
240
src/common/function/src/scalars/vector/vector_subvector.rs
Normal file
240
src/common/function/src/scalars/vector/vector_subvector.rs
Normal file
@@ -0,0 +1,240 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Display;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use common_query::prelude::{Signature, TypeSignature};
|
||||
use datafusion_expr::Volatility;
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
|
||||
|
||||
const NAME: &str = "vec_subvector";
|
||||
|
||||
/// Returns a subvector from start(included) to end(excluded) index.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```sql
|
||||
/// SELECT vec_to_string(vec_subvector("[1, 2, 3, 4, 5]", 1, 3)) as result;
|
||||
///
|
||||
/// +---------+
|
||||
/// | result |
|
||||
/// +---------+
|
||||
/// | [2, 3] |
|
||||
/// +---------+
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct VectorSubvectorFunction;
|
||||
|
||||
impl Function for VectorSubvectorFunction {
|
||||
fn name(&self) -> &str {
|
||||
NAME
|
||||
}
|
||||
|
||||
fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
|
||||
Ok(ConcreteDataType::binary_datatype())
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
Signature::one_of(
|
||||
vec![
|
||||
TypeSignature::Exact(vec![
|
||||
ConcreteDataType::string_datatype(),
|
||||
ConcreteDataType::int64_datatype(),
|
||||
ConcreteDataType::int64_datatype(),
|
||||
]),
|
||||
TypeSignature::Exact(vec![
|
||||
ConcreteDataType::binary_datatype(),
|
||||
ConcreteDataType::int64_datatype(),
|
||||
ConcreteDataType::int64_datatype(),
|
||||
]),
|
||||
],
|
||||
Volatility::Immutable,
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 3,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly three, have: {}",
|
||||
columns.len()
|
||||
)
|
||||
}
|
||||
);
|
||||
|
||||
let arg0 = &columns[0];
|
||||
let arg1 = &columns[1];
|
||||
let arg2 = &columns[2];
|
||||
|
||||
ensure!(
|
||||
arg0.len() == arg1.len() && arg1.len() == arg2.len(),
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The lengths of the vector are not aligned, args 0: {}, args 1: {}, args 2: {}",
|
||||
arg0.len(),
|
||||
arg1.len(),
|
||||
arg2.len()
|
||||
)
|
||||
}
|
||||
);
|
||||
|
||||
let len = arg0.len();
|
||||
let mut result = BinaryVectorBuilder::with_capacity(len);
|
||||
if len == 0 {
|
||||
return Ok(result.to_vector());
|
||||
}
|
||||
|
||||
let arg0_const = as_veclit_if_const(arg0)?;
|
||||
|
||||
for i in 0..len {
|
||||
let arg0 = match arg0_const.as_ref() {
|
||||
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
|
||||
None => as_veclit(arg0.get_ref(i))?,
|
||||
};
|
||||
let arg1 = arg1.get(i).as_i64();
|
||||
let arg2 = arg2.get(i).as_i64();
|
||||
let (Some(arg0), Some(arg1), Some(arg2)) = (arg0, arg1, arg2) else {
|
||||
result.push_null();
|
||||
continue;
|
||||
};
|
||||
|
||||
ensure!(
|
||||
0 <= arg1 && arg1 <= arg2 && arg2 as usize <= arg0.len(),
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"Invalid start and end indices: start={}, end={}, vec_len={}",
|
||||
arg1,
|
||||
arg2,
|
||||
arg0.len()
|
||||
)
|
||||
}
|
||||
);
|
||||
|
||||
let subvector = &arg0[arg1 as usize..arg2 as usize];
|
||||
let binlit = veclit_to_binlit(subvector);
|
||||
result.push(Some(&binlit));
|
||||
}
|
||||
|
||||
Ok(result.to_vector())
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for VectorSubvectorFunction {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", NAME.to_ascii_uppercase())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error::Error;
|
||||
use datatypes::vectors::{Int64Vector, StringVector};
|
||||
|
||||
use super::*;
|
||||
use crate::function::FunctionContext;
|
||||
#[test]
|
||||
fn test_subvector() {
|
||||
let func = VectorSubvectorFunction;
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![
|
||||
Some("[1.0, 2.0, 3.0, 4.0, 5.0]".to_string()),
|
||||
Some("[6.0, 7.0, 8.0, 9.0, 10.0]".to_string()),
|
||||
None,
|
||||
Some("[11.0, 12.0, 13.0]".to_string()),
|
||||
]));
|
||||
let input1 = Arc::new(Int64Vector::from(vec![Some(1), Some(0), Some(0), Some(1)]));
|
||||
let input2 = Arc::new(Int64Vector::from(vec![Some(3), Some(5), Some(2), Some(3)]));
|
||||
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), &[input0, input1, input2])
|
||||
.unwrap();
|
||||
|
||||
let result = result.as_ref();
|
||||
assert_eq!(result.len(), 4);
|
||||
assert_eq!(
|
||||
result.get_ref(0).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[2.0, 3.0]).as_slice())
|
||||
);
|
||||
assert_eq!(
|
||||
result.get_ref(1).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[6.0, 7.0, 8.0, 9.0, 10.0]).as_slice())
|
||||
);
|
||||
assert!(result.get_ref(2).is_null());
|
||||
assert_eq!(
|
||||
result.get_ref(3).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[12.0, 13.0]).as_slice())
|
||||
);
|
||||
}
|
||||
#[test]
|
||||
fn test_subvector_error() {
|
||||
let func = VectorSubvectorFunction;
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![
|
||||
Some("[1.0, 2.0, 3.0]".to_string()),
|
||||
Some("[4.0, 5.0, 6.0]".to_string()),
|
||||
]));
|
||||
let input1 = Arc::new(Int64Vector::from(vec![Some(1), Some(2)]));
|
||||
let input2 = Arc::new(Int64Vector::from(vec![Some(3)]));
|
||||
|
||||
let result = func.eval(&FunctionContext::default(), &[input0, input1, input2]);
|
||||
|
||||
match result {
|
||||
Err(Error::InvalidFuncArgs { err_msg, .. }) => {
|
||||
assert_eq!(
|
||||
err_msg,
|
||||
"The lengths of the vector are not aligned, args 0: 2, args 1: 2, args 2: 1"
|
||||
)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_subvector_invalid_indices() {
|
||||
let func = VectorSubvectorFunction;
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![
|
||||
Some("[1.0, 2.0, 3.0]".to_string()),
|
||||
Some("[4.0, 5.0, 6.0]".to_string()),
|
||||
]));
|
||||
let input1 = Arc::new(Int64Vector::from(vec![Some(1), Some(3)]));
|
||||
let input2 = Arc::new(Int64Vector::from(vec![Some(3), Some(4)]));
|
||||
|
||||
let result = func.eval(&FunctionContext::default(), &[input0, input1, input2]);
|
||||
|
||||
match result {
|
||||
Err(Error::InvalidFuncArgs { err_msg, .. }) => {
|
||||
assert_eq!(
|
||||
err_msg,
|
||||
"Invalid start and end indices: start=3, end=4, vec_len=3"
|
||||
)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -285,6 +285,20 @@ impl Value {
|
||||
}
|
||||
}
|
||||
|
||||
/// Cast Value to i64. Return None if value is not a valid int64 data type.
|
||||
pub fn as_i64(&self) -> Option<i64> {
|
||||
match self {
|
||||
Value::Int8(v) => Some(*v as _),
|
||||
Value::Int16(v) => Some(*v as _),
|
||||
Value::Int32(v) => Some(*v as _),
|
||||
Value::Int64(v) => Some(*v),
|
||||
Value::UInt8(v) => Some(*v as _),
|
||||
Value::UInt16(v) => Some(*v as _),
|
||||
Value::UInt32(v) => Some(*v as _),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Cast Value to u64. Return None if value is not a valid uint64 data type.
|
||||
pub fn as_u64(&self) -> Option<u64> {
|
||||
match self {
|
||||
@@ -295,7 +309,6 @@ impl Value {
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Cast Value to f64. Return None if it's not castable;
|
||||
pub fn as_f64_lossy(&self) -> Option<f64> {
|
||||
match self {
|
||||
|
||||
@@ -326,3 +326,53 @@ FROM (
|
||||
| [7.0, 8.0, 9.0, 10.0] | 4 |
|
||||
+-----------------------+------------+
|
||||
|
||||
SELECT vec_to_string(vec_subvector('[1.0,2.0,3.0,4.0,5.0]', 0, 3));
|
||||
|
||||
+-------------------------------------------------------------------------------+
|
||||
| vec_to_string(vec_subvector(Utf8("[1.0,2.0,3.0,4.0,5.0]"),Int64(0),Int64(3))) |
|
||||
+-------------------------------------------------------------------------------+
|
||||
| [1,2,3] |
|
||||
+-------------------------------------------------------------------------------+
|
||||
|
||||
SELECT vec_to_string(vec_subvector('[1.0,2.0,3.0,4.0,5.0]', 5, 5));
|
||||
|
||||
+-------------------------------------------------------------------------------+
|
||||
| vec_to_string(vec_subvector(Utf8("[1.0,2.0,3.0,4.0,5.0]"),Int64(5),Int64(5))) |
|
||||
+-------------------------------------------------------------------------------+
|
||||
| [] |
|
||||
+-------------------------------------------------------------------------------+
|
||||
|
||||
SELECT v, vec_to_string(vec_subvector(v, 3, 5))
|
||||
FROM (
|
||||
SELECT '[1.0, 2.0, 3.0, 4.0, 5.0]' AS v
|
||||
UNION ALL
|
||||
SELECT '[-1.0, -2.0, -3.0, -4.0, -5.0, -6.0]' AS v
|
||||
UNION ALL
|
||||
SELECT '[4.0, 5.0, 6.0, 10, -8, 100]' AS v
|
||||
) ORDER BY v;
|
||||
|
||||
+--------------------------------------+---------------------------------------------------+
|
||||
| v | vec_to_string(vec_subvector(v,Int64(3),Int64(5))) |
|
||||
+--------------------------------------+---------------------------------------------------+
|
||||
| [-1.0, -2.0, -3.0, -4.0, -5.0, -6.0] | [-4,-5] |
|
||||
| [1.0, 2.0, 3.0, 4.0, 5.0] | [4,5] |
|
||||
| [4.0, 5.0, 6.0, 10, -8, 100] | [10,-8] |
|
||||
+--------------------------------------+---------------------------------------------------+
|
||||
|
||||
SELECT vec_to_string(vec_subvector(v, 0, 5))
|
||||
FROM (
|
||||
SELECT '[1.1, 2.2, 3.3, 4.4, 5.5]' AS v
|
||||
UNION ALL
|
||||
SELECT '[-1.1, -2.1, -3.1, -4.1, -5.1, -6.1]' AS v
|
||||
UNION ALL
|
||||
SELECT '[4.0, 5.0, 6.0, 10, -8, 100]' AS v
|
||||
) ORDER BY v;
|
||||
|
||||
+---------------------------------------------------+
|
||||
| vec_to_string(vec_subvector(v,Int64(0),Int64(5))) |
|
||||
+---------------------------------------------------+
|
||||
| [-1.1,-2.1,-3.1,-4.1,-5.1] |
|
||||
| [1.1,2.2,3.3,4.4,5.5] |
|
||||
| [4,5,6,10,-8] |
|
||||
+---------------------------------------------------+
|
||||
|
||||
|
||||
@@ -99,3 +99,26 @@ FROM (
|
||||
UNION ALL
|
||||
SELECT '[7.0, 8.0, 9.0, 10.0]' AS v
|
||||
) Order By vec_dim(v) ASC;
|
||||
|
||||
SELECT vec_to_string(vec_subvector('[1.0,2.0,3.0,4.0,5.0]', 0, 3));
|
||||
|
||||
SELECT vec_to_string(vec_subvector('[1.0,2.0,3.0,4.0,5.0]', 5, 5));
|
||||
|
||||
SELECT v, vec_to_string(vec_subvector(v, 3, 5))
|
||||
FROM (
|
||||
SELECT '[1.0, 2.0, 3.0, 4.0, 5.0]' AS v
|
||||
UNION ALL
|
||||
SELECT '[-1.0, -2.0, -3.0, -4.0, -5.0, -6.0]' AS v
|
||||
UNION ALL
|
||||
SELECT '[4.0, 5.0, 6.0, 10, -8, 100]' AS v
|
||||
) ORDER BY v;
|
||||
|
||||
SELECT vec_to_string(vec_subvector(v, 0, 5))
|
||||
FROM (
|
||||
SELECT '[1.1, 2.2, 3.3, 4.4, 5.5]' AS v
|
||||
UNION ALL
|
||||
SELECT '[-1.1, -2.1, -3.1, -4.1, -5.1, -6.1]' AS v
|
||||
UNION ALL
|
||||
SELECT '[4.0, 5.0, 6.0, 10, -8, 100]' AS v
|
||||
) ORDER BY v;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user