diff --git a/python/python/lancedb/rerankers/answerdotai.py b/python/python/lancedb/rerankers/answerdotai.py index 63df48c9..642f51aa 100644 --- a/python/python/lancedb/rerankers/answerdotai.py +++ b/python/python/lancedb/rerankers/answerdotai.py @@ -47,6 +47,9 @@ class AnswerdotaiRerankers(Reranker): ) def _rerank(self, result_set: pa.Table, query: str): + result_set = self._handle_empty_results(result_set) + if len(result_set) == 0: + return result_set docs = result_set[self.column].to_pylist() doc_ids = list(range(len(docs))) result = self.reranker.rank(query, docs, doc_ids=doc_ids) @@ -83,7 +86,6 @@ class AnswerdotaiRerankers(Reranker): vector_results = self._rerank(vector_results, query) if self.score == "relevance": vector_results = vector_results.drop_columns(["_distance"]) - vector_results = vector_results.sort_by([("_relevance_score", "descending")]) return vector_results @@ -91,7 +93,5 @@ class AnswerdotaiRerankers(Reranker): fts_results = self._rerank(fts_results, query) if self.score == "relevance": fts_results = fts_results.drop_columns(["_score"]) - fts_results = fts_results.sort_by([("_relevance_score", "descending")]) - return fts_results diff --git a/python/python/lancedb/rerankers/base.py b/python/python/lancedb/rerankers/base.py index ece08eec..bc33eee9 100644 --- a/python/python/lancedb/rerankers/base.py +++ b/python/python/lancedb/rerankers/base.py @@ -65,6 +65,16 @@ class Reranker(ABC): f"{self.__class__.__name__} does not implement rerank_vector" ) + def _handle_empty_results(self, results: pa.Table): + """ + Helper method to handle empty FTS results consistently + """ + if len(results) > 0: + return results + return results.append_column( + "_relevance_score", pa.array([], type=pa.float32()) + ) + def rerank_fts( self, query: str, diff --git a/python/python/lancedb/rerankers/cohere.py b/python/python/lancedb/rerankers/cohere.py index b4044995..0b499213 100644 --- a/python/python/lancedb/rerankers/cohere.py +++ b/python/python/lancedb/rerankers/cohere.py @@ -62,6 +62,9 @@ class CohereReranker(Reranker): return cohere.Client(os.environ.get("COHERE_API_KEY") or self.api_key) def _rerank(self, result_set: pa.Table, query: str): + result_set = self._handle_empty_results(result_set) + if len(result_set) == 0: + return result_set docs = result_set[self.column].to_pylist() response = self._client.rerank( query=query, @@ -99,24 +102,14 @@ class CohereReranker(Reranker): ) return combined_results - def rerank_vector( - self, - query: str, - vector_results: pa.Table, - ): - result_set = self._rerank(vector_results, query) + def rerank_vector(self, query: str, vector_results: pa.Table): + vector_results = self._rerank(vector_results, query) if self.score == "relevance": - result_set = result_set.drop_columns(["_distance"]) + vector_results = vector_results.drop_columns(["_distance"]) + return vector_results - return result_set - - def rerank_fts( - self, - query: str, - fts_results: pa.Table, - ): - result_set = self._rerank(fts_results, query) + def rerank_fts(self, query: str, fts_results: pa.Table): + fts_results = self._rerank(fts_results, query) if self.score == "relevance": - result_set = result_set.drop_columns(["_score"]) - - return result_set + fts_results = fts_results.drop_columns(["_score"]) + return fts_results diff --git a/python/python/lancedb/rerankers/cross_encoder.py b/python/python/lancedb/rerankers/cross_encoder.py index 3920f4d0..4ca2b93d 100644 --- a/python/python/lancedb/rerankers/cross_encoder.py +++ b/python/python/lancedb/rerankers/cross_encoder.py @@ -63,6 +63,9 @@ class CrossEncoderReranker(Reranker): return cross_encoder def _rerank(self, result_set: pa.Table, query: str): + result_set = self._handle_empty_results(result_set) + if len(result_set) == 0: + return result_set passages = result_set[self.column].to_pylist() cross_inp = [[query, passage] for passage in passages] cross_scores = self.model.predict(cross_inp) @@ -93,11 +96,7 @@ class CrossEncoderReranker(Reranker): return combined_results - def rerank_vector( - self, - query: str, - vector_results: pa.Table, - ): + def rerank_vector(self, query: str, vector_results: pa.Table): vector_results = self._rerank(vector_results, query) if self.score == "relevance": vector_results = vector_results.drop_columns(["_distance"]) @@ -105,11 +104,7 @@ class CrossEncoderReranker(Reranker): vector_results = vector_results.sort_by([("_relevance_score", "descending")]) return vector_results - def rerank_fts( - self, - query: str, - fts_results: pa.Table, - ): + def rerank_fts(self, query: str, fts_results: pa.Table): fts_results = self._rerank(fts_results, query) if self.score == "relevance": fts_results = fts_results.drop_columns(["_score"]) diff --git a/python/python/lancedb/rerankers/jinaai.py b/python/python/lancedb/rerankers/jinaai.py index 781455eb..c398cda8 100644 --- a/python/python/lancedb/rerankers/jinaai.py +++ b/python/python/lancedb/rerankers/jinaai.py @@ -62,6 +62,9 @@ class JinaReranker(Reranker): return self._session def _rerank(self, result_set: pa.Table, query: str): + result_set = self._handle_empty_results(result_set) + if len(result_set) == 0: + return result_set docs = result_set[self.column].to_pylist() response = self._client.post( # type: ignore API_URL, @@ -104,24 +107,14 @@ class JinaReranker(Reranker): ) return combined_results - def rerank_vector( - self, - query: str, - vector_results: pa.Table, - ): - result_set = self._rerank(vector_results, query) + def rerank_vector(self, query: str, vector_results: pa.Table): + vector_results = self._rerank(vector_results, query) if self.score == "relevance": - result_set = result_set.drop_columns(["_distance"]) + vector_results = vector_results.drop_columns(["_distance"]) + return vector_results - return result_set - - def rerank_fts( - self, - query: str, - fts_results: pa.Table, - ): - result_set = self._rerank(fts_results, query) + def rerank_fts(self, query: str, fts_results: pa.Table): + fts_results = self._rerank(fts_results, query) if self.score == "relevance": - result_set = result_set.drop_columns(["_score"]) - - return result_set + fts_results = fts_results.drop_columns(["_score"]) + return fts_results diff --git a/python/python/lancedb/rerankers/openai.py b/python/python/lancedb/rerankers/openai.py index 7bffdc56..6ee2da2e 100644 --- a/python/python/lancedb/rerankers/openai.py +++ b/python/python/lancedb/rerankers/openai.py @@ -44,6 +44,9 @@ class OpenaiReranker(Reranker): self.api_key = api_key def _rerank(self, result_set: pa.Table, query: str): + result_set = self._handle_empty_results(result_set) + if len(result_set) == 0: + return result_set docs = result_set[self.column].to_pylist() response = self._client.chat.completions.create( model=self.model_name, @@ -104,18 +107,14 @@ class OpenaiReranker(Reranker): vector_results = self._rerank(vector_results, query) if self.score == "relevance": vector_results = vector_results.drop_columns(["_distance"]) - vector_results = vector_results.sort_by([("_relevance_score", "descending")]) - return vector_results def rerank_fts(self, query: str, fts_results: pa.Table): fts_results = self._rerank(fts_results, query) if self.score == "relevance": fts_results = fts_results.drop_columns(["_score"]) - fts_results = fts_results.sort_by([("_relevance_score", "descending")]) - return fts_results @cached_property diff --git a/python/python/lancedb/rerankers/voyageai.py b/python/python/lancedb/rerankers/voyageai.py index c0e2d23c..9ef8818a 100644 --- a/python/python/lancedb/rerankers/voyageai.py +++ b/python/python/lancedb/rerankers/voyageai.py @@ -63,6 +63,9 @@ class VoyageAIReranker(Reranker): ) def _rerank(self, result_set: pa.Table, query: str): + result_set = self._handle_empty_results(result_set) + if len(result_set) == 0: + return result_set docs = result_set[self.column].to_pylist() response = self._client.rerank( query=query, @@ -101,24 +104,14 @@ class VoyageAIReranker(Reranker): ) return combined_results - def rerank_vector( - self, - query: str, - vector_results: pa.Table, - ): - result_set = self._rerank(vector_results, query) + def rerank_vector(self, query: str, vector_results: pa.Table): + vector_results = self._rerank(vector_results, query) if self.score == "relevance": - result_set = result_set.drop_columns(["_distance"]) + vector_results = vector_results.drop_columns(["_distance"]) + return vector_results - return result_set - - def rerank_fts( - self, - query: str, - fts_results: pa.Table, - ): - result_set = self._rerank(fts_results, query) + def rerank_fts(self, query: str, fts_results: pa.Table): + fts_results = self._rerank(fts_results, query) if self.score == "relevance": - result_set = result_set.drop_columns(["_score"]) - - return result_set + fts_results = fts_results.drop_columns(["_score"]) + return fts_results diff --git a/python/python/tests/test_rerankers.py b/python/python/tests/test_rerankers.py index 4c50c831..d82aae2a 100644 --- a/python/python/tests/test_rerankers.py +++ b/python/python/tests/test_rerankers.py @@ -457,3 +457,45 @@ def test_voyageai_reranker(tmp_path, use_tantivy): reranker = VoyageAIReranker(model_name="rerank-2") table, schema = get_test_table(tmp_path, use_tantivy) _run_test_reranker(reranker, table, "single player experience", None, schema) + + +def test_empty_result_reranker(): + pytest.importorskip("sentence_transformers") + db = lancedb.connect("memory://") + + # Define schema + schema = pa.schema( + [ + ("id", pa.int64()), + ("text", pa.string()), + ("vector", pa.list_(pa.float32(), 128)), # 128-dimensional vector + ] + ) + + # Create empty table with schema + empty_table = db.create_table("empty_table", schema=schema, mode="overwrite") + empty_table.create_fts_index("text", use_tantivy=False, replace=True) + for reranker in [ + CrossEncoderReranker(), + # ColbertReranker(), + # AnswerdotaiRerankers(), + # OpenaiReranker(), + # JinaReranker(), + # VoyageAIReranker(model_name="rerank-2"), + ]: + results = ( + empty_table.search(list(range(128))) + .limit(3) + .rerank(reranker, "query") + .to_arrow() + ) + # check if empty set contains _relevance_score column + assert "_relevance_score" in results.column_names + assert len(results) == 0 + + results = ( + empty_table.search("query", query_type="fts") + .limit(3) + .rerank(reranker) + .to_arrow() + )