mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-18 22:10:42 +00:00
feat(vector): remove simsimd and use nalgebra instead (#5027)
* feat(vector): remove `simsimd` and use `nalgebra` instead Signed-off-by: Zhenchi <zhongzc_arch@outlook.com> * keep thing simple Signed-off-by: Zhenchi <zhongzc_arch@outlook.com> --------- Signed-off-by: Zhenchi <zhongzc_arch@outlook.com>
This commit is contained in:
58
Cargo.lock
generated
58
Cargo.lock
generated
@@ -2060,6 +2060,7 @@ name = "common-function"
|
||||
version = "0.9.5"
|
||||
dependencies = [
|
||||
"api",
|
||||
"approx 0.5.1",
|
||||
"arc-swap",
|
||||
"async-trait",
|
||||
"common-base",
|
||||
@@ -2080,6 +2081,7 @@ dependencies = [
|
||||
"geohash",
|
||||
"h3o",
|
||||
"jsonb",
|
||||
"nalgebra 0.33.2",
|
||||
"num",
|
||||
"num-traits",
|
||||
"once_cell",
|
||||
@@ -2089,7 +2091,6 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"session",
|
||||
"simsimd",
|
||||
"snafu 0.8.5",
|
||||
"sql",
|
||||
"statrs",
|
||||
@@ -7082,13 +7083,29 @@ checksum = "d506eb7e08d6329505faa8a3a00a5dcc6de9f76e0c77e4b75763ae3c770831ff"
|
||||
dependencies = [
|
||||
"approx 0.5.1",
|
||||
"matrixmultiply",
|
||||
"nalgebra-macros",
|
||||
"nalgebra-macros 0.1.0",
|
||||
"num-complex",
|
||||
"num-rational",
|
||||
"num-traits",
|
||||
"rand",
|
||||
"rand_distr",
|
||||
"simba",
|
||||
"simba 0.6.0",
|
||||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nalgebra"
|
||||
version = "0.33.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "26aecdf64b707efd1310e3544d709c5c0ac61c13756046aaaba41be5c4f66a3b"
|
||||
dependencies = [
|
||||
"approx 0.5.1",
|
||||
"matrixmultiply",
|
||||
"nalgebra-macros 0.2.2",
|
||||
"num-complex",
|
||||
"num-rational",
|
||||
"num-traits",
|
||||
"simba 0.9.0",
|
||||
"typenum",
|
||||
]
|
||||
|
||||
@@ -7103,6 +7120,17 @@ dependencies = [
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nalgebra-macros"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "254a5372af8fc138e36684761d3c0cdb758a4410e938babcff1c860ce14ddbfc"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "named_pipe"
|
||||
version = "0.4.1"
|
||||
@@ -11177,6 +11205,19 @@ dependencies = [
|
||||
"wide",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "simba"
|
||||
version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b3a386a501cd104797982c15ae17aafe8b9261315b5d07e3ec803f2ea26be0fa"
|
||||
dependencies = [
|
||||
"approx 0.5.1",
|
||||
"num-complex",
|
||||
"num-traits",
|
||||
"paste",
|
||||
"wide",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "simdutf8"
|
||||
version = "0.1.5"
|
||||
@@ -11215,15 +11256,6 @@ dependencies = [
|
||||
"time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "simsimd"
|
||||
version = "4.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "efc843bc8f12d9c8e6b734a0fe8918fc497b42f6ae0f347dbfdad5b5138ab9b4"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "siphasher"
|
||||
version = "0.3.11"
|
||||
@@ -11664,7 +11696,7 @@ checksum = "b35a062dbadac17a42e0fc64c27f419b25d6fae98572eb43c8814c9e873d7721"
|
||||
dependencies = [
|
||||
"approx 0.5.1",
|
||||
"lazy_static",
|
||||
"nalgebra",
|
||||
"nalgebra 0.29.0",
|
||||
"num-traits",
|
||||
"rand",
|
||||
]
|
||||
|
||||
@@ -33,6 +33,7 @@ geo-types = { version = "0.7", optional = true }
|
||||
geohash = { version = "0.13", optional = true }
|
||||
h3o = { version = "0.6", optional = true }
|
||||
jsonb.workspace = true
|
||||
nalgebra = "0.33"
|
||||
num = "0.4"
|
||||
num-traits = "0.2"
|
||||
once_cell.workspace = true
|
||||
@@ -41,7 +42,6 @@ s2 = { version = "0.0.12", optional = true }
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
session.workspace = true
|
||||
simsimd = "4"
|
||||
snafu.workspace = true
|
||||
sql.workspace = true
|
||||
statrs = "0.16"
|
||||
@@ -50,6 +50,7 @@ table.workspace = true
|
||||
wkt = { version = "0.11", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
approx = "0.5"
|
||||
ron = "0.7"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
tokio.workspace = true
|
||||
|
||||
@@ -12,6 +12,10 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
mod cos;
|
||||
mod dot;
|
||||
mod l2sq;
|
||||
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Display;
|
||||
use std::sync::Arc;
|
||||
@@ -21,14 +25,14 @@ use common_query::prelude::Signature;
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::value::ValueRef;
|
||||
use datatypes::vectors::{Float64VectorBuilder, MutableVector, Vector, VectorRef};
|
||||
use datatypes::vectors::{Float32VectorBuilder, MutableVector, Vector, VectorRef};
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::helper;
|
||||
|
||||
macro_rules! define_distance_function {
|
||||
($StructName:ident, $display_name:expr, $similarity_method:ident) => {
|
||||
($StructName:ident, $display_name:expr, $similarity_method:path) => {
|
||||
|
||||
/// A function calculates the distance between two vectors.
|
||||
|
||||
@@ -41,7 +45,7 @@ macro_rules! define_distance_function {
|
||||
}
|
||||
|
||||
fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
|
||||
Ok(ConcreteDataType::float64_datatype())
|
||||
Ok(ConcreteDataType::float32_datatype())
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
@@ -71,7 +75,7 @@ macro_rules! define_distance_function {
|
||||
let arg1 = &columns[1];
|
||||
|
||||
let size = arg0.len();
|
||||
let mut result = Float64VectorBuilder::with_capacity(size);
|
||||
let mut result = Float32VectorBuilder::with_capacity(size);
|
||||
if size == 0 {
|
||||
return Ok(result.to_vector());
|
||||
}
|
||||
@@ -101,9 +105,8 @@ macro_rules! define_distance_function {
|
||||
}
|
||||
);
|
||||
|
||||
let f = <f32 as simsimd::SpatialSimilarity>::$similarity_method;
|
||||
// Safe: checked if the length of the vectors match
|
||||
let d = f(vec0.as_ref(), vec1.as_ref()).unwrap();
|
||||
// Checked if the length of the vectors match
|
||||
let d = $similarity_method(vec0.as_ref(), vec1.as_ref());
|
||||
result.push(Some(d));
|
||||
} else {
|
||||
result.push_null();
|
||||
@@ -122,9 +125,9 @@ macro_rules! define_distance_function {
|
||||
}
|
||||
}
|
||||
|
||||
define_distance_function!(CosDistanceFunction, "cos_distance", cos);
|
||||
define_distance_function!(L2SqDistanceFunction, "l2sq_distance", l2sq);
|
||||
define_distance_function!(DotProductFunction, "dot_product", dot);
|
||||
define_distance_function!(CosDistanceFunction, "vec_cos_distance", cos::cos);
|
||||
define_distance_function!(L2SqDistanceFunction, "vec_l2sq_distance", l2sq::l2sq);
|
||||
define_distance_function!(DotProductFunction, "vec_dot_product", dot::dot);
|
||||
|
||||
/// Parse a vector value if the value is a constant string.
|
||||
fn parse_if_constant_string(arg: &Arc<dyn Vector>) -> Result<Option<Vec<f32>>> {
|
||||
@@ -148,7 +151,7 @@ fn as_vector(arg: ValueRef<'_>) -> Result<Option<Cow<'_, [f32]>>> {
|
||||
ConcreteDataType::Binary(_) => arg
|
||||
.as_binary()
|
||||
.unwrap() // Safe: checked if it is a binary
|
||||
.map(|bytes| Ok(Cow::Borrowed(binary_as_vector(bytes)?)))
|
||||
.map(binary_as_vector)
|
||||
.transpose(),
|
||||
ConcreteDataType::String(_) => arg
|
||||
.as_string()
|
||||
@@ -164,18 +167,28 @@ fn as_vector(arg: ValueRef<'_>) -> Result<Option<Cow<'_, [f32]>>> {
|
||||
}
|
||||
|
||||
/// Convert a u8 slice to a vector value.
|
||||
fn binary_as_vector(bytes: &[u8]) -> Result<&[f32]> {
|
||||
if bytes.len() % 4 != 0 {
|
||||
fn binary_as_vector(bytes: &[u8]) -> Result<Cow<'_, [f32]>> {
|
||||
if bytes.len() % std::mem::size_of::<f32>() != 0 {
|
||||
return InvalidFuncArgsSnafu {
|
||||
err_msg: format!("Invalid binary length of vector: {}", bytes.len()),
|
||||
}
|
||||
.fail();
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let num_floats = bytes.len() / 4;
|
||||
let floats: &[f32] = std::slice::from_raw_parts(bytes.as_ptr() as *const f32, num_floats);
|
||||
Ok(floats)
|
||||
if cfg!(target_endian = "little") {
|
||||
Ok(unsafe {
|
||||
let vec = std::slice::from_raw_parts(
|
||||
bytes.as_ptr() as *const f32,
|
||||
bytes.len() / std::mem::size_of::<f32>(),
|
||||
);
|
||||
Cow::Borrowed(vec)
|
||||
})
|
||||
} else {
|
||||
let v = bytes
|
||||
.chunks_exact(std::mem::size_of::<f32>())
|
||||
.map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap()))
|
||||
.collect::<Vec<f32>>();
|
||||
Ok(Cow::Owned(v))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -460,7 +473,7 @@ mod tests {
|
||||
fn test_binary_as_vector() {
|
||||
let bytes = [0, 0, 128, 63];
|
||||
let result = binary_as_vector(&bytes).unwrap();
|
||||
assert_eq!(result, &[1.0]);
|
||||
assert_eq!(result.as_ref(), &[1.0]);
|
||||
|
||||
let invalid_bytes = [0, 0, 128];
|
||||
let result = binary_as_vector(&invalid_bytes);
|
||||
|
||||
87
src/common/function/src/scalars/vector/distance/cos.rs
Normal file
87
src/common/function/src/scalars/vector/distance/cos.rs
Normal file
@@ -0,0 +1,87 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use nalgebra::DVectorView;
|
||||
|
||||
/// Calculates the cos distance between two vectors.
|
||||
///
|
||||
/// **Note:** Must ensure that the length of the two vectors are the same.
|
||||
pub fn cos(lhs: &[f32], rhs: &[f32]) -> f32 {
|
||||
let lhs_vec = DVectorView::from_slice(lhs, lhs.len());
|
||||
let rhs_vec = DVectorView::from_slice(rhs, rhs.len());
|
||||
|
||||
let dot_product = lhs_vec.dot(&rhs_vec);
|
||||
let lhs_norm = lhs_vec.norm();
|
||||
let rhs_norm = rhs_vec.norm();
|
||||
if dot_product.abs() < f32::EPSILON
|
||||
|| lhs_norm.abs() < f32::EPSILON
|
||||
|| rhs_norm.abs() < f32::EPSILON
|
||||
{
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
let cos_similar = dot_product / (lhs_norm * rhs_norm);
|
||||
let res = 1.0 - cos_similar;
|
||||
if res.abs() < f32::EPSILON {
|
||||
0.0
|
||||
} else {
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use approx::assert_relative_eq;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_cos_scalar() {
|
||||
let lhs = vec![1.0, 2.0, 3.0];
|
||||
let rhs = vec![1.0, 2.0, 3.0];
|
||||
assert_relative_eq!(cos(&lhs, &rhs), 0.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![1.0, 2.0, 3.0];
|
||||
let rhs = vec![4.0, 5.0, 6.0];
|
||||
assert_relative_eq!(cos(&lhs, &rhs), 0.025, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![1.0, 2.0, 3.0];
|
||||
let rhs = vec![7.0, 8.0, 9.0];
|
||||
assert_relative_eq!(cos(&lhs, &rhs), 0.04, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![0.0, 0.0, 0.0];
|
||||
let rhs = vec![1.0, 2.0, 3.0];
|
||||
assert_relative_eq!(cos(&lhs, &rhs), 1.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![0.0, 0.0, 0.0];
|
||||
let rhs = vec![4.0, 5.0, 6.0];
|
||||
assert_relative_eq!(cos(&lhs, &rhs), 1.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![0.0, 0.0, 0.0];
|
||||
let rhs = vec![7.0, 8.0, 9.0];
|
||||
assert_relative_eq!(cos(&lhs, &rhs), 1.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![7.0, 8.0, 9.0];
|
||||
let rhs = vec![1.0, 2.0, 3.0];
|
||||
assert_relative_eq!(cos(&lhs, &rhs), 0.04, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![7.0, 8.0, 9.0];
|
||||
let rhs = vec![4.0, 5.0, 6.0];
|
||||
assert_relative_eq!(cos(&lhs, &rhs), 0.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![7.0, 8.0, 9.0];
|
||||
let rhs = vec![7.0, 8.0, 9.0];
|
||||
assert_relative_eq!(cos(&lhs, &rhs), 0.0, epsilon = 1e-2);
|
||||
}
|
||||
}
|
||||
71
src/common/function/src/scalars/vector/distance/dot.rs
Normal file
71
src/common/function/src/scalars/vector/distance/dot.rs
Normal file
@@ -0,0 +1,71 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use nalgebra::DVectorView;
|
||||
|
||||
/// Calculates the dot product between two vectors.
|
||||
///
|
||||
/// **Note:** Must ensure that the length of the two vectors are the same.
|
||||
pub fn dot(lhs: &[f32], rhs: &[f32]) -> f32 {
|
||||
let lhs = DVectorView::from_slice(lhs, lhs.len());
|
||||
let rhs = DVectorView::from_slice(rhs, rhs.len());
|
||||
|
||||
lhs.dot(&rhs)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use approx::assert_relative_eq;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_dot_scalar() {
|
||||
let lhs = vec![1.0, 2.0, 3.0];
|
||||
let rhs = vec![1.0, 2.0, 3.0];
|
||||
assert_relative_eq!(dot(&lhs, &rhs), 14.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![1.0, 2.0, 3.0];
|
||||
let rhs = vec![4.0, 5.0, 6.0];
|
||||
assert_relative_eq!(dot(&lhs, &rhs), 32.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![1.0, 2.0, 3.0];
|
||||
let rhs = vec![7.0, 8.0, 9.0];
|
||||
assert_relative_eq!(dot(&lhs, &rhs), 50.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![0.0, 0.0, 0.0];
|
||||
let rhs = vec![1.0, 2.0, 3.0];
|
||||
assert_relative_eq!(dot(&lhs, &rhs), 0.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![0.0, 0.0, 0.0];
|
||||
let rhs = vec![4.0, 5.0, 6.0];
|
||||
assert_relative_eq!(dot(&lhs, &rhs), 0.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![0.0, 0.0, 0.0];
|
||||
let rhs = vec![7.0, 8.0, 9.0];
|
||||
assert_relative_eq!(dot(&lhs, &rhs), 0.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![7.0, 8.0, 9.0];
|
||||
let rhs = vec![1.0, 2.0, 3.0];
|
||||
assert_relative_eq!(dot(&lhs, &rhs), 50.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![7.0, 8.0, 9.0];
|
||||
let rhs = vec![4.0, 5.0, 6.0];
|
||||
assert_relative_eq!(dot(&lhs, &rhs), 122.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![7.0, 8.0, 9.0];
|
||||
let rhs = vec![7.0, 8.0, 9.0];
|
||||
assert_relative_eq!(dot(&lhs, &rhs), 194.0, epsilon = 1e-2);
|
||||
}
|
||||
}
|
||||
71
src/common/function/src/scalars/vector/distance/l2sq.rs
Normal file
71
src/common/function/src/scalars/vector/distance/l2sq.rs
Normal file
@@ -0,0 +1,71 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use nalgebra::DVectorView;
|
||||
|
||||
/// Calculates the squared L2 distance between two vectors.
|
||||
///
|
||||
/// **Note:** Must ensure that the length of the two vectors are the same.
|
||||
pub fn l2sq(lhs: &[f32], rhs: &[f32]) -> f32 {
|
||||
let lhs = DVectorView::from_slice(lhs, lhs.len());
|
||||
let rhs = DVectorView::from_slice(rhs, rhs.len());
|
||||
|
||||
(lhs - rhs).norm_squared()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use approx::assert_relative_eq;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_l2sq_scalar() {
|
||||
let lhs = vec![1.0, 2.0, 3.0];
|
||||
let rhs = vec![1.0, 2.0, 3.0];
|
||||
assert_relative_eq!(l2sq(&lhs, &rhs), 0.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![1.0, 2.0, 3.0];
|
||||
let rhs = vec![4.0, 5.0, 6.0];
|
||||
assert_relative_eq!(l2sq(&lhs, &rhs), 27.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![1.0, 2.0, 3.0];
|
||||
let rhs = vec![7.0, 8.0, 9.0];
|
||||
assert_relative_eq!(l2sq(&lhs, &rhs), 108.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![0.0, 0.0, 0.0];
|
||||
let rhs = vec![1.0, 2.0, 3.0];
|
||||
assert_relative_eq!(l2sq(&lhs, &rhs), 14.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![0.0, 0.0, 0.0];
|
||||
let rhs = vec![4.0, 5.0, 6.0];
|
||||
assert_relative_eq!(l2sq(&lhs, &rhs), 77.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![0.0, 0.0, 0.0];
|
||||
let rhs = vec![7.0, 8.0, 9.0];
|
||||
assert_relative_eq!(l2sq(&lhs, &rhs), 194.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![7.0, 8.0, 9.0];
|
||||
let rhs = vec![1.0, 2.0, 3.0];
|
||||
assert_relative_eq!(l2sq(&lhs, &rhs), 108.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![7.0, 8.0, 9.0];
|
||||
let rhs = vec![4.0, 5.0, 6.0];
|
||||
assert_relative_eq!(l2sq(&lhs, &rhs), 27.0, epsilon = 1e-2);
|
||||
|
||||
let lhs = vec![7.0, 8.0, 9.0];
|
||||
let rhs = vec![7.0, 8.0, 9.0];
|
||||
assert_relative_eq!(l2sq(&lhs, &rhs), 0.0, epsilon = 1e-2);
|
||||
}
|
||||
}
|
||||
@@ -85,15 +85,12 @@ pub fn vector_type_value_to_string(val: &[u8], dim: u32) -> Result<String> {
|
||||
return Ok("[]".to_string());
|
||||
}
|
||||
|
||||
let elements = unsafe {
|
||||
std::slice::from_raw_parts(
|
||||
val.as_ptr() as *const f32,
|
||||
val.len() / std::mem::size_of::<f32>(),
|
||||
)
|
||||
};
|
||||
let elements = val
|
||||
.chunks_exact(std::mem::size_of::<f32>())
|
||||
.map(|e| f32::from_le_bytes(e.try_into().unwrap()));
|
||||
|
||||
let mut s = String::from("[");
|
||||
for (i, e) in elements.iter().enumerate() {
|
||||
for (i, e) in elements.enumerate() {
|
||||
if i > 0 {
|
||||
s.push(',');
|
||||
}
|
||||
@@ -150,12 +147,19 @@ pub fn parse_string_to_vector_type_value(s: &str, dim: u32) -> Result<Vec<u8>> {
|
||||
}
|
||||
|
||||
// Convert Vec<f32> to Vec<u8>
|
||||
let bytes = unsafe {
|
||||
std::slice::from_raw_parts(
|
||||
elements.as_ptr() as *const u8,
|
||||
elements.len() * std::mem::size_of::<f32>(),
|
||||
)
|
||||
.to_vec()
|
||||
let bytes = if cfg!(target_endian = "little") {
|
||||
unsafe {
|
||||
std::slice::from_raw_parts(
|
||||
elements.as_ptr() as *const u8,
|
||||
elements.len() * std::mem::size_of::<f32>(),
|
||||
)
|
||||
.to_vec()
|
||||
}
|
||||
} else {
|
||||
elements
|
||||
.iter()
|
||||
.flat_map(|e| e.to_le_bytes())
|
||||
.collect::<Vec<u8>>()
|
||||
};
|
||||
|
||||
Ok(bytes)
|
||||
|
||||
@@ -31,17 +31,17 @@ SELECT * FROM t;
|
||||
| 1970-01-01 00:00:00.003000 | "[7,8,9]" |
|
||||
+----------------------------+-----------+
|
||||
|
||||
SELECT round(cos_distance(v, '[0.0, 0.0, 0.0]'), 2) FROM t;
|
||||
SELECT round(vec_cos_distance(v, '[0.0, 0.0, 0.0]'), 2) FROM t;
|
||||
|
||||
+-----------------------------------------------------------+
|
||||
| round(cos_distance(t.v,Utf8("[0.0, 0.0, 0.0]")),Int64(2)) |
|
||||
+-----------------------------------------------------------+
|
||||
| 1.0 |
|
||||
| 1.0 |
|
||||
| 1.0 |
|
||||
+-----------------------------------------------------------+
|
||||
+---------------------------------------------------------------+
|
||||
| round(vec_cos_distance(t.v,Utf8("[0.0, 0.0, 0.0]")),Int64(2)) |
|
||||
+---------------------------------------------------------------+
|
||||
| 1.0 |
|
||||
| 1.0 |
|
||||
| 1.0 |
|
||||
+---------------------------------------------------------------+
|
||||
|
||||
SELECT *, round(cos_distance(v, '[0.0, 0.0, 0.0]'), 2) as d FROM t ORDER BY d;
|
||||
SELECT *, round(vec_cos_distance(v, '[0.0, 0.0, 0.0]'), 2) as d FROM t ORDER BY d;
|
||||
|
||||
+-------------------------+--------------------------+-----+
|
||||
| ts | v | d |
|
||||
@@ -51,17 +51,17 @@ SELECT *, round(cos_distance(v, '[0.0, 0.0, 0.0]'), 2) as d FROM t ORDER BY d;
|
||||
| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 1.0 |
|
||||
+-------------------------+--------------------------+-----+
|
||||
|
||||
SELECT round(cos_distance('[7.0, 8.0, 9.0]', v), 2) FROM t;
|
||||
SELECT round(vec_cos_distance('[7.0, 8.0, 9.0]', v), 2) FROM t;
|
||||
|
||||
+-----------------------------------------------------------+
|
||||
| round(cos_distance(Utf8("[7.0, 8.0, 9.0]"),t.v),Int64(2)) |
|
||||
+-----------------------------------------------------------+
|
||||
| 0.04 |
|
||||
| 0.0 |
|
||||
| 0.0 |
|
||||
+-----------------------------------------------------------+
|
||||
+---------------------------------------------------------------+
|
||||
| round(vec_cos_distance(Utf8("[7.0, 8.0, 9.0]"),t.v),Int64(2)) |
|
||||
+---------------------------------------------------------------+
|
||||
| 0.04 |
|
||||
| 0.0 |
|
||||
| 0.0 |
|
||||
+---------------------------------------------------------------+
|
||||
|
||||
SELECT *, round(cos_distance('[7.0, 8.0, 9.0]', v), 2) as d FROM t ORDER BY d;
|
||||
SELECT *, round(vec_cos_distance('[7.0, 8.0, 9.0]', v), 2) as d FROM t ORDER BY d;
|
||||
|
||||
+-------------------------+--------------------------+------+
|
||||
| ts | v | d |
|
||||
@@ -71,37 +71,37 @@ SELECT *, round(cos_distance('[7.0, 8.0, 9.0]', v), 2) as d FROM t ORDER BY d;
|
||||
| 1970-01-01T00:00:00.001 | 0000803f0000004000004040 | 0.04 |
|
||||
+-------------------------+--------------------------+------+
|
||||
|
||||
SELECT round(cos_distance(v, v), 2) FROM t;
|
||||
SELECT round(vec_cos_distance(v, v), 2) FROM t;
|
||||
|
||||
+---------------------------------------+
|
||||
| round(cos_distance(t.v,t.v),Int64(2)) |
|
||||
+---------------------------------------+
|
||||
| 0.0 |
|
||||
| 0.0 |
|
||||
| 0.0 |
|
||||
+---------------------------------------+
|
||||
+-------------------------------------------+
|
||||
| round(vec_cos_distance(t.v,t.v),Int64(2)) |
|
||||
+-------------------------------------------+
|
||||
| 0.0 |
|
||||
| 0.0 |
|
||||
| 0.0 |
|
||||
+-------------------------------------------+
|
||||
|
||||
-- Unexpected dimension --
|
||||
SELECT cos_distance(v, '[1.0]') FROM t;
|
||||
SELECT vec_cos_distance(v, '[1.0]') FROM t;
|
||||
|
||||
Error: 3001(EngineExecuteQuery), Invalid function args: The length of the vectors must match to calculate distance, have: 3 vs 1
|
||||
|
||||
-- Invalid type --
|
||||
SELECT cos_distance(v, 1.0) FROM t;
|
||||
SELECT vec_cos_distance(v, 1.0) FROM t;
|
||||
|
||||
Error: 3001(EngineExecuteQuery), Invalid argument error: Encountered non UTF-8 data: invalid utf-8 sequence of 1 bytes from index 2
|
||||
|
||||
SELECT round(l2sq_distance(v, '[0.0, 0.0, 0.0]'), 2) FROM t;
|
||||
SELECT round(vec_l2sq_distance(v, '[0.0, 0.0, 0.0]'), 2) FROM t;
|
||||
|
||||
+------------------------------------------------------------+
|
||||
| round(l2sq_distance(t.v,Utf8("[0.0, 0.0, 0.0]")),Int64(2)) |
|
||||
+------------------------------------------------------------+
|
||||
| 14.0 |
|
||||
| 77.0 |
|
||||
| 194.0 |
|
||||
+------------------------------------------------------------+
|
||||
+----------------------------------------------------------------+
|
||||
| round(vec_l2sq_distance(t.v,Utf8("[0.0, 0.0, 0.0]")),Int64(2)) |
|
||||
+----------------------------------------------------------------+
|
||||
| 14.0 |
|
||||
| 77.0 |
|
||||
| 194.0 |
|
||||
+----------------------------------------------------------------+
|
||||
|
||||
SELECT *, round(l2sq_distance(v, '[0.0, 0.0, 0.0]'), 2) as d FROM t ORDER BY d;
|
||||
SELECT *, round(vec_l2sq_distance(v, '[0.0, 0.0, 0.0]'), 2) as d FROM t ORDER BY d;
|
||||
|
||||
+-------------------------+--------------------------+-------+
|
||||
| ts | v | d |
|
||||
@@ -111,17 +111,17 @@ SELECT *, round(l2sq_distance(v, '[0.0, 0.0, 0.0]'), 2) as d FROM t ORDER BY d;
|
||||
| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 194.0 |
|
||||
+-------------------------+--------------------------+-------+
|
||||
|
||||
SELECT round(l2sq_distance('[7.0, 8.0, 9.0]', v), 2) FROM t;
|
||||
SELECT round(vec_l2sq_distance('[7.0, 8.0, 9.0]', v), 2) FROM t;
|
||||
|
||||
+------------------------------------------------------------+
|
||||
| round(l2sq_distance(Utf8("[7.0, 8.0, 9.0]"),t.v),Int64(2)) |
|
||||
+------------------------------------------------------------+
|
||||
| 108.0 |
|
||||
| 27.0 |
|
||||
| 0.0 |
|
||||
+------------------------------------------------------------+
|
||||
+----------------------------------------------------------------+
|
||||
| round(vec_l2sq_distance(Utf8("[7.0, 8.0, 9.0]"),t.v),Int64(2)) |
|
||||
+----------------------------------------------------------------+
|
||||
| 108.0 |
|
||||
| 27.0 |
|
||||
| 0.0 |
|
||||
+----------------------------------------------------------------+
|
||||
|
||||
SELECT *, round(l2sq_distance('[7.0, 8.0, 9.0]', v), 2) as d FROM t ORDER BY d;
|
||||
SELECT *, round(vec_l2sq_distance('[7.0, 8.0, 9.0]', v), 2) as d FROM t ORDER BY d;
|
||||
|
||||
+-------------------------+--------------------------+-------+
|
||||
| ts | v | d |
|
||||
@@ -131,37 +131,37 @@ SELECT *, round(l2sq_distance('[7.0, 8.0, 9.0]', v), 2) as d FROM t ORDER BY d;
|
||||
| 1970-01-01T00:00:00.001 | 0000803f0000004000004040 | 108.0 |
|
||||
+-------------------------+--------------------------+-------+
|
||||
|
||||
SELECT round(l2sq_distance(v, v), 2) FROM t;
|
||||
SELECT round(vec_l2sq_distance(v, v), 2) FROM t;
|
||||
|
||||
+----------------------------------------+
|
||||
| round(l2sq_distance(t.v,t.v),Int64(2)) |
|
||||
+----------------------------------------+
|
||||
| 0.0 |
|
||||
| 0.0 |
|
||||
| 0.0 |
|
||||
+----------------------------------------+
|
||||
+--------------------------------------------+
|
||||
| round(vec_l2sq_distance(t.v,t.v),Int64(2)) |
|
||||
+--------------------------------------------+
|
||||
| 0.0 |
|
||||
| 0.0 |
|
||||
| 0.0 |
|
||||
+--------------------------------------------+
|
||||
|
||||
-- Unexpected dimension --
|
||||
SELECT l2sq_distance(v, '[1.0]') FROM t;
|
||||
SELECT vec_l2sq_distance(v, '[1.0]') FROM t;
|
||||
|
||||
Error: 3001(EngineExecuteQuery), Invalid function args: The length of the vectors must match to calculate distance, have: 3 vs 1
|
||||
|
||||
-- Invalid type --
|
||||
SELECT l2sq_distance(v, 1.0) FROM t;
|
||||
SELECT vec_l2sq_distance(v, 1.0) FROM t;
|
||||
|
||||
Error: 3001(EngineExecuteQuery), Invalid argument error: Encountered non UTF-8 data: invalid utf-8 sequence of 1 bytes from index 2
|
||||
|
||||
SELECT round(dot_product(v, '[0.0, 0.0, 0.0]'), 2) FROM t;
|
||||
SELECT round(vec_dot_product(v, '[0.0, 0.0, 0.0]'), 2) FROM t;
|
||||
|
||||
+----------------------------------------------------------+
|
||||
| round(dot_product(t.v,Utf8("[0.0, 0.0, 0.0]")),Int64(2)) |
|
||||
+----------------------------------------------------------+
|
||||
| 0.0 |
|
||||
| 0.0 |
|
||||
| 0.0 |
|
||||
+----------------------------------------------------------+
|
||||
+--------------------------------------------------------------+
|
||||
| round(vec_dot_product(t.v,Utf8("[0.0, 0.0, 0.0]")),Int64(2)) |
|
||||
+--------------------------------------------------------------+
|
||||
| 0.0 |
|
||||
| 0.0 |
|
||||
| 0.0 |
|
||||
+--------------------------------------------------------------+
|
||||
|
||||
SELECT *, round(dot_product(v, '[0.0, 0.0, 0.0]'), 2) as d FROM t ORDER BY d;
|
||||
SELECT *, round(vec_dot_product(v, '[0.0, 0.0, 0.0]'), 2) as d FROM t ORDER BY d;
|
||||
|
||||
+-------------------------+--------------------------+-----+
|
||||
| ts | v | d |
|
||||
@@ -171,17 +171,17 @@ SELECT *, round(dot_product(v, '[0.0, 0.0, 0.0]'), 2) as d FROM t ORDER BY d;
|
||||
| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 0.0 |
|
||||
+-------------------------+--------------------------+-----+
|
||||
|
||||
SELECT round(dot_product('[7.0, 8.0, 9.0]', v), 2) FROM t;
|
||||
SELECT round(vec_dot_product('[7.0, 8.0, 9.0]', v), 2) FROM t;
|
||||
|
||||
+----------------------------------------------------------+
|
||||
| round(dot_product(Utf8("[7.0, 8.0, 9.0]"),t.v),Int64(2)) |
|
||||
+----------------------------------------------------------+
|
||||
| 50.0 |
|
||||
| 122.0 |
|
||||
| 194.0 |
|
||||
+----------------------------------------------------------+
|
||||
+--------------------------------------------------------------+
|
||||
| round(vec_dot_product(Utf8("[7.0, 8.0, 9.0]"),t.v),Int64(2)) |
|
||||
+--------------------------------------------------------------+
|
||||
| 50.0 |
|
||||
| 122.0 |
|
||||
| 194.0 |
|
||||
+--------------------------------------------------------------+
|
||||
|
||||
SELECT *, round(dot_product('[7.0, 8.0, 9.0]', v), 2) as d FROM t ORDER BY d;
|
||||
SELECT *, round(vec_dot_product('[7.0, 8.0, 9.0]', v), 2) as d FROM t ORDER BY d;
|
||||
|
||||
+-------------------------+--------------------------+-------+
|
||||
| ts | v | d |
|
||||
@@ -191,23 +191,23 @@ SELECT *, round(dot_product('[7.0, 8.0, 9.0]', v), 2) as d FROM t ORDER BY d;
|
||||
| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 194.0 |
|
||||
+-------------------------+--------------------------+-------+
|
||||
|
||||
SELECT round(dot_product(v, v), 2) FROM t;
|
||||
SELECT round(vec_dot_product(v, v), 2) FROM t;
|
||||
|
||||
+--------------------------------------+
|
||||
| round(dot_product(t.v,t.v),Int64(2)) |
|
||||
+--------------------------------------+
|
||||
| 14.0 |
|
||||
| 77.0 |
|
||||
| 194.0 |
|
||||
+--------------------------------------+
|
||||
+------------------------------------------+
|
||||
| round(vec_dot_product(t.v,t.v),Int64(2)) |
|
||||
+------------------------------------------+
|
||||
| 14.0 |
|
||||
| 77.0 |
|
||||
| 194.0 |
|
||||
+------------------------------------------+
|
||||
|
||||
-- Unexpected dimension --
|
||||
SELECT dot_product(v, '[1.0]') FROM t;
|
||||
SELECT vec_dot_product(v, '[1.0]') FROM t;
|
||||
|
||||
Error: 3001(EngineExecuteQuery), Invalid function args: The length of the vectors must match to calculate distance, have: 3 vs 1
|
||||
|
||||
-- Invalid type --
|
||||
SELECT dot_product(v, 1.0) FROM t;
|
||||
SELECT vec_dot_product(v, 1.0) FROM t;
|
||||
|
||||
Error: 3001(EngineExecuteQuery), Invalid argument error: Encountered non UTF-8 data: invalid utf-8 sequence of 1 bytes from index 2
|
||||
|
||||
|
||||
@@ -11,54 +11,54 @@ SELECT * FROM t;
|
||||
-- SQLNESS PROTOCOL POSTGRES
|
||||
SELECT * FROM t;
|
||||
|
||||
SELECT round(cos_distance(v, '[0.0, 0.0, 0.0]'), 2) FROM t;
|
||||
SELECT round(vec_cos_distance(v, '[0.0, 0.0, 0.0]'), 2) FROM t;
|
||||
|
||||
SELECT *, round(cos_distance(v, '[0.0, 0.0, 0.0]'), 2) as d FROM t ORDER BY d;
|
||||
SELECT *, round(vec_cos_distance(v, '[0.0, 0.0, 0.0]'), 2) as d FROM t ORDER BY d;
|
||||
|
||||
SELECT round(cos_distance('[7.0, 8.0, 9.0]', v), 2) FROM t;
|
||||
SELECT round(vec_cos_distance('[7.0, 8.0, 9.0]', v), 2) FROM t;
|
||||
|
||||
SELECT *, round(cos_distance('[7.0, 8.0, 9.0]', v), 2) as d FROM t ORDER BY d;
|
||||
SELECT *, round(vec_cos_distance('[7.0, 8.0, 9.0]', v), 2) as d FROM t ORDER BY d;
|
||||
|
||||
SELECT round(cos_distance(v, v), 2) FROM t;
|
||||
SELECT round(vec_cos_distance(v, v), 2) FROM t;
|
||||
|
||||
-- Unexpected dimension --
|
||||
SELECT cos_distance(v, '[1.0]') FROM t;
|
||||
SELECT vec_cos_distance(v, '[1.0]') FROM t;
|
||||
|
||||
-- Invalid type --
|
||||
SELECT cos_distance(v, 1.0) FROM t;
|
||||
SELECT vec_cos_distance(v, 1.0) FROM t;
|
||||
|
||||
SELECT round(l2sq_distance(v, '[0.0, 0.0, 0.0]'), 2) FROM t;
|
||||
SELECT round(vec_l2sq_distance(v, '[0.0, 0.0, 0.0]'), 2) FROM t;
|
||||
|
||||
SELECT *, round(l2sq_distance(v, '[0.0, 0.0, 0.0]'), 2) as d FROM t ORDER BY d;
|
||||
SELECT *, round(vec_l2sq_distance(v, '[0.0, 0.0, 0.0]'), 2) as d FROM t ORDER BY d;
|
||||
|
||||
SELECT round(l2sq_distance('[7.0, 8.0, 9.0]', v), 2) FROM t;
|
||||
SELECT round(vec_l2sq_distance('[7.0, 8.0, 9.0]', v), 2) FROM t;
|
||||
|
||||
SELECT *, round(l2sq_distance('[7.0, 8.0, 9.0]', v), 2) as d FROM t ORDER BY d;
|
||||
SELECT *, round(vec_l2sq_distance('[7.0, 8.0, 9.0]', v), 2) as d FROM t ORDER BY d;
|
||||
|
||||
SELECT round(l2sq_distance(v, v), 2) FROM t;
|
||||
SELECT round(vec_l2sq_distance(v, v), 2) FROM t;
|
||||
|
||||
-- Unexpected dimension --
|
||||
SELECT l2sq_distance(v, '[1.0]') FROM t;
|
||||
SELECT vec_l2sq_distance(v, '[1.0]') FROM t;
|
||||
|
||||
-- Invalid type --
|
||||
SELECT l2sq_distance(v, 1.0) FROM t;
|
||||
SELECT vec_l2sq_distance(v, 1.0) FROM t;
|
||||
|
||||
|
||||
SELECT round(dot_product(v, '[0.0, 0.0, 0.0]'), 2) FROM t;
|
||||
SELECT round(vec_dot_product(v, '[0.0, 0.0, 0.0]'), 2) FROM t;
|
||||
|
||||
SELECT *, round(dot_product(v, '[0.0, 0.0, 0.0]'), 2) as d FROM t ORDER BY d;
|
||||
SELECT *, round(vec_dot_product(v, '[0.0, 0.0, 0.0]'), 2) as d FROM t ORDER BY d;
|
||||
|
||||
SELECT round(dot_product('[7.0, 8.0, 9.0]', v), 2) FROM t;
|
||||
SELECT round(vec_dot_product('[7.0, 8.0, 9.0]', v), 2) FROM t;
|
||||
|
||||
SELECT *, round(dot_product('[7.0, 8.0, 9.0]', v), 2) as d FROM t ORDER BY d;
|
||||
SELECT *, round(vec_dot_product('[7.0, 8.0, 9.0]', v), 2) as d FROM t ORDER BY d;
|
||||
|
||||
SELECT round(dot_product(v, v), 2) FROM t;
|
||||
SELECT round(vec_dot_product(v, v), 2) FROM t;
|
||||
|
||||
-- Unexpected dimension --
|
||||
SELECT dot_product(v, '[1.0]') FROM t;
|
||||
SELECT vec_dot_product(v, '[1.0]') FROM t;
|
||||
|
||||
-- Invalid type --
|
||||
SELECT dot_product(v, 1.0) FROM t;
|
||||
SELECT vec_dot_product(v, 1.0) FROM t;
|
||||
|
||||
-- Unexpected dimension --
|
||||
INSERT INTO t VALUES
|
||||
|
||||
Reference in New Issue
Block a user