diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 5f18a85e..087a8a6b 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -505,7 +505,7 @@ class LanceQueryBuilder(ABC): "column": self._vector_column, "q": self._query, "k": self._limit, - "metric": self._metric, + "metric": self._distance_type, "nprobes": self._nprobes, "refine_factor": self._refine_factor, "use_index": self._use_index, @@ -576,7 +576,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): >>> db = lancedb.connect("./.lancedb") >>> table = db.create_table("my_table", data=data) >>> (table.search([0.4, 0.4]) - ... .metric("cosine") + ... .distance_type("cosine") ... .where("b < 10") ... .select(["b", "vector"]) ... .limit(2) @@ -596,7 +596,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): ): super().__init__(table) self._query = query - self._metric = "L2" + self._distance_type = "L2" self._nprobes = 20 self._lower_bound = None self._upper_bound = None @@ -610,6 +610,9 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): def metric(self, metric: Literal["L2", "cosine", "dot"]) -> LanceVectorQueryBuilder: """Set the distance metric to use. + This is an alias for distance_type() and may be deprecated in the future. + Please use distance_type() instead. + Parameters ---------- metric: "L2" or "cosine" or "dot" @@ -620,7 +623,32 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): LanceVectorQueryBuilder The LanceQueryBuilder object. """ - self._metric = metric.lower() + return self.distance_type(metric) + + def distance_type( + self, distance_type: Literal["L2", "cosine", "dot"] + ) -> "LanceVectorQueryBuilder": + """Set the distance metric to use. + + When performing a vector search we try and find the "nearest" vectors according + to some kind of distance metric. This parameter controls which distance metric + to use. + + Note: if there is a vector index then the distance type used MUST match the + distance type used to train the vector index. If this is not done then the + results will be invalid. + + Parameters + ---------- + distance_type: "L2" or "cosine" or "dot" + The distance metric to use. By default "L2" is used. + + Returns + ------- + LanceVectorQueryBuilder + The LanceQueryBuilder object. + """ + self._distance_type = distance_type.lower() return self def nprobes(self, nprobes: int) -> LanceVectorQueryBuilder: @@ -745,7 +773,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): filter=self._where, prefilter=self._prefilter, k=self._limit, - metric=self._metric, + metric=self._distance_type, columns=self._columns, nprobes=self._nprobes, lower_bound=self._lower_bound, @@ -1078,7 +1106,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): self._reranker = RRFReranker() self._nprobes = None self._refine_factor = None - self._metric = None + self._distance_type = None self._phrase_query = False def _validate_query(self, query, vector=None, text=None): @@ -1146,8 +1174,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): self._fts_query.with_row_id(True) if self._phrase_query: self._fts_query.phrase_query(True) - if self._metric: - self._vector_query.metric(self._metric) + if self._distance_type: + self._vector_query.metric(self._distance_type) if self._nprobes: self._vector_query.nprobes(self._nprobes) if self._refine_factor: @@ -1386,6 +1414,9 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): def metric(self, metric: Literal["L2", "cosine", "dot"]) -> LanceHybridQueryBuilder: """Set the distance metric to use. + This is an alias for distance_type() and may be deprecated in the future. + Please use distance_type() instead. + Parameters ---------- metric: "L2" or "cosine" or "dot" @@ -1396,7 +1427,32 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): LanceVectorQueryBuilder The LanceQueryBuilder object. """ - self._metric = metric.lower() + return self.distance_type(metric) + + def distance_type( + self, distance_type: Literal["L2", "cosine", "dot"] + ) -> "LanceHybridQueryBuilder": + """Set the distance metric to use. + + When performing a vector search we try and find the "nearest" vectors according + to some kind of distance metric. This parameter controls which distance metric + to use. + + Note: if there is a vector index then the distance type used MUST match the + distance type used to train the vector index. If this is not done then the + results will be invalid. + + Parameters + ---------- + distance_type: "L2" or "cosine" or "dot" + The distance metric to use. By default "L2" is used. + + Returns + ------- + LanceVectorQueryBuilder + The LanceQueryBuilder object. + """ + self._distance_type = distance_type.lower() return self def refine_factor(self, refine_factor: int) -> LanceHybridQueryBuilder: diff --git a/python/python/tests/docs/test_binary_vector.py b/python/python/tests/docs/test_binary_vector.py index ec8da5be..fba4b85a 100644 --- a/python/python/tests/docs/test_binary_vector.py +++ b/python/python/tests/docs/test_binary_vector.py @@ -38,7 +38,7 @@ def test_binary_vector(): query = np.random.randint(0, 2, size=256) packed_query = np.packbits(query) - tbl.search(packed_query).metric("hamming").to_arrow() + tbl.search(packed_query).distance_type("hamming").to_arrow() # --8<-- [end:sync_binary_vector] db.drop_table("my_binary_vectors") diff --git a/python/python/tests/docs/test_search.py b/python/python/tests/docs/test_search.py index df1103ca..55be3529 100644 --- a/python/python/tests/docs/test_search.py +++ b/python/python/tests/docs/test_search.py @@ -65,7 +65,7 @@ def test_vector_search(): tbl.search(np.random.random((1536))).limit(10).to_list() # --8<-- [end:exhaustive_search] # --8<-- [start:exhaustive_search_cosine] - tbl.search(np.random.random((1536))).metric("cosine").limit(10).to_list() + tbl.search(np.random.random((1536))).distance_type("cosine").limit(10).to_list() # --8<-- [end:exhaustive_search_cosine] # --8<-- [start:create_table_with_nested_schema] # Let's add 100 sample rows to our dataset diff --git a/python/python/tests/test_query.py b/python/python/tests/test_query.py index 3a2bac9a..5f37b515 100644 --- a/python/python/tests/test_query.py +++ b/python/python/tests/test_query.py @@ -377,14 +377,14 @@ def test_query_builder_with_metric(table): df_default = LanceVectorQueryBuilder(table, query, vector_column_name).to_pandas() df_l2 = ( LanceVectorQueryBuilder(table, query, vector_column_name) - .metric("L2") + .distance_type("L2") .to_pandas() ) tm.assert_frame_equal(df_default, df_l2) df_cosine = ( LanceVectorQueryBuilder(table, query, vector_column_name) - .metric("cosine") + .distance_type("cosine") .limit(1) .to_pandas() ) @@ -401,7 +401,7 @@ def test_query_builder_with_different_vector_column(): vector_column_name = "foo_vector" builder = ( LanceVectorQueryBuilder(table, query, vector_column_name) - .metric("cosine") + .distance_type("cosine") .where("b < 10") .select(["b"]) .limit(2) diff --git a/python/python/tests/test_remote_db.py b/python/python/tests/test_remote_db.py index 0f94948c..64252ed3 100644 --- a/python/python/tests/test_remote_db.py +++ b/python/python/tests/test_remote_db.py @@ -366,7 +366,7 @@ def test_query_sync_maximal(): with query_test_table(handler) as table: ( table.search([1, 2, 3], vector_column_name="vector2", fast_search=True) - .metric("cosine") + .distance_type("cosine") .limit(42) .offset(10) .refine_factor(10) diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index 82b29db0..a10ea597 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -1242,7 +1242,9 @@ def test_hybrid_search_metric_type(tmp_db: DBConnection): # with custom metric result_dot = ( - table.search("feeling lucky", query_type="hybrid").metric("dot").to_arrow() + table.search("feeling lucky", query_type="hybrid") + .distance_type("dot") + .to_arrow() ) result_l2 = table.search("feeling lucky", query_type="hybrid").to_arrow() assert len(result_dot) > 0