From a23b856410c4154e28823c4ec2db3a0b468728cb Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Tue, 19 Mar 2024 07:27:51 -0700 Subject: [PATCH] feat: change DistanceType to be independent thing instead of resuing lance_linalg (#1133) This PR originated from a request to add `Serialize` / `Deserialize` to `lance_linalg::distance::DistanceType`. However, that is a strange request for `lance_linalg` which shouldn't really have to worry about `Serialize` / `Deserialize`. The problem is that `lancedb` is re-using `DistanceType` and things in `lancedb` do need to worry about `Serialize`/`Deserialize` (because `lancedb` needs to support remote client). On the bright side, separating the two types allows us to independently document distance type and allows `lance_linalg` to make changes to `DistanceType` in the future without having to worry about backwards compatibility concerns. --- rust/ffi/node/src/index/vector.rs | 6 +-- rust/ffi/node/src/query.rs | 8 ++-- rust/lancedb/src/lib.rs | 65 ++++++++++++++++++++++++++++++- rust/lancedb/src/table.rs | 4 +- 4 files changed, 73 insertions(+), 10 deletions(-) diff --git a/rust/ffi/node/src/index/vector.rs b/rust/ffi/node/src/index/vector.rs index 3190ee2d..bcc74cb5 100644 --- a/rust/ffi/node/src/index/vector.rs +++ b/rust/ffi/node/src/index/vector.rs @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -use lance_linalg::distance::MetricType; use lancedb::index::vector::IvfPqIndexBuilder; use lancedb::index::Index; +use lancedb::DistanceType; use neon::context::FunctionContext; use neon::prelude::*; use std::convert::TryFrom; @@ -72,8 +72,8 @@ fn get_index_params_builder( } let mut builder = IvfPqIndexBuilder::default(); if let Some(metric_type) = obj.get_opt::(cx, "metric_type")? { - let metric_type = MetricType::try_from(metric_type.value(cx).as_str())?; - builder = builder.distance_type(metric_type); + let distance_type = DistanceType::try_from(metric_type.value(cx).as_str())?; + builder = builder.distance_type(distance_type); } if let Some(np) = obj.get_opt_u32(cx, "num_partitions")? { builder = builder.num_partitions(np); diff --git a/rust/ffi/node/src/query.rs b/rust/ffi/node/src/query.rs index 7e00ac21..7ec5a375 100644 --- a/rust/ffi/node/src/query.rs +++ b/rust/ffi/node/src/query.rs @@ -2,8 +2,8 @@ use std::convert::TryFrom; use std::ops::Deref; use futures::{TryFutureExt, TryStreamExt}; -use lance_linalg::distance::MetricType; use lancedb::query::{ExecutableQuery, QueryBase, Select}; +use lancedb::DistanceType; use neon::context::FunctionContext; use neon::handle::Handle; use neon::prelude::*; @@ -74,12 +74,12 @@ impl JsQuery { let query_vector = query_obj.get_opt::(&mut cx, "_queryVector")?; if let Some(query) = query_vector.map(|q| convert::js_array_to_vec(q.deref(), &mut cx)) { let mut vector_builder = builder.nearest_to(query).unwrap(); - if let Some(metric_type) = query_obj + if let Some(distance_type) = query_obj .get_opt::(&mut cx, "_metricType")? .map(|s| s.value(&mut cx)) - .map(|s| MetricType::try_from(s.as_str()).unwrap()) + .map(|s| DistanceType::try_from(s.as_str()).unwrap()) { - vector_builder = vector_builder.distance_type(metric_type); + vector_builder = vector_builder.distance_type(distance_type); } let nprobes = query_obj.get_usize(&mut cx, "_nprobes").or_throw(&mut cx)?; diff --git a/rust/lancedb/src/lib.rs b/rust/lancedb/src/lib.rs index 99c7e888..2485e965 100644 --- a/rust/lancedb/src/lib.rs +++ b/rust/lancedb/src/lib.rs @@ -194,9 +194,72 @@ pub(crate) mod remote; pub mod table; pub mod utils; +use std::fmt::Display; + +use serde::{Deserialize, Serialize}; + +pub use connection::Connection; pub use error::{Error, Result}; -pub use lance_linalg::distance::DistanceType; +use lance_linalg::distance::DistanceType as LanceDistanceType; pub use table::Table; +#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)] +#[non_exhaustive] +pub enum DistanceType { + /// Euclidean distance. This is a very common distance metric that + /// accounts for both magnitude and direction when determining the distance + /// between vectors. L2 distance has a range of [0, ∞). + L2, + /// Cosine distance. Cosine distance is a distance metric + /// calculated from the cosine similarity between two vectors. Cosine + /// similarity is a measure of similarity between two non-zero vectors of an + /// inner product space. It is defined to equal the cosine of the angle + /// between them. Unlike L2, the cosine distance is not affected by the + /// magnitude of the vectors. Cosine distance has a range of [0, 2]. + /// + /// Note: the cosine distance is undefined when one (or both) of the vectors + /// are all zeros (there is no direction). These vectors are invalid and may + /// never be returned from a vector search. + Cosine, + /// Dot product. Dot distance is the dot product of two vectors. Dot + /// distance has a range of (-∞, ∞). If the vectors are normalized (i.e. their + /// L2 norm is 1), then dot distance is equivalent to the cosine distance. + Dot, +} + +impl From for LanceDistanceType { + fn from(value: DistanceType) -> Self { + match value { + DistanceType::L2 => LanceDistanceType::L2, + DistanceType::Cosine => LanceDistanceType::Cosine, + DistanceType::Dot => LanceDistanceType::Dot, + } + } +} + +impl From for DistanceType { + fn from(value: LanceDistanceType) -> Self { + match value { + LanceDistanceType::L2 => DistanceType::L2, + LanceDistanceType::Cosine => DistanceType::Cosine, + LanceDistanceType::Dot => DistanceType::Dot, + } + } +} + +impl<'a> TryFrom<&'a str> for DistanceType { + type Error = >::Error; + + fn try_from(value: &str) -> std::prelude::v1::Result { + LanceDistanceType::try_from(value).map(DistanceType::from) + } +} + +impl Display for DistanceType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + LanceDistanceType::from(*self).fmt(f) + } +} + /// Connect to a database pub use connection::connect; diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index cd09f5ca..20208925 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -1108,7 +1108,7 @@ impl NativeTable { /*num_bits=*/ 8, num_sub_vectors as usize, false, - index.distance_type, + index.distance_type.into(), index.max_iterations as usize, ); dataset @@ -1229,7 +1229,7 @@ impl NativeTable { } if let Some(distance_type) = query.distance_type { - scanner.distance_metric(distance_type); + scanner.distance_metric(distance_type.into()); } Ok(scanner.try_into_stream().await?) }