mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-07 04:12:59 +00:00
fix: add better check for empty results in hybrid search (#2252)
fixes: https://github.com/lancedb/lancedb/issues/2249
This commit is contained in:
@@ -1220,6 +1220,10 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
vector_results = LanceHybridQueryBuilder._rank(vector_results, "_distance")
|
||||
fts_results = LanceHybridQueryBuilder._rank(fts_results, "_score")
|
||||
|
||||
original_distances = None
|
||||
original_scores = None
|
||||
original_distance_row_ids = None
|
||||
original_score_row_ids = None
|
||||
# normalize the scores to be between 0 and 1, 0 being most relevant
|
||||
# We check whether the results (vector and FTS) are empty, because when
|
||||
# they are, they often are missing the _rowid column, which causes an error
|
||||
@@ -1249,7 +1253,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
|
||||
check_reranker_result(results)
|
||||
|
||||
if "_distance" in results.column_names:
|
||||
if "_distance" in results.column_names and original_distances is not None:
|
||||
# restore the original distances
|
||||
indices = pc.index_in(
|
||||
results["_rowid"], original_distance_row_ids, skip_nulls=True
|
||||
@@ -1258,7 +1262,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
distance_i = results.column_names.index("_distance")
|
||||
results = results.set_column(distance_i, "_distance", original_distances)
|
||||
|
||||
if "_score" in results.column_names:
|
||||
if "_score" in results.column_names and original_scores is not None:
|
||||
# restore the original scores
|
||||
indices = pc.index_in(
|
||||
results["_rowid"], original_score_row_ids, skip_nulls=True
|
||||
|
||||
@@ -368,6 +368,26 @@ def test_rrf_reranker_distance():
|
||||
assert score == fts_scores[rowid], "Score mismatch"
|
||||
assert found_match, "No results matched between hybrid and fts search"
|
||||
|
||||
# Test for empty fts results
|
||||
fts_results = (
|
||||
table.search("abcxyz" * 100, query_type="fts").with_row_id(True).to_list()
|
||||
)
|
||||
hybrid_results = (
|
||||
table.search(query_type="hybrid")
|
||||
.vector([0.0] * 32)
|
||||
.text("abcxyz" * 100)
|
||||
.with_row_id(True)
|
||||
.rerank(reranker)
|
||||
.to_list()
|
||||
)
|
||||
assert len(fts_results) == 0
|
||||
# confirm if _rowid, _score, _distance & _relevance_score are present in hybrid
|
||||
assert len(hybrid_results) > 0
|
||||
assert "_rowid" in hybrid_results[0]
|
||||
assert "_score" in hybrid_results[0]
|
||||
assert "_distance" in hybrid_results[0]
|
||||
assert "_relevance_score" in hybrid_results[0]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
|
||||
|
||||
Reference in New Issue
Block a user