mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-20 15:00:40 +00:00
feat(vector): remove simsimd and use nalgebra instead (#5027)
* feat(vector): remove `simsimd` and use `nalgebra` instead Signed-off-by: Zhenchi <zhongzc_arch@outlook.com> * keep thing simple Signed-off-by: Zhenchi <zhongzc_arch@outlook.com> --------- Signed-off-by: Zhenchi <zhongzc_arch@outlook.com>
This commit is contained in:
@@ -33,6 +33,7 @@ geo-types = { version = "0.7", optional = true }
|
||||
geohash = { version = "0.13", optional = true }
|
||||
h3o = { version = "0.6", optional = true }
|
||||
jsonb.workspace = true
|
||||
nalgebra = "0.33"
|
||||
num = "0.4"
|
||||
num-traits = "0.2"
|
||||
once_cell.workspace = true
|
||||
@@ -41,7 +42,6 @@ s2 = { version = "0.0.12", optional = true }
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
session.workspace = true
|
||||
simsimd = "4"
|
||||
snafu.workspace = true
|
||||
sql.workspace = true
|
||||
statrs = "0.16"
|
||||
@@ -50,6 +50,7 @@ table.workspace = true
|
||||
wkt = { version = "0.11", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
approx = "0.5"
|
||||
ron = "0.7"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
tokio.workspace = true
|
||||
|
||||
@@ -12,6 +12,10 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
mod cos;
|
||||
mod dot;
|
||||
mod l2sq;
|
||||
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Display;
|
||||
use std::sync::Arc;
|
||||
@@ -21,14 +25,14 @@ use common_query::prelude::Signature;
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::value::ValueRef;
|
||||
use datatypes::vectors::{Float64VectorBuilder, MutableVector, Vector, VectorRef};
|
||||
use datatypes::vectors::{Float32VectorBuilder, MutableVector, Vector, VectorRef};
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::helper;
|
||||
|
||||
macro_rules! define_distance_function {
|
||||
($StructName:ident, $display_name:expr, $similarity_method:ident) => {
|
||||
($StructName:ident, $display_name:expr, $similarity_method:path) => {
|
||||
|
||||
/// A function calculates the distance between two vectors.
|
||||
|
||||
@@ -41,7 +45,7 @@ macro_rules! define_distance_function {
|
||||
}
|
||||
|
||||
fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
|
||||
Ok(ConcreteDataType::float64_datatype())
|
||||
Ok(ConcreteDataType::float32_datatype())
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
@@ -71,7 +75,7 @@ macro_rules! define_distance_function {
|
||||
let arg1 = &columns[1];
|
||||
|
||||
let size = arg0.len();
|
||||
let mut result = Float64VectorBuilder::with_capacity(size);
|
||||
let mut result = Float32VectorBuilder::with_capacity(size);
|
||||
if size == 0 {
|
||||
return Ok(result.to_vector());
|
||||
}
|
||||
@@ -101,9 +105,8 @@ macro_rules! define_distance_function {
|
||||
}
|
||||
);
|
||||
|
||||
let f = <f32 as simsimd::SpatialSimilarity>::$similarity_method;
|
||||
// Safe: checked if the length of the vectors match
|
||||
let d = f(vec0.as_ref(), vec1.as_ref()).unwrap();
|
||||
// Checked if the length of the vectors match
|
||||
let d = $similarity_method(vec0.as_ref(), vec1.as_ref());
|
||||
result.push(Some(d));
|
||||
} else {
|
||||
result.push_null();
|
||||
@@ -122,9 +125,9 @@ macro_rules! define_distance_function {
|
||||
}
|
||||
}
|
||||
|
||||
define_distance_function!(CosDistanceFunction, "cos_distance", cos);
|
||||
define_distance_function!(L2SqDistanceFunction, "l2sq_distance", l2sq);
|
||||
define_distance_function!(DotProductFunction, "dot_product", dot);
|
||||
define_distance_function!(CosDistanceFunction, "vec_cos_distance", cos::cos);
|
||||
define_distance_function!(L2SqDistanceFunction, "vec_l2sq_distance", l2sq::l2sq);
|
||||
define_distance_function!(DotProductFunction, "vec_dot_product", dot::dot);
|
||||
|
||||
/// Parse a vector value if the value is a constant string.
|
||||
fn parse_if_constant_string(arg: &Arc<dyn Vector>) -> Result<Option<Vec<f32>>> {
|
||||
@@ -148,7 +151,7 @@ fn as_vector(arg: ValueRef<'_>) -> Result<Option<Cow<'_, [f32]>>> {
|
||||
ConcreteDataType::Binary(_) => arg
|
||||
.as_binary()
|
||||
.unwrap() // Safe: checked if it is a binary
|
||||
.map(|bytes| Ok(Cow::Borrowed(binary_as_vector(bytes)?)))
|
||||
.map(binary_as_vector)
|
||||
.transpose(),
|
||||
ConcreteDataType::String(_) => arg
|
||||
.as_string()
|
||||
@@ -164,18 +167,28 @@ fn as_vector(arg: ValueRef<'_>) -> Result<Option<Cow<'_, [f32]>>> {
|
||||
}
|
||||
|
||||
/// Convert a u8 slice to a vector value.
|
||||
fn binary_as_vector(bytes: &[u8]) -> Result<&[f32]> {
|
||||
if bytes.len() % 4 != 0 {
|
||||
fn binary_as_vector(bytes: &[u8]) -> Result<Cow<'_, [f32]>> {
|
||||
if bytes.len() % std::mem::size_of::<f32>() != 0 {
|
||||
return InvalidFuncArgsSnafu {
|
||||
err_msg: format!("Invalid binary length of vector: {}", bytes.len()),
|
||||
}
|
||||
.fail();
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let num_floats = bytes.len() / 4;
|
||||
let floats: &[f32] = std::slice::from_raw_parts(bytes.as_ptr() as *const f32, num_floats);
|
||||
Ok(floats)
|
||||
if cfg!(target_endian = "little") {
|
||||
Ok(unsafe {
|
||||
let vec = std::slice::from_raw_parts(
|
||||
bytes.as_ptr() as *const f32,
|
||||
bytes.len() / std::mem::size_of::<f32>(),
|
||||
);
|
||||
Cow::Borrowed(vec)
|
||||
})
|
||||
} else {
|
||||
let v = bytes
|
||||
.chunks_exact(std::mem::size_of::<f32>())
|
||||
.map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap()))
|
||||
.collect::<Vec<f32>>();
|
||||
Ok(Cow::Owned(v))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -460,7 +473,7 @@ mod tests {
|
||||
fn test_binary_as_vector() {
|
||||
let bytes = [0, 0, 128, 63];
|
||||
let result = binary_as_vector(&bytes).unwrap();
|
||||
assert_eq!(result, &[1.0]);
|
||||
assert_eq!(result.as_ref(), &[1.0]);
|
||||
|
||||
let invalid_bytes = [0, 0, 128];
|
||||
let result = binary_as_vector(&invalid_bytes);
|
||||
|
||||
87
src/common/function/src/scalars/vector/distance/cos.rs
Normal file
87
src/common/function/src/scalars/vector/distance/cos.rs
Normal file
@@ -0,0 +1,87 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use nalgebra::DVectorView;
|
||||
|
||||
/// Calculates the cos distance between two vectors.
|
||||
///
|
||||
/// **Note:** Must ensure that the length of the two vectors are the same.
|
||||
pub fn cos(lhs: &[f32], rhs: &[f32]) -> f32 {
|
||||
let lhs_vec = DVectorView::from_slice(lhs, lhs.len());
|
||||
let rhs_vec = DVectorView::from_slice(rhs, rhs.len());
|
||||
|
||||
let dot_product = lhs_vec.dot(&rhs_vec);
|
||||
let lhs_norm = lhs_vec.norm();
|
||||
let rhs_norm = rhs_vec.norm();
|
||||
if dot_product.abs() < f32::EPSILON
|
||||
|| lhs_norm.abs() < f32::EPSILON
|
||||
|| rhs_norm.abs() < f32::EPSILON
|
||||
{
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
let cos_similar = dot_product / (lhs_norm * rhs_norm);
|
||||
let res = 1.0 - cos_similar;
|
||||
if res.abs() < f32::EPSILON {
|
||||
0.0
|
||||
} else {
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use approx::assert_relative_eq;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_cos_scalar() {
|
||||
let lhs = vec![1.0, 2.0, 3.0];
|
||||
let rhs = vec![1.0, 2.0, 3.0];
|
||||
assert_relative_eq!(cos(&lhs, &rhs), 0.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![1.0, 2.0, 3.0];
|
||||
let rhs = vec![4.0, 5.0, 6.0];
|
||||
assert_relative_eq!(cos(&lhs, &rhs), 0.025, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![1.0, 2.0, 3.0];
|
||||
let rhs = vec![7.0, 8.0, 9.0];
|
||||
assert_relative_eq!(cos(&lhs, &rhs), 0.04, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![0.0, 0.0, 0.0];
|
||||
let rhs = vec![1.0, 2.0, 3.0];
|
||||
assert_relative_eq!(cos(&lhs, &rhs), 1.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![0.0, 0.0, 0.0];
|
||||
let rhs = vec![4.0, 5.0, 6.0];
|
||||
assert_relative_eq!(cos(&lhs, &rhs), 1.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![0.0, 0.0, 0.0];
|
||||
let rhs = vec![7.0, 8.0, 9.0];
|
||||
assert_relative_eq!(cos(&lhs, &rhs), 1.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![7.0, 8.0, 9.0];
|
||||
let rhs = vec![1.0, 2.0, 3.0];
|
||||
assert_relative_eq!(cos(&lhs, &rhs), 0.04, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![7.0, 8.0, 9.0];
|
||||
let rhs = vec![4.0, 5.0, 6.0];
|
||||
assert_relative_eq!(cos(&lhs, &rhs), 0.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![7.0, 8.0, 9.0];
|
||||
let rhs = vec![7.0, 8.0, 9.0];
|
||||
assert_relative_eq!(cos(&lhs, &rhs), 0.0, epsilon = 1e-2);
|
||||
}
|
||||
}
|
||||
71
src/common/function/src/scalars/vector/distance/dot.rs
Normal file
71
src/common/function/src/scalars/vector/distance/dot.rs
Normal file
@@ -0,0 +1,71 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use nalgebra::DVectorView;
|
||||
|
||||
/// Calculates the dot product between two vectors.
|
||||
///
|
||||
/// **Note:** Must ensure that the length of the two vectors are the same.
|
||||
pub fn dot(lhs: &[f32], rhs: &[f32]) -> f32 {
|
||||
let lhs = DVectorView::from_slice(lhs, lhs.len());
|
||||
let rhs = DVectorView::from_slice(rhs, rhs.len());
|
||||
|
||||
lhs.dot(&rhs)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use approx::assert_relative_eq;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_dot_scalar() {
|
||||
let lhs = vec![1.0, 2.0, 3.0];
|
||||
let rhs = vec![1.0, 2.0, 3.0];
|
||||
assert_relative_eq!(dot(&lhs, &rhs), 14.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![1.0, 2.0, 3.0];
|
||||
let rhs = vec![4.0, 5.0, 6.0];
|
||||
assert_relative_eq!(dot(&lhs, &rhs), 32.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![1.0, 2.0, 3.0];
|
||||
let rhs = vec![7.0, 8.0, 9.0];
|
||||
assert_relative_eq!(dot(&lhs, &rhs), 50.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![0.0, 0.0, 0.0];
|
||||
let rhs = vec![1.0, 2.0, 3.0];
|
||||
assert_relative_eq!(dot(&lhs, &rhs), 0.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![0.0, 0.0, 0.0];
|
||||
let rhs = vec![4.0, 5.0, 6.0];
|
||||
assert_relative_eq!(dot(&lhs, &rhs), 0.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![0.0, 0.0, 0.0];
|
||||
let rhs = vec![7.0, 8.0, 9.0];
|
||||
assert_relative_eq!(dot(&lhs, &rhs), 0.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![7.0, 8.0, 9.0];
|
||||
let rhs = vec![1.0, 2.0, 3.0];
|
||||
assert_relative_eq!(dot(&lhs, &rhs), 50.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![7.0, 8.0, 9.0];
|
||||
let rhs = vec![4.0, 5.0, 6.0];
|
||||
assert_relative_eq!(dot(&lhs, &rhs), 122.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![7.0, 8.0, 9.0];
|
||||
let rhs = vec![7.0, 8.0, 9.0];
|
||||
assert_relative_eq!(dot(&lhs, &rhs), 194.0, epsilon = 1e-2);
|
||||
}
|
||||
}
|
||||
71
src/common/function/src/scalars/vector/distance/l2sq.rs
Normal file
71
src/common/function/src/scalars/vector/distance/l2sq.rs
Normal file
@@ -0,0 +1,71 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use nalgebra::DVectorView;
|
||||
|
||||
/// Calculates the squared L2 distance between two vectors.
|
||||
///
|
||||
/// **Note:** Must ensure that the length of the two vectors are the same.
|
||||
pub fn l2sq(lhs: &[f32], rhs: &[f32]) -> f32 {
|
||||
let lhs = DVectorView::from_slice(lhs, lhs.len());
|
||||
let rhs = DVectorView::from_slice(rhs, rhs.len());
|
||||
|
||||
(lhs - rhs).norm_squared()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use approx::assert_relative_eq;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_l2sq_scalar() {
|
||||
let lhs = vec![1.0, 2.0, 3.0];
|
||||
let rhs = vec![1.0, 2.0, 3.0];
|
||||
assert_relative_eq!(l2sq(&lhs, &rhs), 0.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![1.0, 2.0, 3.0];
|
||||
let rhs = vec![4.0, 5.0, 6.0];
|
||||
assert_relative_eq!(l2sq(&lhs, &rhs), 27.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![1.0, 2.0, 3.0];
|
||||
let rhs = vec![7.0, 8.0, 9.0];
|
||||
assert_relative_eq!(l2sq(&lhs, &rhs), 108.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![0.0, 0.0, 0.0];
|
||||
let rhs = vec![1.0, 2.0, 3.0];
|
||||
assert_relative_eq!(l2sq(&lhs, &rhs), 14.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![0.0, 0.0, 0.0];
|
||||
let rhs = vec![4.0, 5.0, 6.0];
|
||||
assert_relative_eq!(l2sq(&lhs, &rhs), 77.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![0.0, 0.0, 0.0];
|
||||
let rhs = vec![7.0, 8.0, 9.0];
|
||||
assert_relative_eq!(l2sq(&lhs, &rhs), 194.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![7.0, 8.0, 9.0];
|
||||
let rhs = vec![1.0, 2.0, 3.0];
|
||||
assert_relative_eq!(l2sq(&lhs, &rhs), 108.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![7.0, 8.0, 9.0];
|
||||
let rhs = vec![4.0, 5.0, 6.0];
|
||||
assert_relative_eq!(l2sq(&lhs, &rhs), 27.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![7.0, 8.0, 9.0];
|
||||
let rhs = vec![7.0, 8.0, 9.0];
|
||||
assert_relative_eq!(l2sq(&lhs, &rhs), 0.0, epsilon = 1e-2);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user