feat: change DistanceType to be independent thing instead of resuing lance_linalg (#1133)

This PR originated from a request to add `Serialize` / `Deserialize` to
`lance_linalg::distance::DistanceType`. However, that is a strange
request for `lance_linalg` which shouldn't really have to worry about
`Serialize` / `Deserialize`. The problem is that `lancedb` is re-using
`DistanceType` and things in `lancedb` do need to worry about
`Serialize`/`Deserialize` (because `lancedb` needs to support remote
client).

On the bright side, separating the two types allows us to independently
document distance type and allows `lance_linalg` to make changes to
`DistanceType` in the future without having to worry about backwards
compatibility concerns.
This commit is contained in:
Weston Pace
2024-03-19 07:27:51 -07:00
parent 0fe0976a0e
commit a23b856410
4 changed files with 73 additions and 10 deletions

View File

@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use lance_linalg::distance::MetricType;
use lancedb::index::vector::IvfPqIndexBuilder;
use lancedb::index::Index;
use lancedb::DistanceType;
use neon::context::FunctionContext;
use neon::prelude::*;
use std::convert::TryFrom;
@@ -72,8 +72,8 @@ fn get_index_params_builder(
}
let mut builder = IvfPqIndexBuilder::default();
if let Some(metric_type) = obj.get_opt::<JsString, _, _>(cx, "metric_type")? {
let metric_type = MetricType::try_from(metric_type.value(cx).as_str())?;
builder = builder.distance_type(metric_type);
let distance_type = DistanceType::try_from(metric_type.value(cx).as_str())?;
builder = builder.distance_type(distance_type);
}
if let Some(np) = obj.get_opt_u32(cx, "num_partitions")? {
builder = builder.num_partitions(np);

View File

@@ -2,8 +2,8 @@ use std::convert::TryFrom;
use std::ops::Deref;
use futures::{TryFutureExt, TryStreamExt};
use lance_linalg::distance::MetricType;
use lancedb::query::{ExecutableQuery, QueryBase, Select};
use lancedb::DistanceType;
use neon::context::FunctionContext;
use neon::handle::Handle;
use neon::prelude::*;
@@ -74,12 +74,12 @@ impl JsQuery {
let query_vector = query_obj.get_opt::<JsArray, _, _>(&mut cx, "_queryVector")?;
if let Some(query) = query_vector.map(|q| convert::js_array_to_vec(q.deref(), &mut cx)) {
let mut vector_builder = builder.nearest_to(query).unwrap();
if let Some(metric_type) = query_obj
if let Some(distance_type) = query_obj
.get_opt::<JsString, _, _>(&mut cx, "_metricType")?
.map(|s| s.value(&mut cx))
.map(|s| MetricType::try_from(s.as_str()).unwrap())
.map(|s| DistanceType::try_from(s.as_str()).unwrap())
{
vector_builder = vector_builder.distance_type(metric_type);
vector_builder = vector_builder.distance_type(distance_type);
}
let nprobes = query_obj.get_usize(&mut cx, "_nprobes").or_throw(&mut cx)?;

View File

@@ -194,9 +194,72 @@ pub(crate) mod remote;
pub mod table;
pub mod utils;
use std::fmt::Display;
use serde::{Deserialize, Serialize};
pub use connection::Connection;
pub use error::{Error, Result};
pub use lance_linalg::distance::DistanceType;
use lance_linalg::distance::DistanceType as LanceDistanceType;
pub use table::Table;
#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum DistanceType {
/// Euclidean distance. This is a very common distance metric that
/// accounts for both magnitude and direction when determining the distance
/// between vectors. L2 distance has a range of [0, ∞).
L2,
/// Cosine distance. Cosine distance is a distance metric
/// calculated from the cosine similarity between two vectors. Cosine
/// similarity is a measure of similarity between two non-zero vectors of an
/// inner product space. It is defined to equal the cosine of the angle
/// between them. Unlike L2, the cosine distance is not affected by the
/// magnitude of the vectors. Cosine distance has a range of [0, 2].
///
/// Note: the cosine distance is undefined when one (or both) of the vectors
/// are all zeros (there is no direction). These vectors are invalid and may
/// never be returned from a vector search.
Cosine,
/// Dot product. Dot distance is the dot product of two vectors. Dot
/// distance has a range of (-∞, ∞). If the vectors are normalized (i.e. their
/// L2 norm is 1), then dot distance is equivalent to the cosine distance.
Dot,
}
impl From<DistanceType> for LanceDistanceType {
fn from(value: DistanceType) -> Self {
match value {
DistanceType::L2 => LanceDistanceType::L2,
DistanceType::Cosine => LanceDistanceType::Cosine,
DistanceType::Dot => LanceDistanceType::Dot,
}
}
}
impl From<LanceDistanceType> for DistanceType {
fn from(value: LanceDistanceType) -> Self {
match value {
LanceDistanceType::L2 => DistanceType::L2,
LanceDistanceType::Cosine => DistanceType::Cosine,
LanceDistanceType::Dot => DistanceType::Dot,
}
}
}
impl<'a> TryFrom<&'a str> for DistanceType {
type Error = <LanceDistanceType as TryFrom<&'a str>>::Error;
fn try_from(value: &str) -> std::prelude::v1::Result<Self, Self::Error> {
LanceDistanceType::try_from(value).map(DistanceType::from)
}
}
impl Display for DistanceType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
LanceDistanceType::from(*self).fmt(f)
}
}
/// Connect to a database
pub use connection::connect;

View File

@@ -1108,7 +1108,7 @@ impl NativeTable {
/*num_bits=*/ 8,
num_sub_vectors as usize,
false,
index.distance_type,
index.distance_type.into(),
index.max_iterations as usize,
);
dataset
@@ -1229,7 +1229,7 @@ impl NativeTable {
}
if let Some(distance_type) = query.distance_type {
scanner.distance_metric(distance_type);
scanner.distance_metric(distance_type.into());
}
Ok(scanner.try_into_stream().await?)
}