From 55104c5bae87dbe3af6f1b4c2ada52c3beeb77bc Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Tue, 29 Oct 2024 13:51:18 -0700 Subject: [PATCH] feat: allow distance type (metric) to be specified during hybrid search (#1777) --- python/python/lancedb/conftest.py | 19 ++++++++++++++++- python/python/lancedb/query.py | 24 +++++++++++++++++++++- python/python/tests/test_table.py | 34 ++++++++++++++++++++++++++----- 3 files changed, 70 insertions(+), 7 deletions(-) diff --git a/python/python/lancedb/conftest.py b/python/python/lancedb/conftest.py index 7a6a5fd1..a1c748f5 100644 --- a/python/python/lancedb/conftest.py +++ b/python/python/lancedb/conftest.py @@ -26,7 +26,7 @@ registry = EmbeddingFunctionRegistry.get_instance() @registry.register("test") class MockTextEmbeddingFunction(TextEmbeddingFunction): """ - Return the hash of the first 10 characters + Return the hash of the first 10 characters (normalized) """ def generate_embeddings(self, texts): @@ -41,6 +41,23 @@ class MockTextEmbeddingFunction(TextEmbeddingFunction): return 10 +@registry.register("nonnorm") +class MockNonNormTextEmbeddingFunction(TextEmbeddingFunction): + """ + Return the ord of the first 10 characters (not normalized) + """ + + def generate_embeddings(self, texts): + return [self._compute_one_embedding(row) for row in texts] + + def _compute_one_embedding(self, row): + emb = np.array([float(ord(c)) for c in row[:10]]) + return emb if len(emb) == 10 else [0] * 10 + + def ndims(self): + return 10 + + class RateLimitedAPI: rate_limit = 0.1 # 1 request per 0.1 second last_request_time = 0 diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index c79b8846..1062289e 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -983,6 +983,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): self._reranker = RRFReranker() self._nprobes = None self._refine_factor = None + self._metric = None self._phrase_query = False def _validate_query(self, query, vector=None, text=None): @@ -1050,6 +1051,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): self._fts_query.with_row_id(True) if self._phrase_query: self._fts_query.phrase_query(True) + if self._metric: + self._vector_query.metric(self._metric) if self._nprobes: self._vector_query.nprobes(self._nprobes) if self._refine_factor: @@ -1067,6 +1070,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): if self._norm == "rank": vector_results = self._rank(vector_results, "_distance") fts_results = self._rank(fts_results, "_score") + # normalize the scores to be between 0 and 1, 0 being most relevant vector_results = self._normalize_scores(vector_results, "_distance") @@ -1115,7 +1119,9 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): rng = max else: rng = max - min - scores = (scores - min) / rng + # If rng is 0 then min and max are both 0 and so we can leave the scores as is + if rng != 0: + scores = (scores - min) / rng if invert: scores = 1 - scores # replace the _score column with the ranks @@ -1177,6 +1183,22 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): self._nprobes = nprobes return self + def metric(self, metric: Literal["L2", "cosine", "dot"]) -> LanceHybridQueryBuilder: + """Set the distance metric to use. + + Parameters + ---------- + metric: "L2" or "cosine" or "dot" + The distance metric to use. By default "L2" is used. + + Returns + ------- + LanceVectorQueryBuilder + The LanceQueryBuilder object. + """ + self._metric = metric.lower() + return self + def refine_factor(self, refine_factor: int) -> LanceHybridQueryBuilder: """ Refine the vector search results by reading extra elements and diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index 65ec7b3c..bdf22ddf 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -991,13 +991,10 @@ def test_count_rows(db): assert table.count_rows(filter="text='bar'") == 1 -def test_hybrid_search(db, tmp_path): - # This test uses an FTS index - pytest.importorskip("lancedb.fts") - +def setup_hybrid_search_table(tmp_path, embedding_func): db = MockDB(str(tmp_path)) # Create a LanceDB table schema with a vector and a text column - emb = EmbeddingFunctionRegistry.get_instance().get("test")() + emb = EmbeddingFunctionRegistry.get_instance().get(embedding_func)() class MyTable(LanceModel): text: str = emb.SourceField() @@ -1030,6 +1027,15 @@ def test_hybrid_search(db, tmp_path): # Create a fts index table.create_fts_index("text") + return table, MyTable, emb + + +def test_hybrid_search(tmp_path): + # This test uses an FTS index + pytest.importorskip("lancedb.fts") + + table, MyTable, emb = setup_hybrid_search_table(tmp_path, "test") + result1 = ( table.search("Our father who art in heaven", query_type="hybrid") .rerank(normalize="score") @@ -1094,6 +1100,24 @@ def test_hybrid_search(db, tmp_path): table.search(query_type="hybrid").text("Arrrrggghhhhhhh").to_list() +def test_hybrid_search_metric_type(db, tmp_path): + # This test uses an FTS index + pytest.importorskip("lancedb.fts") + + # Need to use nonnorm as the embedding function so L2 and dot results + # are different + table, _, _ = setup_hybrid_search_table(tmp_path, "nonnorm") + + # with custom metric + result_dot = ( + table.search("feeling lucky", query_type="hybrid").metric("dot").to_arrow() + ) + result_l2 = table.search("feeling lucky", query_type="hybrid").to_arrow() + assert len(result_dot) > 0 + assert len(result_l2) > 0 + assert result_dot["_relevance_score"] != result_l2["_relevance_score"] + + @pytest.mark.parametrize( "consistency_interval", [None, timedelta(seconds=0), timedelta(seconds=0.1)] )