diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 23ab8c18..56f988e8 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -1374,6 +1374,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): if query_string is not None and not isinstance(query_string, str): raise ValueError("Reranking currently only supports string queries") self._str_query = query_string if query_string is not None else self._str_query + if reranker.score == "all": + self.with_row_id(True) return self def bypass_vector_index(self) -> LanceVectorQueryBuilder: @@ -1569,6 +1571,8 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): The LanceQueryBuilder object. """ self._reranker = reranker + if reranker.score == "all": + self.with_row_id(True) return self @@ -1845,6 +1849,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): self._norm = normalize self._reranker = reranker + if reranker.score == "all": + self.with_row_id(True) return self diff --git a/python/python/lancedb/rerankers/answerdotai.py b/python/python/lancedb/rerankers/answerdotai.py index 642f51aa..d615ed31 100644 --- a/python/python/lancedb/rerankers/answerdotai.py +++ b/python/python/lancedb/rerankers/answerdotai.py @@ -74,9 +74,7 @@ class AnswerdotaiRerankers(Reranker): if self.score == "relevance": combined_results = self._keep_relevance_score(combined_results) elif self.score == "all": - raise NotImplementedError( - "Answerdotai Reranker does not support score='all' yet" - ) + combined_results = self._merge_and_keep_scores(vector_results, fts_results) combined_results = combined_results.sort_by( [("_relevance_score", "descending")] ) diff --git a/python/python/lancedb/rerankers/base.py b/python/python/lancedb/rerankers/base.py index bc33eee9..0c546a77 100644 --- a/python/python/lancedb/rerankers/base.py +++ b/python/python/lancedb/rerankers/base.py @@ -232,6 +232,39 @@ class Reranker(ABC): return deduped_table + def _merge_and_keep_scores(self, vector_results: pa.Table, fts_results: pa.Table): + """ + Merge the results from the vector and FTS search and keep the scores. + This op is slower than just keeping relevance score but can be useful + for debugging. + """ + # add nulls to fts results for _distance + if "_distance" not in fts_results.column_names: + fts_results = fts_results.append_column( + "_distance", + pa.array([None] * len(fts_results), type=pa.float32()), + ) + # add nulls to vector results for _score + if "_score" not in vector_results.column_names: + vector_results = vector_results.append_column( + "_score", + pa.array([None] * len(vector_results), type=pa.float32()), + ) + + # combine them and fill the scores + vector_results_dict = {row["_rowid"]: row for row in vector_results.to_pylist()} + fts_results_dict = {row["_rowid"]: row for row in fts_results.to_pylist()} + + # merge them into vector_results + for key, value in fts_results_dict.items(): + if key in vector_results_dict: + vector_results_dict[key]["_score"] = value["_score"] + else: + vector_results_dict[key] = value + + combined = pa.Table.from_pylist(list(vector_results_dict.values())) + return combined + def _keep_relevance_score(self, combined_results: pa.Table): if self.score == "relevance": if "_score" in combined_results.column_names: diff --git a/python/python/lancedb/rerankers/cohere.py b/python/python/lancedb/rerankers/cohere.py index 0b499213..10796067 100644 --- a/python/python/lancedb/rerankers/cohere.py +++ b/python/python/lancedb/rerankers/cohere.py @@ -92,14 +92,14 @@ class CohereReranker(Reranker): vector_results: pa.Table, fts_results: pa.Table, ): - combined_results = self.merge_results(vector_results, fts_results) + if self.score == "all": + combined_results = self._merge_and_keep_scores(vector_results, fts_results) + else: + combined_results = self.merge_results(vector_results, fts_results) combined_results = self._rerank(combined_results, query) if self.score == "relevance": combined_results = self._keep_relevance_score(combined_results) - elif self.score == "all": - raise NotImplementedError( - "return_score='all' not implemented for cohere reranker" - ) + return combined_results def rerank_vector(self, query: str, vector_results: pa.Table): diff --git a/python/python/lancedb/rerankers/cross_encoder.py b/python/python/lancedb/rerankers/cross_encoder.py index 4ca2b93d..fd24eed1 100644 --- a/python/python/lancedb/rerankers/cross_encoder.py +++ b/python/python/lancedb/rerankers/cross_encoder.py @@ -81,15 +81,15 @@ class CrossEncoderReranker(Reranker): vector_results: pa.Table, fts_results: pa.Table, ): - combined_results = self.merge_results(vector_results, fts_results) + if self.score == "all": + combined_results = self._merge_and_keep_scores(vector_results, fts_results) + else: + combined_results = self.merge_results(vector_results, fts_results) combined_results = self._rerank(combined_results, query) # sort the results by _score if self.score == "relevance": combined_results = self._keep_relevance_score(combined_results) - elif self.score == "all": - raise NotImplementedError( - "return_score='all' not implemented for CrossEncoderReranker" - ) + combined_results = combined_results.sort_by( [("_relevance_score", "descending")] ) diff --git a/python/python/lancedb/rerankers/jinaai.py b/python/python/lancedb/rerankers/jinaai.py index c398cda8..84ab21f4 100644 --- a/python/python/lancedb/rerankers/jinaai.py +++ b/python/python/lancedb/rerankers/jinaai.py @@ -97,14 +97,14 @@ class JinaReranker(Reranker): vector_results: pa.Table, fts_results: pa.Table, ): - combined_results = self.merge_results(vector_results, fts_results) + if self.score == "all": + combined_results = self._merge_and_keep_scores(vector_results, fts_results) + else: + combined_results = self.merge_results(vector_results, fts_results) combined_results = self._rerank(combined_results, query) if self.score == "relevance": combined_results = self._keep_relevance_score(combined_results) - elif self.score == "all": - raise NotImplementedError( - "return_score='all' not implemented for JinaReranker" - ) + return combined_results def rerank_vector(self, query: str, vector_results: pa.Table): diff --git a/python/python/lancedb/rerankers/openai.py b/python/python/lancedb/rerankers/openai.py index 6ee2da2e..7b181e80 100644 --- a/python/python/lancedb/rerankers/openai.py +++ b/python/python/lancedb/rerankers/openai.py @@ -88,14 +88,13 @@ class OpenaiReranker(Reranker): vector_results: pa.Table, fts_results: pa.Table, ): - combined_results = self.merge_results(vector_results, fts_results) + if self.score == "all": + combined_results = self._merge_and_keep_scores(vector_results, fts_results) + else: + combined_results = self.merge_results(vector_results, fts_results) combined_results = self._rerank(combined_results, query) if self.score == "relevance": combined_results = self._keep_relevance_score(combined_results) - elif self.score == "all": - raise NotImplementedError( - "OpenAI Reranker does not support score='all' yet" - ) combined_results = combined_results.sort_by( [("_relevance_score", "descending")] diff --git a/python/python/lancedb/rerankers/voyageai.py b/python/python/lancedb/rerankers/voyageai.py index 9ef8818a..f99c0ae4 100644 --- a/python/python/lancedb/rerankers/voyageai.py +++ b/python/python/lancedb/rerankers/voyageai.py @@ -94,14 +94,14 @@ class VoyageAIReranker(Reranker): vector_results: pa.Table, fts_results: pa.Table, ): - combined_results = self.merge_results(vector_results, fts_results) + if self.score == "all": + combined_results = self._merge_and_keep_scores(vector_results, fts_results) + else: + combined_results = self.merge_results(vector_results, fts_results) combined_results = self._rerank(combined_results, query) if self.score == "relevance": combined_results = self._keep_relevance_score(combined_results) - elif self.score == "all": - raise NotImplementedError( - "return_score='all' not implemented for voyageai reranker" - ) + return combined_results def rerank_vector(self, query: str, vector_results: pa.Table): diff --git a/python/python/tests/test_rerankers.py b/python/python/tests/test_rerankers.py index d82aae2a..c40afac8 100644 --- a/python/python/tests/test_rerankers.py +++ b/python/python/tests/test_rerankers.py @@ -499,3 +499,19 @@ def test_empty_result_reranker(): .rerank(reranker) .to_arrow() ) + + +@pytest.mark.parametrize("use_tantivy", [True, False]) +def test_cross_encoder_reranker_return_all(tmp_path, use_tantivy): + pytest.importorskip("sentence_transformers") + reranker = CrossEncoderReranker(return_score="all") + table, schema = get_test_table(tmp_path, use_tantivy) + query = "single player experience" + result = ( + table.search(query, query_type="hybrid", vector_column_name="vector") + .rerank(reranker=reranker) + .to_arrow() + ) + assert "_relevance_score" in result.column_names + assert "_score" in result.column_names + assert "_distance" in result.column_names