feat: add distance functions (#4987)

* feat: add distance functions

Signed-off-by: Zhenchi <zhongzc_arch@outlook.com>

* fix: f64 instead

Signed-off-by: Zhenchi <zhongzc_arch@outlook.com>

* address comments

Signed-off-by: Zhenchi <zhongzc_arch@outlook.com>

* tiny adjust

Signed-off-by: Zhenchi <zhongzc_arch@outlook.com>

---------

Signed-off-by: Zhenchi <zhongzc_arch@outlook.com>
This commit is contained in:
Zhenchi
2024-11-14 18:18:58 +08:00
committed by GitHub
parent 22c8a7656b
commit 408013c22b
9 changed files with 755 additions and 10 deletions

24
Cargo.lock generated
View File

@@ -1041,7 +1041,7 @@ dependencies = [
"bitflags 2.6.0",
"cexpr",
"clang-sys",
"itertools 0.10.5",
"itertools 0.12.1",
"lazy_static",
"lazycell",
"log",
@@ -2088,6 +2088,7 @@ dependencies = [
"serde",
"serde_json",
"session",
"simsimd",
"snafu 0.8.5",
"sql",
"statrs",
@@ -5090,7 +5091,7 @@ dependencies = [
"httpdate",
"itoa",
"pin-project-lite",
"socket2 0.4.10",
"socket2 0.5.7",
"tokio",
"tower-service",
"tracing",
@@ -6080,7 +6081,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4"
dependencies = [
"cfg-if",
"windows-targets 0.48.5",
"windows-targets 0.52.6",
]
[[package]]
@@ -8822,7 +8823,7 @@ checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4"
dependencies = [
"bytes",
"heck 0.5.0",
"itertools 0.10.5",
"itertools 0.12.1",
"log",
"multimap",
"once_cell",
@@ -8874,7 +8875,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81bddcdb20abf9501610992b6759a4c888aef7d1a7247ef75e2404275ac24af1"
dependencies = [
"anyhow",
"itertools 0.10.5",
"itertools 0.12.1",
"proc-macro2",
"quote",
"syn 2.0.79",
@@ -9036,7 +9037,7 @@ dependencies = [
"indoc",
"libc",
"memoffset 0.9.1",
"parking_lot 0.11.2",
"parking_lot 0.12.3",
"portable-atomic",
"pyo3-build-config",
"pyo3-ffi",
@@ -11198,6 +11199,15 @@ dependencies = [
"time",
]
[[package]]
name = "simsimd"
version = "4.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "efc843bc8f12d9c8e6b734a0fe8918fc497b42f6ae0f347dbfdad5b5138ab9b4"
dependencies = [
"cc",
]
[[package]]
name = "siphasher"
version = "0.3.11"
@@ -13981,7 +13991,7 @@ version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb"
dependencies = [
"windows-sys 0.48.0",
"windows-sys 0.59.0",
]
[[package]]

View File

@@ -41,6 +41,7 @@ s2 = { version = "0.0.12", optional = true }
serde.workspace = true
serde_json.workspace = true
session.workspace = true
simsimd = "4"
snafu.workspace = true
sql.workspace = true
statrs = "0.16"

View File

@@ -27,6 +27,7 @@ use crate::scalars::matches::MatchesFunction;
use crate::scalars::math::MathFunction;
use crate::scalars::numpy::NumpyFunction;
use crate::scalars::timestamp::TimestampFunction;
use crate::scalars::vector::VectorFunction;
use crate::system::SystemFunction;
use crate::table::TableFunction;
@@ -120,6 +121,9 @@ pub static FUNCTION_REGISTRY: Lazy<Arc<FunctionRegistry>> = Lazy::new(|| {
// Json related functions
JsonFunction::register(&function_registry);
// Vector related functions
VectorFunction::register(&function_registry);
// Geo functions
#[cfg(feature = "geo")]
crate::scalars::geo::GeoFunctions::register(&function_registry);

View File

@@ -21,6 +21,7 @@ pub mod json;
pub mod matches;
pub mod math;
pub mod numpy;
pub mod vector;
#[cfg(test)]
pub(crate) mod test;

View File

@@ -0,0 +1,31 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
mod distance;
use std::sync::Arc;
use distance::{CosDistanceFunction, DotProductFunction, L2SqDistanceFunction};
use crate::function_registry::FunctionRegistry;
pub(crate) struct VectorFunction;
impl VectorFunction {
pub fn register(registry: &FunctionRegistry) {
registry.register(Arc::new(CosDistanceFunction));
registry.register(Arc::new(DotProductFunction));
registry.register(Arc::new(L2SqDistanceFunction));
}
}

View File

@@ -0,0 +1,469 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::borrow::Cow;
use std::fmt::Display;
use std::sync::Arc;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use common_query::prelude::Signature;
use datatypes::prelude::ConcreteDataType;
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::value::ValueRef;
use datatypes::vectors::{Float64VectorBuilder, MutableVector, Vector, VectorRef};
use snafu::ensure;
use crate::function::{Function, FunctionContext};
use crate::helper;
macro_rules! define_distance_function {
($StructName:ident, $display_name:expr, $similarity_method:ident) => {
/// A function calculates the distance between two vectors.
#[derive(Debug, Clone, Default)]
pub struct $StructName;
impl Function for $StructName {
fn name(&self) -> &str {
$display_name
}
fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
Ok(ConcreteDataType::float64_datatype())
}
fn signature(&self) -> Signature {
helper::one_of_sigs2(
vec![
ConcreteDataType::string_datatype(),
ConcreteDataType::binary_datatype(),
],
vec![
ConcreteDataType::string_datatype(),
ConcreteDataType::binary_datatype(),
],
)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly two, have: {}",
columns.len()
),
}
);
let arg0 = &columns[0];
let arg1 = &columns[1];
let size = arg0.len();
let mut result = Float64VectorBuilder::with_capacity(size);
if size == 0 {
return Ok(result.to_vector());
}
let arg0_const = parse_if_constant_string(arg0)?;
let arg1_const = parse_if_constant_string(arg1)?;
for i in 0..size {
let vec0 = match arg0_const.as_ref() {
Some(a) => Some(Cow::Borrowed(a.as_slice())),
None => as_vector(arg0.get_ref(i))?,
};
let vec1 = match arg1_const.as_ref() {
Some(b) => Some(Cow::Borrowed(b.as_slice())),
None => as_vector(arg1.get_ref(i))?,
};
if let (Some(vec0), Some(vec1)) = (vec0, vec1) {
ensure!(
vec0.len() == vec1.len(),
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the vectors must match to calculate distance, have: {} vs {}",
vec0.len(),
vec1.len()
),
}
);
let f = <f32 as simsimd::SpatialSimilarity>::$similarity_method;
// Safe: checked if the length of the vectors match
let d = f(vec0.as_ref(), vec1.as_ref()).unwrap();
result.push(Some(d));
} else {
result.push_null();
}
}
return Ok(result.to_vector());
}
}
impl Display for $StructName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", $display_name.to_ascii_uppercase())
}
}
}
}
define_distance_function!(CosDistanceFunction, "cos_distance", cos);
define_distance_function!(L2SqDistanceFunction, "l2sq_distance", l2sq);
define_distance_function!(DotProductFunction, "dot_product", dot);
/// Parse a vector value if the value is a constant string.
fn parse_if_constant_string(arg: &Arc<dyn Vector>) -> Result<Option<Vec<f32>>> {
if !arg.is_const() {
return Ok(None);
}
if arg.data_type() != ConcreteDataType::string_datatype() {
return Ok(None);
}
arg.get_ref(0)
.as_string()
.unwrap() // Safe: checked if it is a string
.map(parse_f32_vector_from_string)
.transpose()
}
/// Convert a value to a vector value.
/// Supported data types are binary and string.
fn as_vector(arg: ValueRef<'_>) -> Result<Option<Cow<'_, [f32]>>> {
match arg.data_type() {
ConcreteDataType::Binary(_) => arg
.as_binary()
.unwrap() // Safe: checked if it is a binary
.map(|bytes| Ok(Cow::Borrowed(binary_as_vector(bytes)?)))
.transpose(),
ConcreteDataType::String(_) => arg
.as_string()
.unwrap() // Safe: checked if it is a string
.map(|s| Ok(Cow::Owned(parse_f32_vector_from_string(s)?)))
.transpose(),
ConcreteDataType::Null(_) => Ok(None),
_ => InvalidFuncArgsSnafu {
err_msg: format!("Unsupported data type: {:?}", arg.data_type()),
}
.fail(),
}
}
/// Convert a u8 slice to a vector value.
fn binary_as_vector(bytes: &[u8]) -> Result<&[f32]> {
if bytes.len() % 4 != 0 {
return InvalidFuncArgsSnafu {
err_msg: format!("Invalid binary length of vector: {}", bytes.len()),
}
.fail();
}
unsafe {
let num_floats = bytes.len() / 4;
let floats: &[f32] = std::slice::from_raw_parts(bytes.as_ptr() as *const f32, num_floats);
Ok(floats)
}
}
/// Parse a string to a vector value.
/// Valid inputs are strings like "[1.0, 2.0, 3.0]".
fn parse_f32_vector_from_string(s: &str) -> Result<Vec<f32>> {
let trimmed = s.trim();
if !trimmed.starts_with('[') || !trimmed.ends_with(']') {
return InvalidFuncArgsSnafu {
err_msg: format!(
"Failed to parse {s} to Vector value: not properly enclosed in brackets"
),
}
.fail();
}
let content = trimmed[1..trimmed.len() - 1].trim();
if content.is_empty() {
return Ok(Vec::new());
}
content
.split(',')
.map(|s| s.trim().parse::<f32>())
.collect::<std::result::Result<_, _>>()
.map_err(|e| {
InvalidFuncArgsSnafu {
err_msg: format!("Failed to parse {s} to Vector value: {e}"),
}
.build()
})
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use datatypes::vectors::{BinaryVector, ConstantVector, StringVector};
use super::*;
#[test]
fn test_distance_string_string() {
let funcs = [
Box::new(CosDistanceFunction {}) as Box<dyn Function>,
Box::new(L2SqDistanceFunction {}) as Box<dyn Function>,
Box::new(DotProductFunction {}) as Box<dyn Function>,
];
for func in funcs {
let vec1 = Arc::new(StringVector::from(vec![
Some("[0.0, 1.0]"),
Some("[1.0, 0.0]"),
None,
Some("[1.0, 0.0]"),
])) as VectorRef;
let vec2 = Arc::new(StringVector::from(vec![
Some("[0.0, 1.0]"),
Some("[0.0, 1.0]"),
Some("[0.0, 1.0]"),
None,
])) as VectorRef;
let result = func
.eval(FunctionContext::default(), &[vec1.clone(), vec2.clone()])
.unwrap();
assert!(!result.get(0).is_null());
assert!(!result.get(1).is_null());
assert!(result.get(2).is_null());
assert!(result.get(3).is_null());
let result = func
.eval(FunctionContext::default(), &[vec2, vec1])
.unwrap();
assert!(!result.get(0).is_null());
assert!(!result.get(1).is_null());
assert!(result.get(2).is_null());
assert!(result.get(3).is_null());
}
}
#[test]
fn test_distance_binary_binary() {
let funcs = [
Box::new(CosDistanceFunction {}) as Box<dyn Function>,
Box::new(L2SqDistanceFunction {}) as Box<dyn Function>,
Box::new(DotProductFunction {}) as Box<dyn Function>,
];
for func in funcs {
let vec1 = Arc::new(BinaryVector::from(vec![
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
Some(vec![0, 0, 128, 63, 0, 0, 0, 0]),
None,
Some(vec![0, 0, 128, 63, 0, 0, 0, 0]),
])) as VectorRef;
let vec2 = Arc::new(BinaryVector::from(vec![
// [0.0, 1.0]
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
None,
])) as VectorRef;
let result = func
.eval(FunctionContext::default(), &[vec1.clone(), vec2.clone()])
.unwrap();
assert!(!result.get(0).is_null());
assert!(!result.get(1).is_null());
assert!(result.get(2).is_null());
assert!(result.get(3).is_null());
let result = func
.eval(FunctionContext::default(), &[vec2, vec1])
.unwrap();
assert!(!result.get(0).is_null());
assert!(!result.get(1).is_null());
assert!(result.get(2).is_null());
assert!(result.get(3).is_null());
}
}
#[test]
fn test_distance_string_binary() {
let funcs = [
Box::new(CosDistanceFunction {}) as Box<dyn Function>,
Box::new(L2SqDistanceFunction {}) as Box<dyn Function>,
Box::new(DotProductFunction {}) as Box<dyn Function>,
];
for func in funcs {
let vec1 = Arc::new(StringVector::from(vec![
Some("[0.0, 1.0]"),
Some("[1.0, 0.0]"),
None,
Some("[1.0, 0.0]"),
])) as VectorRef;
let vec2 = Arc::new(BinaryVector::from(vec![
// [0.0, 1.0]
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
None,
])) as VectorRef;
let result = func
.eval(FunctionContext::default(), &[vec1.clone(), vec2.clone()])
.unwrap();
assert!(!result.get(0).is_null());
assert!(!result.get(1).is_null());
assert!(result.get(2).is_null());
assert!(result.get(3).is_null());
let result = func
.eval(FunctionContext::default(), &[vec2, vec1])
.unwrap();
assert!(!result.get(0).is_null());
assert!(!result.get(1).is_null());
assert!(result.get(2).is_null());
assert!(result.get(3).is_null());
}
}
#[test]
fn test_distance_const_string() {
let funcs = [
Box::new(CosDistanceFunction {}) as Box<dyn Function>,
Box::new(L2SqDistanceFunction {}) as Box<dyn Function>,
Box::new(DotProductFunction {}) as Box<dyn Function>,
];
for func in funcs {
let const_str = Arc::new(ConstantVector::new(
Arc::new(StringVector::from(vec!["[0.0, 1.0]"])),
4,
));
let vec1 = Arc::new(StringVector::from(vec![
Some("[0.0, 1.0]"),
Some("[1.0, 0.0]"),
None,
Some("[1.0, 0.0]"),
])) as VectorRef;
let vec2 = Arc::new(BinaryVector::from(vec![
// [0.0, 1.0]
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
None,
])) as VectorRef;
let result = func
.eval(
FunctionContext::default(),
&[const_str.clone(), vec1.clone()],
)
.unwrap();
assert!(!result.get(0).is_null());
assert!(!result.get(1).is_null());
assert!(result.get(2).is_null());
assert!(!result.get(3).is_null());
let result = func
.eval(
FunctionContext::default(),
&[vec1.clone(), const_str.clone()],
)
.unwrap();
assert!(!result.get(0).is_null());
assert!(!result.get(1).is_null());
assert!(result.get(2).is_null());
assert!(!result.get(3).is_null());
let result = func
.eval(
FunctionContext::default(),
&[const_str.clone(), vec2.clone()],
)
.unwrap();
assert!(!result.get(0).is_null());
assert!(!result.get(1).is_null());
assert!(!result.get(2).is_null());
assert!(result.get(3).is_null());
let result = func
.eval(
FunctionContext::default(),
&[vec2.clone(), const_str.clone()],
)
.unwrap();
assert!(!result.get(0).is_null());
assert!(!result.get(1).is_null());
assert!(!result.get(2).is_null());
assert!(result.get(3).is_null());
}
}
#[test]
fn test_invalid_vector_length() {
let funcs = [
Box::new(CosDistanceFunction {}) as Box<dyn Function>,
Box::new(L2SqDistanceFunction {}) as Box<dyn Function>,
Box::new(DotProductFunction {}) as Box<dyn Function>,
];
for func in funcs {
let vec1 = Arc::new(StringVector::from(vec!["[1.0]"])) as VectorRef;
let vec2 = Arc::new(StringVector::from(vec!["[1.0, 1.0]"])) as VectorRef;
let result = func.eval(FunctionContext::default(), &[vec1, vec2]);
assert!(result.is_err());
let vec1 = Arc::new(BinaryVector::from(vec![vec![0, 0, 128, 63]])) as VectorRef;
let vec2 =
Arc::new(BinaryVector::from(vec![vec![0, 0, 128, 63, 0, 0, 0, 64]])) as VectorRef;
let result = func.eval(FunctionContext::default(), &[vec1, vec2]);
assert!(result.is_err());
}
}
#[test]
fn test_parse_vector_from_string() {
let result = parse_f32_vector_from_string("[1.0, 2.0, 3.0]").unwrap();
assert_eq!(result, vec![1.0, 2.0, 3.0]);
let result = parse_f32_vector_from_string("[]").unwrap();
assert_eq!(result, Vec::<f32>::new());
let result = parse_f32_vector_from_string("[1.0, a, 3.0]");
assert!(result.is_err());
}
#[test]
fn test_binary_as_vector() {
let bytes = [0, 0, 128, 63];
let result = binary_as_vector(&bytes).unwrap();
assert_eq!(result, &[1.0]);
let invalid_bytes = [0, 0, 128];
let result = binary_as_vector(&invalid_bytes);
assert!(result.is_err());
}
}

View File

@@ -1089,7 +1089,7 @@ macro_rules! impl_as_for_value_ref {
};
}
impl ValueRef<'_> {
impl<'a> ValueRef<'a> {
define_data_type_func!(ValueRef);
/// Returns true if this is null.
@@ -1098,12 +1098,12 @@ impl ValueRef<'_> {
}
/// Cast itself to binary slice.
pub fn as_binary(&self) -> Result<Option<&[u8]>> {
pub fn as_binary(&self) -> Result<Option<&'a [u8]>> {
impl_as_for_value_ref!(self, Binary)
}
/// Cast itself to string slice.
pub fn as_string(&self) -> Result<Option<&str>> {
pub fn as_string(&self) -> Result<Option<&'a str>> {
impl_as_for_value_ref!(self, String)
}

View File

@@ -31,6 +31,186 @@ SELECT * FROM t;
| 1970-01-01 00:00:00.003000 | "[7,8,9]" |
+----------------------------+-----------+
SELECT round(cos_distance(v, '[0.0, 0.0, 0.0]'), 4) FROM t;
+-----------------------------------------------------------+
| round(cos_distance(t.v,Utf8("[0.0, 0.0, 0.0]")),Int64(4)) |
+-----------------------------------------------------------+
| 1.0 |
| 1.0 |
| 1.0 |
+-----------------------------------------------------------+
SELECT *, round(cos_distance(v, '[0.0, 0.0, 0.0]'), 4) as d FROM t ORDER BY d;
+-------------------------+--------------------------+-----+
| ts | v | d |
+-------------------------+--------------------------+-----+
| 1970-01-01T00:00:00.001 | 0000803f0000004000004040 | 1.0 |
| 1970-01-01T00:00:00.002 | 000080400000a0400000c040 | 1.0 |
| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 1.0 |
+-------------------------+--------------------------+-----+
SELECT round(cos_distance('[7.0, 8.0, 9.0]', v), 4) FROM t;
+-----------------------------------------------------------+
| round(cos_distance(Utf8("[7.0, 8.0, 9.0]"),t.v),Int64(4)) |
+-----------------------------------------------------------+
| 0.0406 |
| 0.0018 |
| 0.0 |
+-----------------------------------------------------------+
SELECT *, round(cos_distance('[7.0, 8.0, 9.0]', v), 4) as d FROM t ORDER BY d;
+-------------------------+--------------------------+--------+
| ts | v | d |
+-------------------------+--------------------------+--------+
| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 0.0 |
| 1970-01-01T00:00:00.002 | 000080400000a0400000c040 | 0.0018 |
| 1970-01-01T00:00:00.001 | 0000803f0000004000004040 | 0.0406 |
+-------------------------+--------------------------+--------+
SELECT round(cos_distance(v, v), 4) FROM t;
+---------------------------------------+
| round(cos_distance(t.v,t.v),Int64(4)) |
+---------------------------------------+
| 0.0 |
| 0.0 |
| 0.0 |
+---------------------------------------+
-- Unexpected dimension --
SELECT cos_distance(v, '[1.0]') FROM t;
Error: 3001(EngineExecuteQuery), Invalid function args: The length of the vectors must match to calculate distance, have: 3 vs 1
-- Invalid type --
SELECT cos_distance(v, 1.0) FROM t;
Error: 3001(EngineExecuteQuery), Invalid argument error: Encountered non UTF-8 data: invalid utf-8 sequence of 1 bytes from index 2
SELECT round(l2sq_distance(v, '[0.0, 0.0, 0.0]'), 4) FROM t;
+------------------------------------------------------------+
| round(l2sq_distance(t.v,Utf8("[0.0, 0.0, 0.0]")),Int64(4)) |
+------------------------------------------------------------+
| 14.0 |
| 77.0 |
| 194.0 |
+------------------------------------------------------------+
SELECT *, round(l2sq_distance(v, '[0.0, 0.0, 0.0]'), 4) as d FROM t ORDER BY d;
+-------------------------+--------------------------+-------+
| ts | v | d |
+-------------------------+--------------------------+-------+
| 1970-01-01T00:00:00.001 | 0000803f0000004000004040 | 14.0 |
| 1970-01-01T00:00:00.002 | 000080400000a0400000c040 | 77.0 |
| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 194.0 |
+-------------------------+--------------------------+-------+
SELECT round(l2sq_distance('[7.0, 8.0, 9.0]', v), 4) FROM t;
+------------------------------------------------------------+
| round(l2sq_distance(Utf8("[7.0, 8.0, 9.0]"),t.v),Int64(4)) |
+------------------------------------------------------------+
| 108.0 |
| 27.0 |
| 0.0 |
+------------------------------------------------------------+
SELECT *, round(l2sq_distance('[7.0, 8.0, 9.0]', v), 4) as d FROM t ORDER BY d;
+-------------------------+--------------------------+-------+
| ts | v | d |
+-------------------------+--------------------------+-------+
| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 0.0 |
| 1970-01-01T00:00:00.002 | 000080400000a0400000c040 | 27.0 |
| 1970-01-01T00:00:00.001 | 0000803f0000004000004040 | 108.0 |
+-------------------------+--------------------------+-------+
SELECT round(l2sq_distance(v, v), 4) FROM t;
+----------------------------------------+
| round(l2sq_distance(t.v,t.v),Int64(4)) |
+----------------------------------------+
| 0.0 |
| 0.0 |
| 0.0 |
+----------------------------------------+
-- Unexpected dimension --
SELECT l2sq_distance(v, '[1.0]') FROM t;
Error: 3001(EngineExecuteQuery), Invalid function args: The length of the vectors must match to calculate distance, have: 3 vs 1
-- Invalid type --
SELECT l2sq_distance(v, 1.0) FROM t;
Error: 3001(EngineExecuteQuery), Invalid argument error: Encountered non UTF-8 data: invalid utf-8 sequence of 1 bytes from index 2
SELECT round(dot_product(v, '[0.0, 0.0, 0.0]'), 4) FROM t;
+----------------------------------------------------------+
| round(dot_product(t.v,Utf8("[0.0, 0.0, 0.0]")),Int64(4)) |
+----------------------------------------------------------+
| 0.0 |
| 0.0 |
| 0.0 |
+----------------------------------------------------------+
SELECT *, round(dot_product(v, '[0.0, 0.0, 0.0]'), 4) as d FROM t ORDER BY d;
+-------------------------+--------------------------+-----+
| ts | v | d |
+-------------------------+--------------------------+-----+
| 1970-01-01T00:00:00.001 | 0000803f0000004000004040 | 0.0 |
| 1970-01-01T00:00:00.002 | 000080400000a0400000c040 | 0.0 |
| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 0.0 |
+-------------------------+--------------------------+-----+
SELECT round(dot_product('[7.0, 8.0, 9.0]', v), 4) FROM t;
+----------------------------------------------------------+
| round(dot_product(Utf8("[7.0, 8.0, 9.0]"),t.v),Int64(4)) |
+----------------------------------------------------------+
| 50.0 |
| 122.0 |
| 194.0 |
+----------------------------------------------------------+
SELECT *, round(dot_product('[7.0, 8.0, 9.0]', v), 4) as d FROM t ORDER BY d;
+-------------------------+--------------------------+-------+
| ts | v | d |
+-------------------------+--------------------------+-------+
| 1970-01-01T00:00:00.001 | 0000803f0000004000004040 | 50.0 |
| 1970-01-01T00:00:00.002 | 000080400000a0400000c040 | 122.0 |
| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 194.0 |
+-------------------------+--------------------------+-------+
SELECT round(dot_product(v, v), 4) FROM t;
+--------------------------------------+
| round(dot_product(t.v,t.v),Int64(4)) |
+--------------------------------------+
| 14.0 |
| 77.0 |
| 194.0 |
+--------------------------------------+
-- Unexpected dimension --
SELECT dot_product(v, '[1.0]') FROM t;
Error: 3001(EngineExecuteQuery), Invalid function args: The length of the vectors must match to calculate distance, have: 3 vs 1
-- Invalid type --
SELECT dot_product(v, 1.0) FROM t;
Error: 3001(EngineExecuteQuery), Invalid argument error: Encountered non UTF-8 data: invalid utf-8 sequence of 1 bytes from index 2
-- Unexpected dimension --
INSERT INTO t VALUES
(4, "[1.0]");

View File

@@ -11,6 +11,55 @@ SELECT * FROM t;
-- SQLNESS PROTOCOL POSTGRES
SELECT * FROM t;
SELECT round(cos_distance(v, '[0.0, 0.0, 0.0]'), 4) FROM t;
SELECT *, round(cos_distance(v, '[0.0, 0.0, 0.0]'), 4) as d FROM t ORDER BY d;
SELECT round(cos_distance('[7.0, 8.0, 9.0]', v), 4) FROM t;
SELECT *, round(cos_distance('[7.0, 8.0, 9.0]', v), 4) as d FROM t ORDER BY d;
SELECT round(cos_distance(v, v), 4) FROM t;
-- Unexpected dimension --
SELECT cos_distance(v, '[1.0]') FROM t;
-- Invalid type --
SELECT cos_distance(v, 1.0) FROM t;
SELECT round(l2sq_distance(v, '[0.0, 0.0, 0.0]'), 4) FROM t;
SELECT *, round(l2sq_distance(v, '[0.0, 0.0, 0.0]'), 4) as d FROM t ORDER BY d;
SELECT round(l2sq_distance('[7.0, 8.0, 9.0]', v), 4) FROM t;
SELECT *, round(l2sq_distance('[7.0, 8.0, 9.0]', v), 4) as d FROM t ORDER BY d;
SELECT round(l2sq_distance(v, v), 4) FROM t;
-- Unexpected dimension --
SELECT l2sq_distance(v, '[1.0]') FROM t;
-- Invalid type --
SELECT l2sq_distance(v, 1.0) FROM t;
SELECT round(dot_product(v, '[0.0, 0.0, 0.0]'), 4) FROM t;
SELECT *, round(dot_product(v, '[0.0, 0.0, 0.0]'), 4) as d FROM t ORDER BY d;
SELECT round(dot_product('[7.0, 8.0, 9.0]', v), 4) FROM t;
SELECT *, round(dot_product('[7.0, 8.0, 9.0]', v), 4) as d FROM t ORDER BY d;
SELECT round(dot_product(v, v), 4) FROM t;
-- Unexpected dimension --
SELECT dot_product(v, '[1.0]') FROM t;
-- Invalid type --
SELECT dot_product(v, 1.0) FROM t;
-- Unexpected dimension --
INSERT INTO t VALUES
(4, "[1.0]");