From f076bb41f4c94273260044a97887b8b66596c06e Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Tue, 15 Jul 2025 21:03:03 +0530 Subject: [PATCH] feat: add support for returning all scores with rerankers (#2509) Previously `return_score="all"` was supported only for the default reranker (RRF) and not the model based rerankers. This adds support for keeping all scores in the base reranker so that all model based rerankers can use it. Its a slower path than keeping just the relevance score but can be useful in debugging --- python/python/lancedb/query.py | 6 ++++ .../python/lancedb/rerankers/answerdotai.py | 4 +-- python/python/lancedb/rerankers/base.py | 33 +++++++++++++++++++ python/python/lancedb/rerankers/cohere.py | 10 +++--- .../python/lancedb/rerankers/cross_encoder.py | 10 +++--- python/python/lancedb/rerankers/jinaai.py | 10 +++--- python/python/lancedb/rerankers/openai.py | 9 +++-- python/python/lancedb/rerankers/voyageai.py | 10 +++--- python/python/tests/test_rerankers.py | 16 +++++++++ 9 files changed, 80 insertions(+), 28 deletions(-) diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 23ab8c18..56f988e8 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -1374,6 +1374,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): if query_string is not None and not isinstance(query_string, str): raise ValueError("Reranking currently only supports string queries") self._str_query = query_string if query_string is not None else self._str_query + if reranker.score == "all": + self.with_row_id(True) return self def bypass_vector_index(self) -> LanceVectorQueryBuilder: @@ -1569,6 +1571,8 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): The LanceQueryBuilder object. """ self._reranker = reranker + if reranker.score == "all": + self.with_row_id(True) return self @@ -1845,6 +1849,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): self._norm = normalize self._reranker = reranker + if reranker.score == "all": + self.with_row_id(True) return self diff --git a/python/python/lancedb/rerankers/answerdotai.py b/python/python/lancedb/rerankers/answerdotai.py index 642f51aa..d615ed31 100644 --- a/python/python/lancedb/rerankers/answerdotai.py +++ b/python/python/lancedb/rerankers/answerdotai.py @@ -74,9 +74,7 @@ class AnswerdotaiRerankers(Reranker): if self.score == "relevance": combined_results = self._keep_relevance_score(combined_results) elif self.score == "all": - raise NotImplementedError( - "Answerdotai Reranker does not support score='all' yet" - ) + combined_results = self._merge_and_keep_scores(vector_results, fts_results) combined_results = combined_results.sort_by( [("_relevance_score", "descending")] ) diff --git a/python/python/lancedb/rerankers/base.py b/python/python/lancedb/rerankers/base.py index bc33eee9..0c546a77 100644 --- a/python/python/lancedb/rerankers/base.py +++ b/python/python/lancedb/rerankers/base.py @@ -232,6 +232,39 @@ class Reranker(ABC): return deduped_table + def _merge_and_keep_scores(self, vector_results: pa.Table, fts_results: pa.Table): + """ + Merge the results from the vector and FTS search and keep the scores. + This op is slower than just keeping relevance score but can be useful + for debugging. + """ + # add nulls to fts results for _distance + if "_distance" not in fts_results.column_names: + fts_results = fts_results.append_column( + "_distance", + pa.array([None] * len(fts_results), type=pa.float32()), + ) + # add nulls to vector results for _score + if "_score" not in vector_results.column_names: + vector_results = vector_results.append_column( + "_score", + pa.array([None] * len(vector_results), type=pa.float32()), + ) + + # combine them and fill the scores + vector_results_dict = {row["_rowid"]: row for row in vector_results.to_pylist()} + fts_results_dict = {row["_rowid"]: row for row in fts_results.to_pylist()} + + # merge them into vector_results + for key, value in fts_results_dict.items(): + if key in vector_results_dict: + vector_results_dict[key]["_score"] = value["_score"] + else: + vector_results_dict[key] = value + + combined = pa.Table.from_pylist(list(vector_results_dict.values())) + return combined + def _keep_relevance_score(self, combined_results: pa.Table): if self.score == "relevance": if "_score" in combined_results.column_names: diff --git a/python/python/lancedb/rerankers/cohere.py b/python/python/lancedb/rerankers/cohere.py index 0b499213..10796067 100644 --- a/python/python/lancedb/rerankers/cohere.py +++ b/python/python/lancedb/rerankers/cohere.py @@ -92,14 +92,14 @@ class CohereReranker(Reranker): vector_results: pa.Table, fts_results: pa.Table, ): - combined_results = self.merge_results(vector_results, fts_results) + if self.score == "all": + combined_results = self._merge_and_keep_scores(vector_results, fts_results) + else: + combined_results = self.merge_results(vector_results, fts_results) combined_results = self._rerank(combined_results, query) if self.score == "relevance": combined_results = self._keep_relevance_score(combined_results) - elif self.score == "all": - raise NotImplementedError( - "return_score='all' not implemented for cohere reranker" - ) + return combined_results def rerank_vector(self, query: str, vector_results: pa.Table): diff --git a/python/python/lancedb/rerankers/cross_encoder.py b/python/python/lancedb/rerankers/cross_encoder.py index 4ca2b93d..fd24eed1 100644 --- a/python/python/lancedb/rerankers/cross_encoder.py +++ b/python/python/lancedb/rerankers/cross_encoder.py @@ -81,15 +81,15 @@ class CrossEncoderReranker(Reranker): vector_results: pa.Table, fts_results: pa.Table, ): - combined_results = self.merge_results(vector_results, fts_results) + if self.score == "all": + combined_results = self._merge_and_keep_scores(vector_results, fts_results) + else: + combined_results = self.merge_results(vector_results, fts_results) combined_results = self._rerank(combined_results, query) # sort the results by _score if self.score == "relevance": combined_results = self._keep_relevance_score(combined_results) - elif self.score == "all": - raise NotImplementedError( - "return_score='all' not implemented for CrossEncoderReranker" - ) + combined_results = combined_results.sort_by( [("_relevance_score", "descending")] ) diff --git a/python/python/lancedb/rerankers/jinaai.py b/python/python/lancedb/rerankers/jinaai.py index c398cda8..84ab21f4 100644 --- a/python/python/lancedb/rerankers/jinaai.py +++ b/python/python/lancedb/rerankers/jinaai.py @@ -97,14 +97,14 @@ class JinaReranker(Reranker): vector_results: pa.Table, fts_results: pa.Table, ): - combined_results = self.merge_results(vector_results, fts_results) + if self.score == "all": + combined_results = self._merge_and_keep_scores(vector_results, fts_results) + else: + combined_results = self.merge_results(vector_results, fts_results) combined_results = self._rerank(combined_results, query) if self.score == "relevance": combined_results = self._keep_relevance_score(combined_results) - elif self.score == "all": - raise NotImplementedError( - "return_score='all' not implemented for JinaReranker" - ) + return combined_results def rerank_vector(self, query: str, vector_results: pa.Table): diff --git a/python/python/lancedb/rerankers/openai.py b/python/python/lancedb/rerankers/openai.py index 6ee2da2e..7b181e80 100644 --- a/python/python/lancedb/rerankers/openai.py +++ b/python/python/lancedb/rerankers/openai.py @@ -88,14 +88,13 @@ class OpenaiReranker(Reranker): vector_results: pa.Table, fts_results: pa.Table, ): - combined_results = self.merge_results(vector_results, fts_results) + if self.score == "all": + combined_results = self._merge_and_keep_scores(vector_results, fts_results) + else: + combined_results = self.merge_results(vector_results, fts_results) combined_results = self._rerank(combined_results, query) if self.score == "relevance": combined_results = self._keep_relevance_score(combined_results) - elif self.score == "all": - raise NotImplementedError( - "OpenAI Reranker does not support score='all' yet" - ) combined_results = combined_results.sort_by( [("_relevance_score", "descending")] diff --git a/python/python/lancedb/rerankers/voyageai.py b/python/python/lancedb/rerankers/voyageai.py index 9ef8818a..f99c0ae4 100644 --- a/python/python/lancedb/rerankers/voyageai.py +++ b/python/python/lancedb/rerankers/voyageai.py @@ -94,14 +94,14 @@ class VoyageAIReranker(Reranker): vector_results: pa.Table, fts_results: pa.Table, ): - combined_results = self.merge_results(vector_results, fts_results) + if self.score == "all": + combined_results = self._merge_and_keep_scores(vector_results, fts_results) + else: + combined_results = self.merge_results(vector_results, fts_results) combined_results = self._rerank(combined_results, query) if self.score == "relevance": combined_results = self._keep_relevance_score(combined_results) - elif self.score == "all": - raise NotImplementedError( - "return_score='all' not implemented for voyageai reranker" - ) + return combined_results def rerank_vector(self, query: str, vector_results: pa.Table): diff --git a/python/python/tests/test_rerankers.py b/python/python/tests/test_rerankers.py index d82aae2a..c40afac8 100644 --- a/python/python/tests/test_rerankers.py +++ b/python/python/tests/test_rerankers.py @@ -499,3 +499,19 @@ def test_empty_result_reranker(): .rerank(reranker) .to_arrow() ) + + +@pytest.mark.parametrize("use_tantivy", [True, False]) +def test_cross_encoder_reranker_return_all(tmp_path, use_tantivy): + pytest.importorskip("sentence_transformers") + reranker = CrossEncoderReranker(return_score="all") + table, schema = get_test_table(tmp_path, use_tantivy) + query = "single player experience" + result = ( + table.search(query, query_type="hybrid", vector_column_name="vector") + .rerank(reranker=reranker) + .to_arrow() + ) + assert "_relevance_score" in result.column_names + assert "_score" in result.column_names + assert "_distance" in result.column_names