mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-26 14:49:57 +00:00
feat: add support for returning all scores with rerankers (#2509)
Previously `return_score="all"` was supported only for the default reranker (RRF) and not the model based rerankers. This adds support for keeping all scores in the base reranker so that all model based rerankers can use it. Its a slower path than keeping just the relevance score but can be useful in debugging
This commit is contained in:
@@ -1374,6 +1374,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
if query_string is not None and not isinstance(query_string, str):
|
||||
raise ValueError("Reranking currently only supports string queries")
|
||||
self._str_query = query_string if query_string is not None else self._str_query
|
||||
if reranker.score == "all":
|
||||
self.with_row_id(True)
|
||||
return self
|
||||
|
||||
def bypass_vector_index(self) -> LanceVectorQueryBuilder:
|
||||
@@ -1569,6 +1571,8 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._reranker = reranker
|
||||
if reranker.score == "all":
|
||||
self.with_row_id(True)
|
||||
return self
|
||||
|
||||
|
||||
@@ -1845,6 +1849,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
|
||||
self._norm = normalize
|
||||
self._reranker = reranker
|
||||
if reranker.score == "all":
|
||||
self.with_row_id(True)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
@@ -74,9 +74,7 @@ class AnswerdotaiRerankers(Reranker):
|
||||
if self.score == "relevance":
|
||||
combined_results = self._keep_relevance_score(combined_results)
|
||||
elif self.score == "all":
|
||||
raise NotImplementedError(
|
||||
"Answerdotai Reranker does not support score='all' yet"
|
||||
)
|
||||
combined_results = self._merge_and_keep_scores(vector_results, fts_results)
|
||||
combined_results = combined_results.sort_by(
|
||||
[("_relevance_score", "descending")]
|
||||
)
|
||||
|
||||
@@ -232,6 +232,39 @@ class Reranker(ABC):
|
||||
|
||||
return deduped_table
|
||||
|
||||
def _merge_and_keep_scores(self, vector_results: pa.Table, fts_results: pa.Table):
|
||||
"""
|
||||
Merge the results from the vector and FTS search and keep the scores.
|
||||
This op is slower than just keeping relevance score but can be useful
|
||||
for debugging.
|
||||
"""
|
||||
# add nulls to fts results for _distance
|
||||
if "_distance" not in fts_results.column_names:
|
||||
fts_results = fts_results.append_column(
|
||||
"_distance",
|
||||
pa.array([None] * len(fts_results), type=pa.float32()),
|
||||
)
|
||||
# add nulls to vector results for _score
|
||||
if "_score" not in vector_results.column_names:
|
||||
vector_results = vector_results.append_column(
|
||||
"_score",
|
||||
pa.array([None] * len(vector_results), type=pa.float32()),
|
||||
)
|
||||
|
||||
# combine them and fill the scores
|
||||
vector_results_dict = {row["_rowid"]: row for row in vector_results.to_pylist()}
|
||||
fts_results_dict = {row["_rowid"]: row for row in fts_results.to_pylist()}
|
||||
|
||||
# merge them into vector_results
|
||||
for key, value in fts_results_dict.items():
|
||||
if key in vector_results_dict:
|
||||
vector_results_dict[key]["_score"] = value["_score"]
|
||||
else:
|
||||
vector_results_dict[key] = value
|
||||
|
||||
combined = pa.Table.from_pylist(list(vector_results_dict.values()))
|
||||
return combined
|
||||
|
||||
def _keep_relevance_score(self, combined_results: pa.Table):
|
||||
if self.score == "relevance":
|
||||
if "_score" in combined_results.column_names:
|
||||
|
||||
@@ -92,14 +92,14 @@ class CohereReranker(Reranker):
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
if self.score == "all":
|
||||
combined_results = self._merge_and_keep_scores(vector_results, fts_results)
|
||||
else:
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
combined_results = self._rerank(combined_results, query)
|
||||
if self.score == "relevance":
|
||||
combined_results = self._keep_relevance_score(combined_results)
|
||||
elif self.score == "all":
|
||||
raise NotImplementedError(
|
||||
"return_score='all' not implemented for cohere reranker"
|
||||
)
|
||||
|
||||
return combined_results
|
||||
|
||||
def rerank_vector(self, query: str, vector_results: pa.Table):
|
||||
|
||||
@@ -81,15 +81,15 @@ class CrossEncoderReranker(Reranker):
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
if self.score == "all":
|
||||
combined_results = self._merge_and_keep_scores(vector_results, fts_results)
|
||||
else:
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
combined_results = self._rerank(combined_results, query)
|
||||
# sort the results by _score
|
||||
if self.score == "relevance":
|
||||
combined_results = self._keep_relevance_score(combined_results)
|
||||
elif self.score == "all":
|
||||
raise NotImplementedError(
|
||||
"return_score='all' not implemented for CrossEncoderReranker"
|
||||
)
|
||||
|
||||
combined_results = combined_results.sort_by(
|
||||
[("_relevance_score", "descending")]
|
||||
)
|
||||
|
||||
@@ -97,14 +97,14 @@ class JinaReranker(Reranker):
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
if self.score == "all":
|
||||
combined_results = self._merge_and_keep_scores(vector_results, fts_results)
|
||||
else:
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
combined_results = self._rerank(combined_results, query)
|
||||
if self.score == "relevance":
|
||||
combined_results = self._keep_relevance_score(combined_results)
|
||||
elif self.score == "all":
|
||||
raise NotImplementedError(
|
||||
"return_score='all' not implemented for JinaReranker"
|
||||
)
|
||||
|
||||
return combined_results
|
||||
|
||||
def rerank_vector(self, query: str, vector_results: pa.Table):
|
||||
|
||||
@@ -88,14 +88,13 @@ class OpenaiReranker(Reranker):
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
if self.score == "all":
|
||||
combined_results = self._merge_and_keep_scores(vector_results, fts_results)
|
||||
else:
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
combined_results = self._rerank(combined_results, query)
|
||||
if self.score == "relevance":
|
||||
combined_results = self._keep_relevance_score(combined_results)
|
||||
elif self.score == "all":
|
||||
raise NotImplementedError(
|
||||
"OpenAI Reranker does not support score='all' yet"
|
||||
)
|
||||
|
||||
combined_results = combined_results.sort_by(
|
||||
[("_relevance_score", "descending")]
|
||||
|
||||
@@ -94,14 +94,14 @@ class VoyageAIReranker(Reranker):
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
if self.score == "all":
|
||||
combined_results = self._merge_and_keep_scores(vector_results, fts_results)
|
||||
else:
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
combined_results = self._rerank(combined_results, query)
|
||||
if self.score == "relevance":
|
||||
combined_results = self._keep_relevance_score(combined_results)
|
||||
elif self.score == "all":
|
||||
raise NotImplementedError(
|
||||
"return_score='all' not implemented for voyageai reranker"
|
||||
)
|
||||
|
||||
return combined_results
|
||||
|
||||
def rerank_vector(self, query: str, vector_results: pa.Table):
|
||||
|
||||
@@ -499,3 +499,19 @@ def test_empty_result_reranker():
|
||||
.rerank(reranker)
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_cross_encoder_reranker_return_all(tmp_path, use_tantivy):
|
||||
pytest.importorskip("sentence_transformers")
|
||||
reranker = CrossEncoderReranker(return_score="all")
|
||||
table, schema = get_test_table(tmp_path, use_tantivy)
|
||||
query = "single player experience"
|
||||
result = (
|
||||
table.search(query, query_type="hybrid", vector_column_name="vector")
|
||||
.rerank(reranker=reranker)
|
||||
.to_arrow()
|
||||
)
|
||||
assert "_relevance_score" in result.column_names
|
||||
assert "_score" in result.column_names
|
||||
assert "_distance" in result.column_names
|
||||
|
||||
Reference in New Issue
Block a user