mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 05:19:58 +00:00
fix: robust handling of empty result when reranking (#2313)
I found some edge cases while running experiments that - depending on the base reranking libraries, some of them don't handle empty lists well. This PR manually checks if the result set to be reranked is empty <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Bug Fixes** - Enhanced search result processing by ensuring that reordering only occurs when valid, non-empty results are available, thereby preventing unnecessary operations and potential errors. - **Tests** - Added automated tests to verify that empty search result sets are handled correctly, ensuring consistent behavior across various rerankers. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
@@ -47,6 +47,9 @@ class AnswerdotaiRerankers(Reranker):
|
||||
)
|
||||
|
||||
def _rerank(self, result_set: pa.Table, query: str):
|
||||
result_set = self._handle_empty_results(result_set)
|
||||
if len(result_set) == 0:
|
||||
return result_set
|
||||
docs = result_set[self.column].to_pylist()
|
||||
doc_ids = list(range(len(docs)))
|
||||
result = self.reranker.rank(query, docs, doc_ids=doc_ids)
|
||||
@@ -83,7 +86,6 @@ class AnswerdotaiRerankers(Reranker):
|
||||
vector_results = self._rerank(vector_results, query)
|
||||
if self.score == "relevance":
|
||||
vector_results = vector_results.drop_columns(["_distance"])
|
||||
|
||||
vector_results = vector_results.sort_by([("_relevance_score", "descending")])
|
||||
return vector_results
|
||||
|
||||
@@ -91,7 +93,5 @@ class AnswerdotaiRerankers(Reranker):
|
||||
fts_results = self._rerank(fts_results, query)
|
||||
if self.score == "relevance":
|
||||
fts_results = fts_results.drop_columns(["_score"])
|
||||
|
||||
fts_results = fts_results.sort_by([("_relevance_score", "descending")])
|
||||
|
||||
return fts_results
|
||||
|
||||
@@ -65,6 +65,16 @@ class Reranker(ABC):
|
||||
f"{self.__class__.__name__} does not implement rerank_vector"
|
||||
)
|
||||
|
||||
def _handle_empty_results(self, results: pa.Table):
|
||||
"""
|
||||
Helper method to handle empty FTS results consistently
|
||||
"""
|
||||
if len(results) > 0:
|
||||
return results
|
||||
return results.append_column(
|
||||
"_relevance_score", pa.array([], type=pa.float32())
|
||||
)
|
||||
|
||||
def rerank_fts(
|
||||
self,
|
||||
query: str,
|
||||
|
||||
@@ -62,6 +62,9 @@ class CohereReranker(Reranker):
|
||||
return cohere.Client(os.environ.get("COHERE_API_KEY") or self.api_key)
|
||||
|
||||
def _rerank(self, result_set: pa.Table, query: str):
|
||||
result_set = self._handle_empty_results(result_set)
|
||||
if len(result_set) == 0:
|
||||
return result_set
|
||||
docs = result_set[self.column].to_pylist()
|
||||
response = self._client.rerank(
|
||||
query=query,
|
||||
@@ -99,24 +102,14 @@ class CohereReranker(Reranker):
|
||||
)
|
||||
return combined_results
|
||||
|
||||
def rerank_vector(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
):
|
||||
result_set = self._rerank(vector_results, query)
|
||||
def rerank_vector(self, query: str, vector_results: pa.Table):
|
||||
vector_results = self._rerank(vector_results, query)
|
||||
if self.score == "relevance":
|
||||
result_set = result_set.drop_columns(["_distance"])
|
||||
vector_results = vector_results.drop_columns(["_distance"])
|
||||
return vector_results
|
||||
|
||||
return result_set
|
||||
|
||||
def rerank_fts(
|
||||
self,
|
||||
query: str,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
result_set = self._rerank(fts_results, query)
|
||||
def rerank_fts(self, query: str, fts_results: pa.Table):
|
||||
fts_results = self._rerank(fts_results, query)
|
||||
if self.score == "relevance":
|
||||
result_set = result_set.drop_columns(["_score"])
|
||||
|
||||
return result_set
|
||||
fts_results = fts_results.drop_columns(["_score"])
|
||||
return fts_results
|
||||
|
||||
@@ -63,6 +63,9 @@ class CrossEncoderReranker(Reranker):
|
||||
return cross_encoder
|
||||
|
||||
def _rerank(self, result_set: pa.Table, query: str):
|
||||
result_set = self._handle_empty_results(result_set)
|
||||
if len(result_set) == 0:
|
||||
return result_set
|
||||
passages = result_set[self.column].to_pylist()
|
||||
cross_inp = [[query, passage] for passage in passages]
|
||||
cross_scores = self.model.predict(cross_inp)
|
||||
@@ -93,11 +96,7 @@ class CrossEncoderReranker(Reranker):
|
||||
|
||||
return combined_results
|
||||
|
||||
def rerank_vector(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
):
|
||||
def rerank_vector(self, query: str, vector_results: pa.Table):
|
||||
vector_results = self._rerank(vector_results, query)
|
||||
if self.score == "relevance":
|
||||
vector_results = vector_results.drop_columns(["_distance"])
|
||||
@@ -105,11 +104,7 @@ class CrossEncoderReranker(Reranker):
|
||||
vector_results = vector_results.sort_by([("_relevance_score", "descending")])
|
||||
return vector_results
|
||||
|
||||
def rerank_fts(
|
||||
self,
|
||||
query: str,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
def rerank_fts(self, query: str, fts_results: pa.Table):
|
||||
fts_results = self._rerank(fts_results, query)
|
||||
if self.score == "relevance":
|
||||
fts_results = fts_results.drop_columns(["_score"])
|
||||
|
||||
@@ -62,6 +62,9 @@ class JinaReranker(Reranker):
|
||||
return self._session
|
||||
|
||||
def _rerank(self, result_set: pa.Table, query: str):
|
||||
result_set = self._handle_empty_results(result_set)
|
||||
if len(result_set) == 0:
|
||||
return result_set
|
||||
docs = result_set[self.column].to_pylist()
|
||||
response = self._client.post( # type: ignore
|
||||
API_URL,
|
||||
@@ -104,24 +107,14 @@ class JinaReranker(Reranker):
|
||||
)
|
||||
return combined_results
|
||||
|
||||
def rerank_vector(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
):
|
||||
result_set = self._rerank(vector_results, query)
|
||||
def rerank_vector(self, query: str, vector_results: pa.Table):
|
||||
vector_results = self._rerank(vector_results, query)
|
||||
if self.score == "relevance":
|
||||
result_set = result_set.drop_columns(["_distance"])
|
||||
vector_results = vector_results.drop_columns(["_distance"])
|
||||
return vector_results
|
||||
|
||||
return result_set
|
||||
|
||||
def rerank_fts(
|
||||
self,
|
||||
query: str,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
result_set = self._rerank(fts_results, query)
|
||||
def rerank_fts(self, query: str, fts_results: pa.Table):
|
||||
fts_results = self._rerank(fts_results, query)
|
||||
if self.score == "relevance":
|
||||
result_set = result_set.drop_columns(["_score"])
|
||||
|
||||
return result_set
|
||||
fts_results = fts_results.drop_columns(["_score"])
|
||||
return fts_results
|
||||
|
||||
@@ -44,6 +44,9 @@ class OpenaiReranker(Reranker):
|
||||
self.api_key = api_key
|
||||
|
||||
def _rerank(self, result_set: pa.Table, query: str):
|
||||
result_set = self._handle_empty_results(result_set)
|
||||
if len(result_set) == 0:
|
||||
return result_set
|
||||
docs = result_set[self.column].to_pylist()
|
||||
response = self._client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
@@ -104,18 +107,14 @@ class OpenaiReranker(Reranker):
|
||||
vector_results = self._rerank(vector_results, query)
|
||||
if self.score == "relevance":
|
||||
vector_results = vector_results.drop_columns(["_distance"])
|
||||
|
||||
vector_results = vector_results.sort_by([("_relevance_score", "descending")])
|
||||
|
||||
return vector_results
|
||||
|
||||
def rerank_fts(self, query: str, fts_results: pa.Table):
|
||||
fts_results = self._rerank(fts_results, query)
|
||||
if self.score == "relevance":
|
||||
fts_results = fts_results.drop_columns(["_score"])
|
||||
|
||||
fts_results = fts_results.sort_by([("_relevance_score", "descending")])
|
||||
|
||||
return fts_results
|
||||
|
||||
@cached_property
|
||||
|
||||
@@ -63,6 +63,9 @@ class VoyageAIReranker(Reranker):
|
||||
)
|
||||
|
||||
def _rerank(self, result_set: pa.Table, query: str):
|
||||
result_set = self._handle_empty_results(result_set)
|
||||
if len(result_set) == 0:
|
||||
return result_set
|
||||
docs = result_set[self.column].to_pylist()
|
||||
response = self._client.rerank(
|
||||
query=query,
|
||||
@@ -101,24 +104,14 @@ class VoyageAIReranker(Reranker):
|
||||
)
|
||||
return combined_results
|
||||
|
||||
def rerank_vector(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
):
|
||||
result_set = self._rerank(vector_results, query)
|
||||
def rerank_vector(self, query: str, vector_results: pa.Table):
|
||||
vector_results = self._rerank(vector_results, query)
|
||||
if self.score == "relevance":
|
||||
result_set = result_set.drop_columns(["_distance"])
|
||||
vector_results = vector_results.drop_columns(["_distance"])
|
||||
return vector_results
|
||||
|
||||
return result_set
|
||||
|
||||
def rerank_fts(
|
||||
self,
|
||||
query: str,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
result_set = self._rerank(fts_results, query)
|
||||
def rerank_fts(self, query: str, fts_results: pa.Table):
|
||||
fts_results = self._rerank(fts_results, query)
|
||||
if self.score == "relevance":
|
||||
result_set = result_set.drop_columns(["_score"])
|
||||
|
||||
return result_set
|
||||
fts_results = fts_results.drop_columns(["_score"])
|
||||
return fts_results
|
||||
|
||||
@@ -457,3 +457,45 @@ def test_voyageai_reranker(tmp_path, use_tantivy):
|
||||
reranker = VoyageAIReranker(model_name="rerank-2")
|
||||
table, schema = get_test_table(tmp_path, use_tantivy)
|
||||
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
||||
|
||||
|
||||
def test_empty_result_reranker():
|
||||
pytest.importorskip("sentence_transformers")
|
||||
db = lancedb.connect("memory://")
|
||||
|
||||
# Define schema
|
||||
schema = pa.schema(
|
||||
[
|
||||
("id", pa.int64()),
|
||||
("text", pa.string()),
|
||||
("vector", pa.list_(pa.float32(), 128)), # 128-dimensional vector
|
||||
]
|
||||
)
|
||||
|
||||
# Create empty table with schema
|
||||
empty_table = db.create_table("empty_table", schema=schema, mode="overwrite")
|
||||
empty_table.create_fts_index("text", use_tantivy=False, replace=True)
|
||||
for reranker in [
|
||||
CrossEncoderReranker(),
|
||||
# ColbertReranker(),
|
||||
# AnswerdotaiRerankers(),
|
||||
# OpenaiReranker(),
|
||||
# JinaReranker(),
|
||||
# VoyageAIReranker(model_name="rerank-2"),
|
||||
]:
|
||||
results = (
|
||||
empty_table.search(list(range(128)))
|
||||
.limit(3)
|
||||
.rerank(reranker, "query")
|
||||
.to_arrow()
|
||||
)
|
||||
# check if empty set contains _relevance_score column
|
||||
assert "_relevance_score" in results.column_names
|
||||
assert len(results) == 0
|
||||
|
||||
results = (
|
||||
empty_table.search("query", query_type="fts")
|
||||
.limit(3)
|
||||
.rerank(reranker)
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user