feat: add distance_type() parameter to python sync query builders and metric() as an alias (#2073)

This PR aims to fix #2047 by doing the following things:
- Add a distance_type parameter to the sync query builders of Python
SDK.
- Make metric an alias to distance_type.
This commit is contained in:
Vaibhav
2025-01-28 13:59:53 -08:00
committed by GitHub
parent 0a9e1eab75
commit dac0857745
6 changed files with 74 additions and 16 deletions

View File

@@ -505,7 +505,7 @@ class LanceQueryBuilder(ABC):
"column": self._vector_column,
"q": self._query,
"k": self._limit,
"metric": self._metric,
"metric": self._distance_type,
"nprobes": self._nprobes,
"refine_factor": self._refine_factor,
"use_index": self._use_index,
@@ -576,7 +576,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
>>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", data=data)
>>> (table.search([0.4, 0.4])
... .metric("cosine")
... .distance_type("cosine")
... .where("b < 10")
... .select(["b", "vector"])
... .limit(2)
@@ -596,7 +596,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
):
super().__init__(table)
self._query = query
self._metric = "L2"
self._distance_type = "L2"
self._nprobes = 20
self._lower_bound = None
self._upper_bound = None
@@ -610,6 +610,9 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
def metric(self, metric: Literal["L2", "cosine", "dot"]) -> LanceVectorQueryBuilder:
"""Set the distance metric to use.
This is an alias for distance_type() and may be deprecated in the future.
Please use distance_type() instead.
Parameters
----------
metric: "L2" or "cosine" or "dot"
@@ -620,7 +623,32 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._metric = metric.lower()
return self.distance_type(metric)
def distance_type(
self, distance_type: Literal["L2", "cosine", "dot"]
) -> "LanceVectorQueryBuilder":
"""Set the distance metric to use.
When performing a vector search we try and find the "nearest" vectors according
to some kind of distance metric. This parameter controls which distance metric
to use.
Note: if there is a vector index then the distance type used MUST match the
distance type used to train the vector index. If this is not done then the
results will be invalid.
Parameters
----------
distance_type: "L2" or "cosine" or "dot"
The distance metric to use. By default "L2" is used.
Returns
-------
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._distance_type = distance_type.lower()
return self
def nprobes(self, nprobes: int) -> LanceVectorQueryBuilder:
@@ -745,7 +773,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
filter=self._where,
prefilter=self._prefilter,
k=self._limit,
metric=self._metric,
metric=self._distance_type,
columns=self._columns,
nprobes=self._nprobes,
lower_bound=self._lower_bound,
@@ -1078,7 +1106,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
self._reranker = RRFReranker()
self._nprobes = None
self._refine_factor = None
self._metric = None
self._distance_type = None
self._phrase_query = False
def _validate_query(self, query, vector=None, text=None):
@@ -1146,8 +1174,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
self._fts_query.with_row_id(True)
if self._phrase_query:
self._fts_query.phrase_query(True)
if self._metric:
self._vector_query.metric(self._metric)
if self._distance_type:
self._vector_query.metric(self._distance_type)
if self._nprobes:
self._vector_query.nprobes(self._nprobes)
if self._refine_factor:
@@ -1386,6 +1414,9 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
def metric(self, metric: Literal["L2", "cosine", "dot"]) -> LanceHybridQueryBuilder:
"""Set the distance metric to use.
This is an alias for distance_type() and may be deprecated in the future.
Please use distance_type() instead.
Parameters
----------
metric: "L2" or "cosine" or "dot"
@@ -1396,7 +1427,32 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._metric = metric.lower()
return self.distance_type(metric)
def distance_type(
self, distance_type: Literal["L2", "cosine", "dot"]
) -> "LanceHybridQueryBuilder":
"""Set the distance metric to use.
When performing a vector search we try and find the "nearest" vectors according
to some kind of distance metric. This parameter controls which distance metric
to use.
Note: if there is a vector index then the distance type used MUST match the
distance type used to train the vector index. If this is not done then the
results will be invalid.
Parameters
----------
distance_type: "L2" or "cosine" or "dot"
The distance metric to use. By default "L2" is used.
Returns
-------
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._distance_type = distance_type.lower()
return self
def refine_factor(self, refine_factor: int) -> LanceHybridQueryBuilder:

View File

@@ -38,7 +38,7 @@ def test_binary_vector():
query = np.random.randint(0, 2, size=256)
packed_query = np.packbits(query)
tbl.search(packed_query).metric("hamming").to_arrow()
tbl.search(packed_query).distance_type("hamming").to_arrow()
# --8<-- [end:sync_binary_vector]
db.drop_table("my_binary_vectors")

View File

@@ -65,7 +65,7 @@ def test_vector_search():
tbl.search(np.random.random((1536))).limit(10).to_list()
# --8<-- [end:exhaustive_search]
# --8<-- [start:exhaustive_search_cosine]
tbl.search(np.random.random((1536))).metric("cosine").limit(10).to_list()
tbl.search(np.random.random((1536))).distance_type("cosine").limit(10).to_list()
# --8<-- [end:exhaustive_search_cosine]
# --8<-- [start:create_table_with_nested_schema]
# Let's add 100 sample rows to our dataset

View File

@@ -377,14 +377,14 @@ def test_query_builder_with_metric(table):
df_default = LanceVectorQueryBuilder(table, query, vector_column_name).to_pandas()
df_l2 = (
LanceVectorQueryBuilder(table, query, vector_column_name)
.metric("L2")
.distance_type("L2")
.to_pandas()
)
tm.assert_frame_equal(df_default, df_l2)
df_cosine = (
LanceVectorQueryBuilder(table, query, vector_column_name)
.metric("cosine")
.distance_type("cosine")
.limit(1)
.to_pandas()
)
@@ -401,7 +401,7 @@ def test_query_builder_with_different_vector_column():
vector_column_name = "foo_vector"
builder = (
LanceVectorQueryBuilder(table, query, vector_column_name)
.metric("cosine")
.distance_type("cosine")
.where("b < 10")
.select(["b"])
.limit(2)

View File

@@ -366,7 +366,7 @@ def test_query_sync_maximal():
with query_test_table(handler) as table:
(
table.search([1, 2, 3], vector_column_name="vector2", fast_search=True)
.metric("cosine")
.distance_type("cosine")
.limit(42)
.offset(10)
.refine_factor(10)

View File

@@ -1242,7 +1242,9 @@ def test_hybrid_search_metric_type(tmp_db: DBConnection):
# with custom metric
result_dot = (
table.search("feeling lucky", query_type="hybrid").metric("dot").to_arrow()
table.search("feeling lucky", query_type="hybrid")
.distance_type("dot")
.to_arrow()
)
result_l2 = table.search("feeling lucky", query_type="hybrid").to_arrow()
assert len(result_dot) > 0