mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-25 17:30:41 +00:00
feat: add vec_kth_elem function (#5674)
* feat: add vec_kth_elem function Signed-off-by: pikady <2652917633@qq.com> * code format Signed-off-by: pikady <2652917633@qq.com> * add test sql Signed-off-by: pikady <2652917633@qq.com> * change indexing from 1-based to 0-based Signed-off-by: pikady <2652917633@qq.com> * improve code formatting and correct spelling errors Signed-off-by: pikady <2652917633@qq.com> * Update tests/cases/standalone/common/function/vector/vector.sql I noticed the two lines are identical. Could you clarify the reason for the change? Thanks! Co-authored-by: Zhenchi <zhongzc_arch@outlook.com> --------- Signed-off-by: pikady <2652917633@qq.com> Co-authored-by: Zhenchi <zhongzc_arch@outlook.com>
This commit is contained in:
@@ -24,6 +24,7 @@ pub(crate) mod sum;
|
||||
mod vector_add;
|
||||
mod vector_dim;
|
||||
mod vector_div;
|
||||
mod vector_kth_elem;
|
||||
mod vector_mul;
|
||||
mod vector_norm;
|
||||
mod vector_sub;
|
||||
@@ -57,6 +58,7 @@ impl VectorFunction {
|
||||
registry.register(Arc::new(vector_div::VectorDivFunction));
|
||||
registry.register(Arc::new(vector_norm::VectorNormFunction));
|
||||
registry.register(Arc::new(vector_dim::VectorDimFunction));
|
||||
registry.register(Arc::new(vector_kth_elem::VectorKthElemFunction));
|
||||
registry.register(Arc::new(vector_subvector::VectorSubvectorFunction));
|
||||
registry.register(Arc::new(elem_sum::ElemSumFunction));
|
||||
registry.register(Arc::new(elem_product::ElemProductFunction));
|
||||
|
||||
211
src/common/function/src/scalars/vector/vector_kth_elem.rs
Normal file
211
src/common/function/src/scalars/vector/vector_kth_elem.rs
Normal file
@@ -0,0 +1,211 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Display;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use common_query::prelude::Signature;
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef};
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::helper;
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
|
||||
|
||||
const NAME: &str = "vec_kth_elem";
|
||||
|
||||
/// Returns the k-th(0-based index) element of the vector.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```sql
|
||||
/// SELECT vec_kth_elem("[2, 4, 6]",1) as result;
|
||||
///
|
||||
/// +---------+
|
||||
/// | result |
|
||||
/// +---------+
|
||||
/// | 4 |
|
||||
/// +---------+
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct VectorKthElemFunction;
|
||||
|
||||
impl Function for VectorKthElemFunction {
|
||||
fn name(&self) -> &str {
|
||||
NAME
|
||||
}
|
||||
|
||||
fn return_type(
|
||||
&self,
|
||||
_input_types: &[ConcreteDataType],
|
||||
) -> common_query::error::Result<ConcreteDataType> {
|
||||
Ok(ConcreteDataType::float32_datatype())
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
helper::one_of_sigs2(
|
||||
vec![
|
||||
ConcreteDataType::string_datatype(),
|
||||
ConcreteDataType::binary_datatype(),
|
||||
],
|
||||
vec![ConcreteDataType::int64_datatype()],
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 2,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly two, have: {}",
|
||||
columns.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
|
||||
let arg0 = &columns[0];
|
||||
let arg1 = &columns[1];
|
||||
|
||||
let len = arg0.len();
|
||||
let mut result = Float32VectorBuilder::with_capacity(len);
|
||||
if len == 0 {
|
||||
return Ok(result.to_vector());
|
||||
};
|
||||
|
||||
let arg0_const = as_veclit_if_const(arg0)?;
|
||||
|
||||
for i in 0..len {
|
||||
let arg0 = match arg0_const.as_ref() {
|
||||
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
|
||||
None => as_veclit(arg0.get_ref(i))?,
|
||||
};
|
||||
let Some(arg0) = arg0 else {
|
||||
result.push_null();
|
||||
continue;
|
||||
};
|
||||
|
||||
let arg1 = arg1.get(i).as_f64_lossy();
|
||||
let Some(arg1) = arg1 else {
|
||||
result.push_null();
|
||||
continue;
|
||||
};
|
||||
|
||||
ensure!(
|
||||
arg1 >= 0.0 && arg1.fract() == 0.0,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"Invalid argument: k must be a non-negative integer, but got k = {}.",
|
||||
arg1
|
||||
),
|
||||
}
|
||||
);
|
||||
|
||||
let k = arg1 as usize;
|
||||
|
||||
ensure!(
|
||||
k < arg0.len(),
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"Out of range: k must be in the range [0, {}], but got k = {}.",
|
||||
arg0.len() - 1,
|
||||
k
|
||||
),
|
||||
}
|
||||
);
|
||||
|
||||
let value = arg0[k];
|
||||
|
||||
result.push(Some(value));
|
||||
}
|
||||
Ok(result.to_vector())
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for VectorKthElemFunction {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", NAME.to_ascii_uppercase())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error;
|
||||
use datatypes::vectors::{Int64Vector, StringVector};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_vec_kth_elem() {
|
||||
let func = VectorKthElemFunction;
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![
|
||||
Some("[1.0,2.0,3.0]".to_string()),
|
||||
Some("[4.0,5.0,6.0]".to_string()),
|
||||
Some("[7.0,8.0,9.0]".to_string()),
|
||||
None,
|
||||
]));
|
||||
let input1 = Arc::new(Int64Vector::from(vec![Some(0), Some(2), None, Some(1)]));
|
||||
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), &[input0, input1])
|
||||
.unwrap();
|
||||
|
||||
let result = result.as_ref();
|
||||
assert_eq!(result.len(), 4);
|
||||
assert_eq!(result.get_ref(0).as_f32().unwrap(), Some(1.0));
|
||||
assert_eq!(result.get_ref(1).as_f32().unwrap(), Some(6.0));
|
||||
assert!(result.get_ref(2).is_null());
|
||||
assert!(result.get_ref(3).is_null());
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![Some("[1.0,2.0,3.0]".to_string())]));
|
||||
let input1 = Arc::new(Int64Vector::from(vec![Some(3)]));
|
||||
|
||||
let err = func
|
||||
.eval(&FunctionContext::default(), &[input0, input1])
|
||||
.unwrap_err();
|
||||
match err {
|
||||
error::Error::InvalidFuncArgs { err_msg, .. } => {
|
||||
assert_eq!(
|
||||
err_msg,
|
||||
format!("Out of range: k must be in the range [0, 2], but got k = 3.")
|
||||
)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![Some("[1.0,2.0,3.0]".to_string())]));
|
||||
let input1 = Arc::new(Int64Vector::from(vec![Some(-1)]));
|
||||
|
||||
let err = func
|
||||
.eval(&FunctionContext::default(), &[input0, input1])
|
||||
.unwrap_err();
|
||||
match err {
|
||||
error::Error::InvalidFuncArgs { err_msg, .. } => {
|
||||
assert_eq!(
|
||||
err_msg,
|
||||
format!("Invalid argument: k must be a non-negative integer, but got k = -1.")
|
||||
)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user