fix: add better check for empty results in hybrid search (#2252)

fixes: https://github.com/lancedb/lancedb/issues/2249
This commit is contained in:
Ayush Chaurasia
2025-03-21 13:05:05 +05:30
committed by GitHub
parent b595d8a579
commit ba1ded933a
2 changed files with 26 additions and 2 deletions

View File

@@ -1220,6 +1220,10 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
vector_results = LanceHybridQueryBuilder._rank(vector_results, "_distance")
fts_results = LanceHybridQueryBuilder._rank(fts_results, "_score")
original_distances = None
original_scores = None
original_distance_row_ids = None
original_score_row_ids = None
# normalize the scores to be between 0 and 1, 0 being most relevant
# We check whether the results (vector and FTS) are empty, because when
# they are, they often are missing the _rowid column, which causes an error
@@ -1249,7 +1253,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
check_reranker_result(results)
if "_distance" in results.column_names:
if "_distance" in results.column_names and original_distances is not None:
# restore the original distances
indices = pc.index_in(
results["_rowid"], original_distance_row_ids, skip_nulls=True
@@ -1258,7 +1262,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
distance_i = results.column_names.index("_distance")
results = results.set_column(distance_i, "_distance", original_distances)
if "_score" in results.column_names:
if "_score" in results.column_names and original_scores is not None:
# restore the original scores
indices = pc.index_in(
results["_rowid"], original_score_row_ids, skip_nulls=True

View File

@@ -368,6 +368,26 @@ def test_rrf_reranker_distance():
assert score == fts_scores[rowid], "Score mismatch"
assert found_match, "No results matched between hybrid and fts search"
# Test for empty fts results
fts_results = (
table.search("abcxyz" * 100, query_type="fts").with_row_id(True).to_list()
)
hybrid_results = (
table.search(query_type="hybrid")
.vector([0.0] * 32)
.text("abcxyz" * 100)
.with_row_id(True)
.rerank(reranker)
.to_list()
)
assert len(fts_results) == 0
# confirm if _rowid, _score, _distance & _relevance_score are present in hybrid
assert len(hybrid_results) > 0
assert "_rowid" in hybrid_results[0]
assert "_score" in hybrid_results[0]
assert "_distance" in hybrid_results[0]
assert "_relevance_score" in hybrid_results[0]
@pytest.mark.skipif(
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"