mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-25 17:30:41 +00:00
feat: add distance functions (#4987)
* feat: add distance functions Signed-off-by: Zhenchi <zhongzc_arch@outlook.com> * fix: f64 instead Signed-off-by: Zhenchi <zhongzc_arch@outlook.com> * address comments Signed-off-by: Zhenchi <zhongzc_arch@outlook.com> * tiny adjust Signed-off-by: Zhenchi <zhongzc_arch@outlook.com> --------- Signed-off-by: Zhenchi <zhongzc_arch@outlook.com>
This commit is contained in:
24
Cargo.lock
generated
24
Cargo.lock
generated
@@ -1041,7 +1041,7 @@ dependencies = [
|
||||
"bitflags 2.6.0",
|
||||
"cexpr",
|
||||
"clang-sys",
|
||||
"itertools 0.10.5",
|
||||
"itertools 0.12.1",
|
||||
"lazy_static",
|
||||
"lazycell",
|
||||
"log",
|
||||
@@ -2088,6 +2088,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"session",
|
||||
"simsimd",
|
||||
"snafu 0.8.5",
|
||||
"sql",
|
||||
"statrs",
|
||||
@@ -5090,7 +5091,7 @@ dependencies = [
|
||||
"httpdate",
|
||||
"itoa",
|
||||
"pin-project-lite",
|
||||
"socket2 0.4.10",
|
||||
"socket2 0.5.7",
|
||||
"tokio",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
@@ -6080,7 +6081,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"windows-targets 0.48.5",
|
||||
"windows-targets 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -8822,7 +8823,7 @@ checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"heck 0.5.0",
|
||||
"itertools 0.10.5",
|
||||
"itertools 0.12.1",
|
||||
"log",
|
||||
"multimap",
|
||||
"once_cell",
|
||||
@@ -8874,7 +8875,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "81bddcdb20abf9501610992b6759a4c888aef7d1a7247ef75e2404275ac24af1"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"itertools 0.10.5",
|
||||
"itertools 0.12.1",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.79",
|
||||
@@ -9036,7 +9037,7 @@ dependencies = [
|
||||
"indoc",
|
||||
"libc",
|
||||
"memoffset 0.9.1",
|
||||
"parking_lot 0.11.2",
|
||||
"parking_lot 0.12.3",
|
||||
"portable-atomic",
|
||||
"pyo3-build-config",
|
||||
"pyo3-ffi",
|
||||
@@ -11198,6 +11199,15 @@ dependencies = [
|
||||
"time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "simsimd"
|
||||
version = "4.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "efc843bc8f12d9c8e6b734a0fe8918fc497b42f6ae0f347dbfdad5b5138ab9b4"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "siphasher"
|
||||
version = "0.3.11"
|
||||
@@ -13981,7 +13991,7 @@ version = "0.1.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb"
|
||||
dependencies = [
|
||||
"windows-sys 0.48.0",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -41,6 +41,7 @@ s2 = { version = "0.0.12", optional = true }
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
session.workspace = true
|
||||
simsimd = "4"
|
||||
snafu.workspace = true
|
||||
sql.workspace = true
|
||||
statrs = "0.16"
|
||||
|
||||
@@ -27,6 +27,7 @@ use crate::scalars::matches::MatchesFunction;
|
||||
use crate::scalars::math::MathFunction;
|
||||
use crate::scalars::numpy::NumpyFunction;
|
||||
use crate::scalars::timestamp::TimestampFunction;
|
||||
use crate::scalars::vector::VectorFunction;
|
||||
use crate::system::SystemFunction;
|
||||
use crate::table::TableFunction;
|
||||
|
||||
@@ -120,6 +121,9 @@ pub static FUNCTION_REGISTRY: Lazy<Arc<FunctionRegistry>> = Lazy::new(|| {
|
||||
// Json related functions
|
||||
JsonFunction::register(&function_registry);
|
||||
|
||||
// Vector related functions
|
||||
VectorFunction::register(&function_registry);
|
||||
|
||||
// Geo functions
|
||||
#[cfg(feature = "geo")]
|
||||
crate::scalars::geo::GeoFunctions::register(&function_registry);
|
||||
|
||||
@@ -21,6 +21,7 @@ pub mod json;
|
||||
pub mod matches;
|
||||
pub mod math;
|
||||
pub mod numpy;
|
||||
pub mod vector;
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod test;
|
||||
|
||||
31
src/common/function/src/scalars/vector.rs
Normal file
31
src/common/function/src/scalars/vector.rs
Normal file
@@ -0,0 +1,31 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
mod distance;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use distance::{CosDistanceFunction, DotProductFunction, L2SqDistanceFunction};
|
||||
|
||||
use crate::function_registry::FunctionRegistry;
|
||||
|
||||
pub(crate) struct VectorFunction;
|
||||
|
||||
impl VectorFunction {
|
||||
pub fn register(registry: &FunctionRegistry) {
|
||||
registry.register(Arc::new(CosDistanceFunction));
|
||||
registry.register(Arc::new(DotProductFunction));
|
||||
registry.register(Arc::new(L2SqDistanceFunction));
|
||||
}
|
||||
}
|
||||
469
src/common/function/src/scalars/vector/distance.rs
Normal file
469
src/common/function/src/scalars/vector/distance.rs
Normal file
@@ -0,0 +1,469 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Display;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use common_query::prelude::Signature;
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::value::ValueRef;
|
||||
use datatypes::vectors::{Float64VectorBuilder, MutableVector, Vector, VectorRef};
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::helper;
|
||||
|
||||
macro_rules! define_distance_function {
|
||||
($StructName:ident, $display_name:expr, $similarity_method:ident) => {
|
||||
|
||||
/// A function calculates the distance between two vectors.
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct $StructName;
|
||||
|
||||
impl Function for $StructName {
|
||||
fn name(&self) -> &str {
|
||||
$display_name
|
||||
}
|
||||
|
||||
fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
|
||||
Ok(ConcreteDataType::float64_datatype())
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
helper::one_of_sigs2(
|
||||
vec![
|
||||
ConcreteDataType::string_datatype(),
|
||||
ConcreteDataType::binary_datatype(),
|
||||
],
|
||||
vec![
|
||||
ConcreteDataType::string_datatype(),
|
||||
ConcreteDataType::binary_datatype(),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 2,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly two, have: {}",
|
||||
columns.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
let arg0 = &columns[0];
|
||||
let arg1 = &columns[1];
|
||||
|
||||
let size = arg0.len();
|
||||
let mut result = Float64VectorBuilder::with_capacity(size);
|
||||
if size == 0 {
|
||||
return Ok(result.to_vector());
|
||||
}
|
||||
|
||||
let arg0_const = parse_if_constant_string(arg0)?;
|
||||
let arg1_const = parse_if_constant_string(arg1)?;
|
||||
|
||||
for i in 0..size {
|
||||
let vec0 = match arg0_const.as_ref() {
|
||||
Some(a) => Some(Cow::Borrowed(a.as_slice())),
|
||||
None => as_vector(arg0.get_ref(i))?,
|
||||
};
|
||||
let vec1 = match arg1_const.as_ref() {
|
||||
Some(b) => Some(Cow::Borrowed(b.as_slice())),
|
||||
None => as_vector(arg1.get_ref(i))?,
|
||||
};
|
||||
|
||||
if let (Some(vec0), Some(vec1)) = (vec0, vec1) {
|
||||
ensure!(
|
||||
vec0.len() == vec1.len(),
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the vectors must match to calculate distance, have: {} vs {}",
|
||||
vec0.len(),
|
||||
vec1.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
|
||||
let f = <f32 as simsimd::SpatialSimilarity>::$similarity_method;
|
||||
// Safe: checked if the length of the vectors match
|
||||
let d = f(vec0.as_ref(), vec1.as_ref()).unwrap();
|
||||
result.push(Some(d));
|
||||
} else {
|
||||
result.push_null();
|
||||
}
|
||||
}
|
||||
|
||||
return Ok(result.to_vector());
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for $StructName {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", $display_name.to_ascii_uppercase())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
define_distance_function!(CosDistanceFunction, "cos_distance", cos);
|
||||
define_distance_function!(L2SqDistanceFunction, "l2sq_distance", l2sq);
|
||||
define_distance_function!(DotProductFunction, "dot_product", dot);
|
||||
|
||||
/// Parse a vector value if the value is a constant string.
|
||||
fn parse_if_constant_string(arg: &Arc<dyn Vector>) -> Result<Option<Vec<f32>>> {
|
||||
if !arg.is_const() {
|
||||
return Ok(None);
|
||||
}
|
||||
if arg.data_type() != ConcreteDataType::string_datatype() {
|
||||
return Ok(None);
|
||||
}
|
||||
arg.get_ref(0)
|
||||
.as_string()
|
||||
.unwrap() // Safe: checked if it is a string
|
||||
.map(parse_f32_vector_from_string)
|
||||
.transpose()
|
||||
}
|
||||
|
||||
/// Convert a value to a vector value.
|
||||
/// Supported data types are binary and string.
|
||||
fn as_vector(arg: ValueRef<'_>) -> Result<Option<Cow<'_, [f32]>>> {
|
||||
match arg.data_type() {
|
||||
ConcreteDataType::Binary(_) => arg
|
||||
.as_binary()
|
||||
.unwrap() // Safe: checked if it is a binary
|
||||
.map(|bytes| Ok(Cow::Borrowed(binary_as_vector(bytes)?)))
|
||||
.transpose(),
|
||||
ConcreteDataType::String(_) => arg
|
||||
.as_string()
|
||||
.unwrap() // Safe: checked if it is a string
|
||||
.map(|s| Ok(Cow::Owned(parse_f32_vector_from_string(s)?)))
|
||||
.transpose(),
|
||||
ConcreteDataType::Null(_) => Ok(None),
|
||||
_ => InvalidFuncArgsSnafu {
|
||||
err_msg: format!("Unsupported data type: {:?}", arg.data_type()),
|
||||
}
|
||||
.fail(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a u8 slice to a vector value.
|
||||
fn binary_as_vector(bytes: &[u8]) -> Result<&[f32]> {
|
||||
if bytes.len() % 4 != 0 {
|
||||
return InvalidFuncArgsSnafu {
|
||||
err_msg: format!("Invalid binary length of vector: {}", bytes.len()),
|
||||
}
|
||||
.fail();
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let num_floats = bytes.len() / 4;
|
||||
let floats: &[f32] = std::slice::from_raw_parts(bytes.as_ptr() as *const f32, num_floats);
|
||||
Ok(floats)
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a string to a vector value.
|
||||
/// Valid inputs are strings like "[1.0, 2.0, 3.0]".
|
||||
fn parse_f32_vector_from_string(s: &str) -> Result<Vec<f32>> {
|
||||
let trimmed = s.trim();
|
||||
if !trimmed.starts_with('[') || !trimmed.ends_with(']') {
|
||||
return InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"Failed to parse {s} to Vector value: not properly enclosed in brackets"
|
||||
),
|
||||
}
|
||||
.fail();
|
||||
}
|
||||
let content = trimmed[1..trimmed.len() - 1].trim();
|
||||
if content.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
content
|
||||
.split(',')
|
||||
.map(|s| s.trim().parse::<f32>())
|
||||
.collect::<std::result::Result<_, _>>()
|
||||
.map_err(|e| {
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!("Failed to parse {s} to Vector value: {e}"),
|
||||
}
|
||||
.build()
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datatypes::vectors::{BinaryVector, ConstantVector, StringVector};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_distance_string_string() {
|
||||
let funcs = [
|
||||
Box::new(CosDistanceFunction {}) as Box<dyn Function>,
|
||||
Box::new(L2SqDistanceFunction {}) as Box<dyn Function>,
|
||||
Box::new(DotProductFunction {}) as Box<dyn Function>,
|
||||
];
|
||||
|
||||
for func in funcs {
|
||||
let vec1 = Arc::new(StringVector::from(vec![
|
||||
Some("[0.0, 1.0]"),
|
||||
Some("[1.0, 0.0]"),
|
||||
None,
|
||||
Some("[1.0, 0.0]"),
|
||||
])) as VectorRef;
|
||||
let vec2 = Arc::new(StringVector::from(vec![
|
||||
Some("[0.0, 1.0]"),
|
||||
Some("[0.0, 1.0]"),
|
||||
Some("[0.0, 1.0]"),
|
||||
None,
|
||||
])) as VectorRef;
|
||||
|
||||
let result = func
|
||||
.eval(FunctionContext::default(), &[vec1.clone(), vec2.clone()])
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.get(0).is_null());
|
||||
assert!(!result.get(1).is_null());
|
||||
assert!(result.get(2).is_null());
|
||||
assert!(result.get(3).is_null());
|
||||
|
||||
let result = func
|
||||
.eval(FunctionContext::default(), &[vec2, vec1])
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.get(0).is_null());
|
||||
assert!(!result.get(1).is_null());
|
||||
assert!(result.get(2).is_null());
|
||||
assert!(result.get(3).is_null());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_distance_binary_binary() {
|
||||
let funcs = [
|
||||
Box::new(CosDistanceFunction {}) as Box<dyn Function>,
|
||||
Box::new(L2SqDistanceFunction {}) as Box<dyn Function>,
|
||||
Box::new(DotProductFunction {}) as Box<dyn Function>,
|
||||
];
|
||||
|
||||
for func in funcs {
|
||||
let vec1 = Arc::new(BinaryVector::from(vec![
|
||||
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
|
||||
Some(vec![0, 0, 128, 63, 0, 0, 0, 0]),
|
||||
None,
|
||||
Some(vec![0, 0, 128, 63, 0, 0, 0, 0]),
|
||||
])) as VectorRef;
|
||||
let vec2 = Arc::new(BinaryVector::from(vec![
|
||||
// [0.0, 1.0]
|
||||
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
|
||||
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
|
||||
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
|
||||
None,
|
||||
])) as VectorRef;
|
||||
|
||||
let result = func
|
||||
.eval(FunctionContext::default(), &[vec1.clone(), vec2.clone()])
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.get(0).is_null());
|
||||
assert!(!result.get(1).is_null());
|
||||
assert!(result.get(2).is_null());
|
||||
assert!(result.get(3).is_null());
|
||||
|
||||
let result = func
|
||||
.eval(FunctionContext::default(), &[vec2, vec1])
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.get(0).is_null());
|
||||
assert!(!result.get(1).is_null());
|
||||
assert!(result.get(2).is_null());
|
||||
assert!(result.get(3).is_null());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_distance_string_binary() {
|
||||
let funcs = [
|
||||
Box::new(CosDistanceFunction {}) as Box<dyn Function>,
|
||||
Box::new(L2SqDistanceFunction {}) as Box<dyn Function>,
|
||||
Box::new(DotProductFunction {}) as Box<dyn Function>,
|
||||
];
|
||||
|
||||
for func in funcs {
|
||||
let vec1 = Arc::new(StringVector::from(vec![
|
||||
Some("[0.0, 1.0]"),
|
||||
Some("[1.0, 0.0]"),
|
||||
None,
|
||||
Some("[1.0, 0.0]"),
|
||||
])) as VectorRef;
|
||||
let vec2 = Arc::new(BinaryVector::from(vec![
|
||||
// [0.0, 1.0]
|
||||
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
|
||||
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
|
||||
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
|
||||
None,
|
||||
])) as VectorRef;
|
||||
|
||||
let result = func
|
||||
.eval(FunctionContext::default(), &[vec1.clone(), vec2.clone()])
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.get(0).is_null());
|
||||
assert!(!result.get(1).is_null());
|
||||
assert!(result.get(2).is_null());
|
||||
assert!(result.get(3).is_null());
|
||||
|
||||
let result = func
|
||||
.eval(FunctionContext::default(), &[vec2, vec1])
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.get(0).is_null());
|
||||
assert!(!result.get(1).is_null());
|
||||
assert!(result.get(2).is_null());
|
||||
assert!(result.get(3).is_null());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_distance_const_string() {
|
||||
let funcs = [
|
||||
Box::new(CosDistanceFunction {}) as Box<dyn Function>,
|
||||
Box::new(L2SqDistanceFunction {}) as Box<dyn Function>,
|
||||
Box::new(DotProductFunction {}) as Box<dyn Function>,
|
||||
];
|
||||
|
||||
for func in funcs {
|
||||
let const_str = Arc::new(ConstantVector::new(
|
||||
Arc::new(StringVector::from(vec!["[0.0, 1.0]"])),
|
||||
4,
|
||||
));
|
||||
|
||||
let vec1 = Arc::new(StringVector::from(vec![
|
||||
Some("[0.0, 1.0]"),
|
||||
Some("[1.0, 0.0]"),
|
||||
None,
|
||||
Some("[1.0, 0.0]"),
|
||||
])) as VectorRef;
|
||||
let vec2 = Arc::new(BinaryVector::from(vec![
|
||||
// [0.0, 1.0]
|
||||
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
|
||||
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
|
||||
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
|
||||
None,
|
||||
])) as VectorRef;
|
||||
|
||||
let result = func
|
||||
.eval(
|
||||
FunctionContext::default(),
|
||||
&[const_str.clone(), vec1.clone()],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.get(0).is_null());
|
||||
assert!(!result.get(1).is_null());
|
||||
assert!(result.get(2).is_null());
|
||||
assert!(!result.get(3).is_null());
|
||||
|
||||
let result = func
|
||||
.eval(
|
||||
FunctionContext::default(),
|
||||
&[vec1.clone(), const_str.clone()],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.get(0).is_null());
|
||||
assert!(!result.get(1).is_null());
|
||||
assert!(result.get(2).is_null());
|
||||
assert!(!result.get(3).is_null());
|
||||
|
||||
let result = func
|
||||
.eval(
|
||||
FunctionContext::default(),
|
||||
&[const_str.clone(), vec2.clone()],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.get(0).is_null());
|
||||
assert!(!result.get(1).is_null());
|
||||
assert!(!result.get(2).is_null());
|
||||
assert!(result.get(3).is_null());
|
||||
|
||||
let result = func
|
||||
.eval(
|
||||
FunctionContext::default(),
|
||||
&[vec2.clone(), const_str.clone()],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.get(0).is_null());
|
||||
assert!(!result.get(1).is_null());
|
||||
assert!(!result.get(2).is_null());
|
||||
assert!(result.get(3).is_null());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_vector_length() {
|
||||
let funcs = [
|
||||
Box::new(CosDistanceFunction {}) as Box<dyn Function>,
|
||||
Box::new(L2SqDistanceFunction {}) as Box<dyn Function>,
|
||||
Box::new(DotProductFunction {}) as Box<dyn Function>,
|
||||
];
|
||||
|
||||
for func in funcs {
|
||||
let vec1 = Arc::new(StringVector::from(vec!["[1.0]"])) as VectorRef;
|
||||
let vec2 = Arc::new(StringVector::from(vec!["[1.0, 1.0]"])) as VectorRef;
|
||||
let result = func.eval(FunctionContext::default(), &[vec1, vec2]);
|
||||
assert!(result.is_err());
|
||||
|
||||
let vec1 = Arc::new(BinaryVector::from(vec![vec![0, 0, 128, 63]])) as VectorRef;
|
||||
let vec2 =
|
||||
Arc::new(BinaryVector::from(vec![vec![0, 0, 128, 63, 0, 0, 0, 64]])) as VectorRef;
|
||||
let result = func.eval(FunctionContext::default(), &[vec1, vec2]);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_vector_from_string() {
|
||||
let result = parse_f32_vector_from_string("[1.0, 2.0, 3.0]").unwrap();
|
||||
assert_eq!(result, vec![1.0, 2.0, 3.0]);
|
||||
|
||||
let result = parse_f32_vector_from_string("[]").unwrap();
|
||||
assert_eq!(result, Vec::<f32>::new());
|
||||
|
||||
let result = parse_f32_vector_from_string("[1.0, a, 3.0]");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_as_vector() {
|
||||
let bytes = [0, 0, 128, 63];
|
||||
let result = binary_as_vector(&bytes).unwrap();
|
||||
assert_eq!(result, &[1.0]);
|
||||
|
||||
let invalid_bytes = [0, 0, 128];
|
||||
let result = binary_as_vector(&invalid_bytes);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
@@ -1089,7 +1089,7 @@ macro_rules! impl_as_for_value_ref {
|
||||
};
|
||||
}
|
||||
|
||||
impl ValueRef<'_> {
|
||||
impl<'a> ValueRef<'a> {
|
||||
define_data_type_func!(ValueRef);
|
||||
|
||||
/// Returns true if this is null.
|
||||
@@ -1098,12 +1098,12 @@ impl ValueRef<'_> {
|
||||
}
|
||||
|
||||
/// Cast itself to binary slice.
|
||||
pub fn as_binary(&self) -> Result<Option<&[u8]>> {
|
||||
pub fn as_binary(&self) -> Result<Option<&'a [u8]>> {
|
||||
impl_as_for_value_ref!(self, Binary)
|
||||
}
|
||||
|
||||
/// Cast itself to string slice.
|
||||
pub fn as_string(&self) -> Result<Option<&str>> {
|
||||
pub fn as_string(&self) -> Result<Option<&'a str>> {
|
||||
impl_as_for_value_ref!(self, String)
|
||||
}
|
||||
|
||||
|
||||
@@ -31,6 +31,186 @@ SELECT * FROM t;
|
||||
| 1970-01-01 00:00:00.003000 | "[7,8,9]" |
|
||||
+----------------------------+-----------+
|
||||
|
||||
SELECT round(cos_distance(v, '[0.0, 0.0, 0.0]'), 4) FROM t;
|
||||
|
||||
+-----------------------------------------------------------+
|
||||
| round(cos_distance(t.v,Utf8("[0.0, 0.0, 0.0]")),Int64(4)) |
|
||||
+-----------------------------------------------------------+
|
||||
| 1.0 |
|
||||
| 1.0 |
|
||||
| 1.0 |
|
||||
+-----------------------------------------------------------+
|
||||
|
||||
SELECT *, round(cos_distance(v, '[0.0, 0.0, 0.0]'), 4) as d FROM t ORDER BY d;
|
||||
|
||||
+-------------------------+--------------------------+-----+
|
||||
| ts | v | d |
|
||||
+-------------------------+--------------------------+-----+
|
||||
| 1970-01-01T00:00:00.001 | 0000803f0000004000004040 | 1.0 |
|
||||
| 1970-01-01T00:00:00.002 | 000080400000a0400000c040 | 1.0 |
|
||||
| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 1.0 |
|
||||
+-------------------------+--------------------------+-----+
|
||||
|
||||
SELECT round(cos_distance('[7.0, 8.0, 9.0]', v), 4) FROM t;
|
||||
|
||||
+-----------------------------------------------------------+
|
||||
| round(cos_distance(Utf8("[7.0, 8.0, 9.0]"),t.v),Int64(4)) |
|
||||
+-----------------------------------------------------------+
|
||||
| 0.0406 |
|
||||
| 0.0018 |
|
||||
| 0.0 |
|
||||
+-----------------------------------------------------------+
|
||||
|
||||
SELECT *, round(cos_distance('[7.0, 8.0, 9.0]', v), 4) as d FROM t ORDER BY d;
|
||||
|
||||
+-------------------------+--------------------------+--------+
|
||||
| ts | v | d |
|
||||
+-------------------------+--------------------------+--------+
|
||||
| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 0.0 |
|
||||
| 1970-01-01T00:00:00.002 | 000080400000a0400000c040 | 0.0018 |
|
||||
| 1970-01-01T00:00:00.001 | 0000803f0000004000004040 | 0.0406 |
|
||||
+-------------------------+--------------------------+--------+
|
||||
|
||||
SELECT round(cos_distance(v, v), 4) FROM t;
|
||||
|
||||
+---------------------------------------+
|
||||
| round(cos_distance(t.v,t.v),Int64(4)) |
|
||||
+---------------------------------------+
|
||||
| 0.0 |
|
||||
| 0.0 |
|
||||
| 0.0 |
|
||||
+---------------------------------------+
|
||||
|
||||
-- Unexpected dimension --
|
||||
SELECT cos_distance(v, '[1.0]') FROM t;
|
||||
|
||||
Error: 3001(EngineExecuteQuery), Invalid function args: The length of the vectors must match to calculate distance, have: 3 vs 1
|
||||
|
||||
-- Invalid type --
|
||||
SELECT cos_distance(v, 1.0) FROM t;
|
||||
|
||||
Error: 3001(EngineExecuteQuery), Invalid argument error: Encountered non UTF-8 data: invalid utf-8 sequence of 1 bytes from index 2
|
||||
|
||||
SELECT round(l2sq_distance(v, '[0.0, 0.0, 0.0]'), 4) FROM t;
|
||||
|
||||
+------------------------------------------------------------+
|
||||
| round(l2sq_distance(t.v,Utf8("[0.0, 0.0, 0.0]")),Int64(4)) |
|
||||
+------------------------------------------------------------+
|
||||
| 14.0 |
|
||||
| 77.0 |
|
||||
| 194.0 |
|
||||
+------------------------------------------------------------+
|
||||
|
||||
SELECT *, round(l2sq_distance(v, '[0.0, 0.0, 0.0]'), 4) as d FROM t ORDER BY d;
|
||||
|
||||
+-------------------------+--------------------------+-------+
|
||||
| ts | v | d |
|
||||
+-------------------------+--------------------------+-------+
|
||||
| 1970-01-01T00:00:00.001 | 0000803f0000004000004040 | 14.0 |
|
||||
| 1970-01-01T00:00:00.002 | 000080400000a0400000c040 | 77.0 |
|
||||
| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 194.0 |
|
||||
+-------------------------+--------------------------+-------+
|
||||
|
||||
SELECT round(l2sq_distance('[7.0, 8.0, 9.0]', v), 4) FROM t;
|
||||
|
||||
+------------------------------------------------------------+
|
||||
| round(l2sq_distance(Utf8("[7.0, 8.0, 9.0]"),t.v),Int64(4)) |
|
||||
+------------------------------------------------------------+
|
||||
| 108.0 |
|
||||
| 27.0 |
|
||||
| 0.0 |
|
||||
+------------------------------------------------------------+
|
||||
|
||||
SELECT *, round(l2sq_distance('[7.0, 8.0, 9.0]', v), 4) as d FROM t ORDER BY d;
|
||||
|
||||
+-------------------------+--------------------------+-------+
|
||||
| ts | v | d |
|
||||
+-------------------------+--------------------------+-------+
|
||||
| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 0.0 |
|
||||
| 1970-01-01T00:00:00.002 | 000080400000a0400000c040 | 27.0 |
|
||||
| 1970-01-01T00:00:00.001 | 0000803f0000004000004040 | 108.0 |
|
||||
+-------------------------+--------------------------+-------+
|
||||
|
||||
SELECT round(l2sq_distance(v, v), 4) FROM t;
|
||||
|
||||
+----------------------------------------+
|
||||
| round(l2sq_distance(t.v,t.v),Int64(4)) |
|
||||
+----------------------------------------+
|
||||
| 0.0 |
|
||||
| 0.0 |
|
||||
| 0.0 |
|
||||
+----------------------------------------+
|
||||
|
||||
-- Unexpected dimension --
|
||||
SELECT l2sq_distance(v, '[1.0]') FROM t;
|
||||
|
||||
Error: 3001(EngineExecuteQuery), Invalid function args: The length of the vectors must match to calculate distance, have: 3 vs 1
|
||||
|
||||
-- Invalid type --
|
||||
SELECT l2sq_distance(v, 1.0) FROM t;
|
||||
|
||||
Error: 3001(EngineExecuteQuery), Invalid argument error: Encountered non UTF-8 data: invalid utf-8 sequence of 1 bytes from index 2
|
||||
|
||||
SELECT round(dot_product(v, '[0.0, 0.0, 0.0]'), 4) FROM t;
|
||||
|
||||
+----------------------------------------------------------+
|
||||
| round(dot_product(t.v,Utf8("[0.0, 0.0, 0.0]")),Int64(4)) |
|
||||
+----------------------------------------------------------+
|
||||
| 0.0 |
|
||||
| 0.0 |
|
||||
| 0.0 |
|
||||
+----------------------------------------------------------+
|
||||
|
||||
SELECT *, round(dot_product(v, '[0.0, 0.0, 0.0]'), 4) as d FROM t ORDER BY d;
|
||||
|
||||
+-------------------------+--------------------------+-----+
|
||||
| ts | v | d |
|
||||
+-------------------------+--------------------------+-----+
|
||||
| 1970-01-01T00:00:00.001 | 0000803f0000004000004040 | 0.0 |
|
||||
| 1970-01-01T00:00:00.002 | 000080400000a0400000c040 | 0.0 |
|
||||
| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 0.0 |
|
||||
+-------------------------+--------------------------+-----+
|
||||
|
||||
SELECT round(dot_product('[7.0, 8.0, 9.0]', v), 4) FROM t;
|
||||
|
||||
+----------------------------------------------------------+
|
||||
| round(dot_product(Utf8("[7.0, 8.0, 9.0]"),t.v),Int64(4)) |
|
||||
+----------------------------------------------------------+
|
||||
| 50.0 |
|
||||
| 122.0 |
|
||||
| 194.0 |
|
||||
+----------------------------------------------------------+
|
||||
|
||||
SELECT *, round(dot_product('[7.0, 8.0, 9.0]', v), 4) as d FROM t ORDER BY d;
|
||||
|
||||
+-------------------------+--------------------------+-------+
|
||||
| ts | v | d |
|
||||
+-------------------------+--------------------------+-------+
|
||||
| 1970-01-01T00:00:00.001 | 0000803f0000004000004040 | 50.0 |
|
||||
| 1970-01-01T00:00:00.002 | 000080400000a0400000c040 | 122.0 |
|
||||
| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 194.0 |
|
||||
+-------------------------+--------------------------+-------+
|
||||
|
||||
SELECT round(dot_product(v, v), 4) FROM t;
|
||||
|
||||
+--------------------------------------+
|
||||
| round(dot_product(t.v,t.v),Int64(4)) |
|
||||
+--------------------------------------+
|
||||
| 14.0 |
|
||||
| 77.0 |
|
||||
| 194.0 |
|
||||
+--------------------------------------+
|
||||
|
||||
-- Unexpected dimension --
|
||||
SELECT dot_product(v, '[1.0]') FROM t;
|
||||
|
||||
Error: 3001(EngineExecuteQuery), Invalid function args: The length of the vectors must match to calculate distance, have: 3 vs 1
|
||||
|
||||
-- Invalid type --
|
||||
SELECT dot_product(v, 1.0) FROM t;
|
||||
|
||||
Error: 3001(EngineExecuteQuery), Invalid argument error: Encountered non UTF-8 data: invalid utf-8 sequence of 1 bytes from index 2
|
||||
|
||||
-- Unexpected dimension --
|
||||
INSERT INTO t VALUES
|
||||
(4, "[1.0]");
|
||||
|
||||
@@ -11,6 +11,55 @@ SELECT * FROM t;
|
||||
-- SQLNESS PROTOCOL POSTGRES
|
||||
SELECT * FROM t;
|
||||
|
||||
SELECT round(cos_distance(v, '[0.0, 0.0, 0.0]'), 4) FROM t;
|
||||
|
||||
SELECT *, round(cos_distance(v, '[0.0, 0.0, 0.0]'), 4) as d FROM t ORDER BY d;
|
||||
|
||||
SELECT round(cos_distance('[7.0, 8.0, 9.0]', v), 4) FROM t;
|
||||
|
||||
SELECT *, round(cos_distance('[7.0, 8.0, 9.0]', v), 4) as d FROM t ORDER BY d;
|
||||
|
||||
SELECT round(cos_distance(v, v), 4) FROM t;
|
||||
|
||||
-- Unexpected dimension --
|
||||
SELECT cos_distance(v, '[1.0]') FROM t;
|
||||
|
||||
-- Invalid type --
|
||||
SELECT cos_distance(v, 1.0) FROM t;
|
||||
|
||||
SELECT round(l2sq_distance(v, '[0.0, 0.0, 0.0]'), 4) FROM t;
|
||||
|
||||
SELECT *, round(l2sq_distance(v, '[0.0, 0.0, 0.0]'), 4) as d FROM t ORDER BY d;
|
||||
|
||||
SELECT round(l2sq_distance('[7.0, 8.0, 9.0]', v), 4) FROM t;
|
||||
|
||||
SELECT *, round(l2sq_distance('[7.0, 8.0, 9.0]', v), 4) as d FROM t ORDER BY d;
|
||||
|
||||
SELECT round(l2sq_distance(v, v), 4) FROM t;
|
||||
|
||||
-- Unexpected dimension --
|
||||
SELECT l2sq_distance(v, '[1.0]') FROM t;
|
||||
|
||||
-- Invalid type --
|
||||
SELECT l2sq_distance(v, 1.0) FROM t;
|
||||
|
||||
|
||||
SELECT round(dot_product(v, '[0.0, 0.0, 0.0]'), 4) FROM t;
|
||||
|
||||
SELECT *, round(dot_product(v, '[0.0, 0.0, 0.0]'), 4) as d FROM t ORDER BY d;
|
||||
|
||||
SELECT round(dot_product('[7.0, 8.0, 9.0]', v), 4) FROM t;
|
||||
|
||||
SELECT *, round(dot_product('[7.0, 8.0, 9.0]', v), 4) as d FROM t ORDER BY d;
|
||||
|
||||
SELECT round(dot_product(v, v), 4) FROM t;
|
||||
|
||||
-- Unexpected dimension --
|
||||
SELECT dot_product(v, '[1.0]') FROM t;
|
||||
|
||||
-- Invalid type --
|
||||
SELECT dot_product(v, 1.0) FROM t;
|
||||
|
||||
-- Unexpected dimension --
|
||||
INSERT INTO t VALUES
|
||||
(4, "[1.0]");
|
||||
|
||||
Reference in New Issue
Block a user