From f00b21c98c05b750d0facbf77c3d7f620f575c54 Mon Sep 17 00:00:00 2001 From: QianZhu Date: Tue, 24 Sep 2024 16:10:29 -0700 Subject: [PATCH] fix: metric type for python/node search api (#1689) --- node/src/remote/client.ts | 5 ++++- node/src/remote/index.ts | 1 + python/python/lancedb/query.py | 6 +++--- rust/lancedb/src/lib.rs | 1 + 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/node/src/remote/client.ts b/node/src/remote/client.ts index dbd30893..de70a9e3 100644 --- a/node/src/remote/client.ts +++ b/node/src/remote/client.ts @@ -17,6 +17,7 @@ import axios, { type AxiosResponse, type ResponseType } from 'axios' import { tableFromIPC, type Table as ArrowTable } from 'apache-arrow' import { type RemoteResponse, type RemoteRequest, Method } from '../middleware' +import { MetricType } from '..' interface HttpLancedbClientMiddleware { onRemoteRequest( @@ -152,6 +153,7 @@ export class HttpLancedbClient { refineFactor?: number, columns?: string[], filter?: string, + metricType?: MetricType, fastSearch?: boolean ): Promise> { const result = await this.post( @@ -160,10 +162,11 @@ export class HttpLancedbClient { vector, k, nprobes, - refineFactor, + refine_factor: refineFactor, columns, filter, prefilter, + metric: metricType, fast_search: fastSearch }, undefined, diff --git a/node/src/remote/index.ts b/node/src/remote/index.ts index 38bc7a78..df7caab3 100644 --- a/node/src/remote/index.ts +++ b/node/src/remote/index.ts @@ -239,6 +239,7 @@ export class RemoteQuery extends Query { (this as any)._refineFactor, (this as any)._select, (this as any)._filter, + (this as any)._metricType, (this as any)._fastSearch ) diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 48b34860..8ef97897 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -576,12 +576,12 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): self._reranker = None self._str_query = str_query - def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder: + def metric(self, metric: Literal["L2", "cosine", "dot"]) -> LanceVectorQueryBuilder: """Set the distance metric to use. Parameters ---------- - metric: "L2" or "cosine" + metric: "L2" or "cosine" or "dot" The distance metric to use. By default "L2" is used. Returns @@ -589,7 +589,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): LanceVectorQueryBuilder The LanceQueryBuilder object. """ - self._metric = metric + self._metric = metric.lower() return self def nprobes(self, nprobes: int) -> LanceVectorQueryBuilder: diff --git a/rust/lancedb/src/lib.rs b/rust/lancedb/src/lib.rs index 0b7775bd..c552c30b 100644 --- a/rust/lancedb/src/lib.rs +++ b/rust/lancedb/src/lib.rs @@ -228,6 +228,7 @@ pub use table::Table; #[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)] #[non_exhaustive] +#[serde(rename_all = "lowercase")] pub enum DistanceType { /// Euclidean distance. This is a very common distance metric that /// accounts for both magnitude and direction when determining the distance