mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-05 03:12:57 +00:00
feat: add distance_type() parameter to python sync query builders and metric() as an alias (#2073)
This PR aims to fix #2047 by doing the following things: - Add a distance_type parameter to the sync query builders of Python SDK. - Make metric an alias to distance_type.
This commit is contained in:
@@ -505,7 +505,7 @@ class LanceQueryBuilder(ABC):
|
||||
"column": self._vector_column,
|
||||
"q": self._query,
|
||||
"k": self._limit,
|
||||
"metric": self._metric,
|
||||
"metric": self._distance_type,
|
||||
"nprobes": self._nprobes,
|
||||
"refine_factor": self._refine_factor,
|
||||
"use_index": self._use_index,
|
||||
@@ -576,7 +576,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", data=data)
|
||||
>>> (table.search([0.4, 0.4])
|
||||
... .metric("cosine")
|
||||
... .distance_type("cosine")
|
||||
... .where("b < 10")
|
||||
... .select(["b", "vector"])
|
||||
... .limit(2)
|
||||
@@ -596,7 +596,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
):
|
||||
super().__init__(table)
|
||||
self._query = query
|
||||
self._metric = "L2"
|
||||
self._distance_type = "L2"
|
||||
self._nprobes = 20
|
||||
self._lower_bound = None
|
||||
self._upper_bound = None
|
||||
@@ -610,6 +610,9 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
def metric(self, metric: Literal["L2", "cosine", "dot"]) -> LanceVectorQueryBuilder:
|
||||
"""Set the distance metric to use.
|
||||
|
||||
This is an alias for distance_type() and may be deprecated in the future.
|
||||
Please use distance_type() instead.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metric: "L2" or "cosine" or "dot"
|
||||
@@ -620,7 +623,32 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
LanceVectorQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._metric = metric.lower()
|
||||
return self.distance_type(metric)
|
||||
|
||||
def distance_type(
|
||||
self, distance_type: Literal["L2", "cosine", "dot"]
|
||||
) -> "LanceVectorQueryBuilder":
|
||||
"""Set the distance metric to use.
|
||||
|
||||
When performing a vector search we try and find the "nearest" vectors according
|
||||
to some kind of distance metric. This parameter controls which distance metric
|
||||
to use.
|
||||
|
||||
Note: if there is a vector index then the distance type used MUST match the
|
||||
distance type used to train the vector index. If this is not done then the
|
||||
results will be invalid.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
distance_type: "L2" or "cosine" or "dot"
|
||||
The distance metric to use. By default "L2" is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceVectorQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._distance_type = distance_type.lower()
|
||||
return self
|
||||
|
||||
def nprobes(self, nprobes: int) -> LanceVectorQueryBuilder:
|
||||
@@ -745,7 +773,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
filter=self._where,
|
||||
prefilter=self._prefilter,
|
||||
k=self._limit,
|
||||
metric=self._metric,
|
||||
metric=self._distance_type,
|
||||
columns=self._columns,
|
||||
nprobes=self._nprobes,
|
||||
lower_bound=self._lower_bound,
|
||||
@@ -1078,7 +1106,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
self._reranker = RRFReranker()
|
||||
self._nprobes = None
|
||||
self._refine_factor = None
|
||||
self._metric = None
|
||||
self._distance_type = None
|
||||
self._phrase_query = False
|
||||
|
||||
def _validate_query(self, query, vector=None, text=None):
|
||||
@@ -1146,8 +1174,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
self._fts_query.with_row_id(True)
|
||||
if self._phrase_query:
|
||||
self._fts_query.phrase_query(True)
|
||||
if self._metric:
|
||||
self._vector_query.metric(self._metric)
|
||||
if self._distance_type:
|
||||
self._vector_query.metric(self._distance_type)
|
||||
if self._nprobes:
|
||||
self._vector_query.nprobes(self._nprobes)
|
||||
if self._refine_factor:
|
||||
@@ -1386,6 +1414,9 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
def metric(self, metric: Literal["L2", "cosine", "dot"]) -> LanceHybridQueryBuilder:
|
||||
"""Set the distance metric to use.
|
||||
|
||||
This is an alias for distance_type() and may be deprecated in the future.
|
||||
Please use distance_type() instead.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metric: "L2" or "cosine" or "dot"
|
||||
@@ -1396,7 +1427,32 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
LanceVectorQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._metric = metric.lower()
|
||||
return self.distance_type(metric)
|
||||
|
||||
def distance_type(
|
||||
self, distance_type: Literal["L2", "cosine", "dot"]
|
||||
) -> "LanceHybridQueryBuilder":
|
||||
"""Set the distance metric to use.
|
||||
|
||||
When performing a vector search we try and find the "nearest" vectors according
|
||||
to some kind of distance metric. This parameter controls which distance metric
|
||||
to use.
|
||||
|
||||
Note: if there is a vector index then the distance type used MUST match the
|
||||
distance type used to train the vector index. If this is not done then the
|
||||
results will be invalid.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
distance_type: "L2" or "cosine" or "dot"
|
||||
The distance metric to use. By default "L2" is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceVectorQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._distance_type = distance_type.lower()
|
||||
return self
|
||||
|
||||
def refine_factor(self, refine_factor: int) -> LanceHybridQueryBuilder:
|
||||
|
||||
@@ -38,7 +38,7 @@ def test_binary_vector():
|
||||
|
||||
query = np.random.randint(0, 2, size=256)
|
||||
packed_query = np.packbits(query)
|
||||
tbl.search(packed_query).metric("hamming").to_arrow()
|
||||
tbl.search(packed_query).distance_type("hamming").to_arrow()
|
||||
# --8<-- [end:sync_binary_vector]
|
||||
db.drop_table("my_binary_vectors")
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ def test_vector_search():
|
||||
tbl.search(np.random.random((1536))).limit(10).to_list()
|
||||
# --8<-- [end:exhaustive_search]
|
||||
# --8<-- [start:exhaustive_search_cosine]
|
||||
tbl.search(np.random.random((1536))).metric("cosine").limit(10).to_list()
|
||||
tbl.search(np.random.random((1536))).distance_type("cosine").limit(10).to_list()
|
||||
# --8<-- [end:exhaustive_search_cosine]
|
||||
# --8<-- [start:create_table_with_nested_schema]
|
||||
# Let's add 100 sample rows to our dataset
|
||||
|
||||
@@ -377,14 +377,14 @@ def test_query_builder_with_metric(table):
|
||||
df_default = LanceVectorQueryBuilder(table, query, vector_column_name).to_pandas()
|
||||
df_l2 = (
|
||||
LanceVectorQueryBuilder(table, query, vector_column_name)
|
||||
.metric("L2")
|
||||
.distance_type("L2")
|
||||
.to_pandas()
|
||||
)
|
||||
tm.assert_frame_equal(df_default, df_l2)
|
||||
|
||||
df_cosine = (
|
||||
LanceVectorQueryBuilder(table, query, vector_column_name)
|
||||
.metric("cosine")
|
||||
.distance_type("cosine")
|
||||
.limit(1)
|
||||
.to_pandas()
|
||||
)
|
||||
@@ -401,7 +401,7 @@ def test_query_builder_with_different_vector_column():
|
||||
vector_column_name = "foo_vector"
|
||||
builder = (
|
||||
LanceVectorQueryBuilder(table, query, vector_column_name)
|
||||
.metric("cosine")
|
||||
.distance_type("cosine")
|
||||
.where("b < 10")
|
||||
.select(["b"])
|
||||
.limit(2)
|
||||
|
||||
@@ -366,7 +366,7 @@ def test_query_sync_maximal():
|
||||
with query_test_table(handler) as table:
|
||||
(
|
||||
table.search([1, 2, 3], vector_column_name="vector2", fast_search=True)
|
||||
.metric("cosine")
|
||||
.distance_type("cosine")
|
||||
.limit(42)
|
||||
.offset(10)
|
||||
.refine_factor(10)
|
||||
|
||||
@@ -1242,7 +1242,9 @@ def test_hybrid_search_metric_type(tmp_db: DBConnection):
|
||||
|
||||
# with custom metric
|
||||
result_dot = (
|
||||
table.search("feeling lucky", query_type="hybrid").metric("dot").to_arrow()
|
||||
table.search("feeling lucky", query_type="hybrid")
|
||||
.distance_type("dot")
|
||||
.to_arrow()
|
||||
)
|
||||
result_l2 = table.search("feeling lucky", query_type="hybrid").to_arrow()
|
||||
assert len(result_dot) > 0
|
||||
|
||||
Reference in New Issue
Block a user