diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 413d26a3..348da579 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -1220,6 +1220,10 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): vector_results = LanceHybridQueryBuilder._rank(vector_results, "_distance") fts_results = LanceHybridQueryBuilder._rank(fts_results, "_score") + original_distances = None + original_scores = None + original_distance_row_ids = None + original_score_row_ids = None # normalize the scores to be between 0 and 1, 0 being most relevant # We check whether the results (vector and FTS) are empty, because when # they are, they often are missing the _rowid column, which causes an error @@ -1249,7 +1253,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): check_reranker_result(results) - if "_distance" in results.column_names: + if "_distance" in results.column_names and original_distances is not None: # restore the original distances indices = pc.index_in( results["_rowid"], original_distance_row_ids, skip_nulls=True @@ -1258,7 +1262,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): distance_i = results.column_names.index("_distance") results = results.set_column(distance_i, "_distance", original_distances) - if "_score" in results.column_names: + if "_score" in results.column_names and original_scores is not None: # restore the original scores indices = pc.index_in( results["_rowid"], original_score_row_ids, skip_nulls=True diff --git a/python/python/tests/test_rerankers.py b/python/python/tests/test_rerankers.py index a4cd1290..4c50c831 100644 --- a/python/python/tests/test_rerankers.py +++ b/python/python/tests/test_rerankers.py @@ -368,6 +368,26 @@ def test_rrf_reranker_distance(): assert score == fts_scores[rowid], "Score mismatch" assert found_match, "No results matched between hybrid and fts search" + # Test for empty fts results + fts_results = ( + table.search("abcxyz" * 100, query_type="fts").with_row_id(True).to_list() + ) + hybrid_results = ( + table.search(query_type="hybrid") + .vector([0.0] * 32) + .text("abcxyz" * 100) + .with_row_id(True) + .rerank(reranker) + .to_list() + ) + assert len(fts_results) == 0 + # confirm if _rowid, _score, _distance & _relevance_score are present in hybrid + assert len(hybrid_results) > 0 + assert "_rowid" in hybrid_results[0] + assert "_score" in hybrid_results[0] + assert "_distance" in hybrid_results[0] + assert "_relevance_score" in hybrid_results[0] + @pytest.mark.skipif( os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"