mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 07:09:57 +00:00
feat: change DistanceType to be independent thing instead of resuing lance_linalg (#1133)
This PR originated from a request to add `Serialize` / `Deserialize` to `lance_linalg::distance::DistanceType`. However, that is a strange request for `lance_linalg` which shouldn't really have to worry about `Serialize` / `Deserialize`. The problem is that `lancedb` is re-using `DistanceType` and things in `lancedb` do need to worry about `Serialize`/`Deserialize` (because `lancedb` needs to support remote client). On the bright side, separating the two types allows us to independently document distance type and allows `lance_linalg` to make changes to `DistanceType` in the future without having to worry about backwards compatibility concerns.
This commit is contained in:
@@ -12,9 +12,9 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use lance_linalg::distance::MetricType;
|
||||
use lancedb::index::vector::IvfPqIndexBuilder;
|
||||
use lancedb::index::Index;
|
||||
use lancedb::DistanceType;
|
||||
use neon::context::FunctionContext;
|
||||
use neon::prelude::*;
|
||||
use std::convert::TryFrom;
|
||||
@@ -72,8 +72,8 @@ fn get_index_params_builder(
|
||||
}
|
||||
let mut builder = IvfPqIndexBuilder::default();
|
||||
if let Some(metric_type) = obj.get_opt::<JsString, _, _>(cx, "metric_type")? {
|
||||
let metric_type = MetricType::try_from(metric_type.value(cx).as_str())?;
|
||||
builder = builder.distance_type(metric_type);
|
||||
let distance_type = DistanceType::try_from(metric_type.value(cx).as_str())?;
|
||||
builder = builder.distance_type(distance_type);
|
||||
}
|
||||
if let Some(np) = obj.get_opt_u32(cx, "num_partitions")? {
|
||||
builder = builder.num_partitions(np);
|
||||
|
||||
@@ -2,8 +2,8 @@ use std::convert::TryFrom;
|
||||
use std::ops::Deref;
|
||||
|
||||
use futures::{TryFutureExt, TryStreamExt};
|
||||
use lance_linalg::distance::MetricType;
|
||||
use lancedb::query::{ExecutableQuery, QueryBase, Select};
|
||||
use lancedb::DistanceType;
|
||||
use neon::context::FunctionContext;
|
||||
use neon::handle::Handle;
|
||||
use neon::prelude::*;
|
||||
@@ -74,12 +74,12 @@ impl JsQuery {
|
||||
let query_vector = query_obj.get_opt::<JsArray, _, _>(&mut cx, "_queryVector")?;
|
||||
if let Some(query) = query_vector.map(|q| convert::js_array_to_vec(q.deref(), &mut cx)) {
|
||||
let mut vector_builder = builder.nearest_to(query).unwrap();
|
||||
if let Some(metric_type) = query_obj
|
||||
if let Some(distance_type) = query_obj
|
||||
.get_opt::<JsString, _, _>(&mut cx, "_metricType")?
|
||||
.map(|s| s.value(&mut cx))
|
||||
.map(|s| MetricType::try_from(s.as_str()).unwrap())
|
||||
.map(|s| DistanceType::try_from(s.as_str()).unwrap())
|
||||
{
|
||||
vector_builder = vector_builder.distance_type(metric_type);
|
||||
vector_builder = vector_builder.distance_type(distance_type);
|
||||
}
|
||||
|
||||
let nprobes = query_obj.get_usize(&mut cx, "_nprobes").or_throw(&mut cx)?;
|
||||
|
||||
@@ -194,9 +194,72 @@ pub(crate) mod remote;
|
||||
pub mod table;
|
||||
pub mod utils;
|
||||
|
||||
use std::fmt::Display;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub use connection::Connection;
|
||||
pub use error::{Error, Result};
|
||||
pub use lance_linalg::distance::DistanceType;
|
||||
use lance_linalg::distance::DistanceType as LanceDistanceType;
|
||||
pub use table::Table;
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[non_exhaustive]
|
||||
pub enum DistanceType {
|
||||
/// Euclidean distance. This is a very common distance metric that
|
||||
/// accounts for both magnitude and direction when determining the distance
|
||||
/// between vectors. L2 distance has a range of [0, ∞).
|
||||
L2,
|
||||
/// Cosine distance. Cosine distance is a distance metric
|
||||
/// calculated from the cosine similarity between two vectors. Cosine
|
||||
/// similarity is a measure of similarity between two non-zero vectors of an
|
||||
/// inner product space. It is defined to equal the cosine of the angle
|
||||
/// between them. Unlike L2, the cosine distance is not affected by the
|
||||
/// magnitude of the vectors. Cosine distance has a range of [0, 2].
|
||||
///
|
||||
/// Note: the cosine distance is undefined when one (or both) of the vectors
|
||||
/// are all zeros (there is no direction). These vectors are invalid and may
|
||||
/// never be returned from a vector search.
|
||||
Cosine,
|
||||
/// Dot product. Dot distance is the dot product of two vectors. Dot
|
||||
/// distance has a range of (-∞, ∞). If the vectors are normalized (i.e. their
|
||||
/// L2 norm is 1), then dot distance is equivalent to the cosine distance.
|
||||
Dot,
|
||||
}
|
||||
|
||||
impl From<DistanceType> for LanceDistanceType {
|
||||
fn from(value: DistanceType) -> Self {
|
||||
match value {
|
||||
DistanceType::L2 => LanceDistanceType::L2,
|
||||
DistanceType::Cosine => LanceDistanceType::Cosine,
|
||||
DistanceType::Dot => LanceDistanceType::Dot,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<LanceDistanceType> for DistanceType {
|
||||
fn from(value: LanceDistanceType) -> Self {
|
||||
match value {
|
||||
LanceDistanceType::L2 => DistanceType::L2,
|
||||
LanceDistanceType::Cosine => DistanceType::Cosine,
|
||||
LanceDistanceType::Dot => DistanceType::Dot,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> TryFrom<&'a str> for DistanceType {
|
||||
type Error = <LanceDistanceType as TryFrom<&'a str>>::Error;
|
||||
|
||||
fn try_from(value: &str) -> std::prelude::v1::Result<Self, Self::Error> {
|
||||
LanceDistanceType::try_from(value).map(DistanceType::from)
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for DistanceType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
LanceDistanceType::from(*self).fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
/// Connect to a database
|
||||
pub use connection::connect;
|
||||
|
||||
@@ -1108,7 +1108,7 @@ impl NativeTable {
|
||||
/*num_bits=*/ 8,
|
||||
num_sub_vectors as usize,
|
||||
false,
|
||||
index.distance_type,
|
||||
index.distance_type.into(),
|
||||
index.max_iterations as usize,
|
||||
);
|
||||
dataset
|
||||
@@ -1229,7 +1229,7 @@ impl NativeTable {
|
||||
}
|
||||
|
||||
if let Some(distance_type) = query.distance_type {
|
||||
scanner.distance_metric(distance_type);
|
||||
scanner.distance_metric(distance_type.into());
|
||||
}
|
||||
Ok(scanner.try_into_stream().await?)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user