diff --git a/python/python/lancedb/rerankers/__init__.py b/python/python/lancedb/rerankers/__init__.py index c7e34d57..0b767a67 100644 --- a/python/python/lancedb/rerankers/__init__.py +++ b/python/python/lancedb/rerankers/__init__.py @@ -5,6 +5,7 @@ from .cross_encoder import CrossEncoderReranker from .linear_combination import LinearCombinationReranker from .openai import OpenaiReranker from .jinaai import JinaReranker +from .rrf import RRFReranker __all__ = [ "Reranker", @@ -14,4 +15,5 @@ __all__ = [ "OpenaiReranker", "ColbertReranker", "JinaReranker", + "RRFReranker", ] diff --git a/python/python/lancedb/rerankers/rrf.py b/python/python/lancedb/rerankers/rrf.py new file mode 100644 index 00000000..a4d95a39 --- /dev/null +++ b/python/python/lancedb/rerankers/rrf.py @@ -0,0 +1,60 @@ +import pyarrow as pa + +from collections import defaultdict +from .base import Reranker + + +class RRFReranker(Reranker): + """ + Reranks the results using Reciprocal Rank Fusion(RRF) algorithm based + on the scores of vector and FTS search. + Parameters + ---------- + K : int, default 60 + A constant used in the RRF formula (default is 60). Experiments + indicate that k = 60 was near-optimal, but that the choice is + not critical. See paper: + https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf + return_score : str, default "relevance" + opntions are "relevance" or "all" + The type of score to return. If "relevance", will return only the relevance + score. If "all", will return all scores from the vector and FTS search along + with the relevance score. + """ + + def __init__(self, K: int = 60, return_score="relevance"): + if K <= 0: + raise ValueError("K must be greater than 0") + super().__init__(return_score) + self.K = K + + def rerank_hybrid( + self, + query: str, # noqa: F821 + vector_results: pa.Table, + fts_results: pa.Table, + ): + vector_ids = vector_results["_rowid"].to_pylist() if vector_results else [] + fts_ids = fts_results["_rowid"].to_pylist() if fts_results else [] + rrf_score_map = defaultdict(float) + + # Calculate RRF score of each result + for ids in [vector_ids, fts_ids]: + for i, result_id in enumerate(ids, 1): + rrf_score_map[result_id] += 1 / (i + self.K) + + # Sort the results based on RRF score + combined_results = self.merge_results(vector_results, fts_results) + combined_row_ids = combined_results["_rowid"].to_pylist() + relevance_scores = [rrf_score_map[row_id] for row_id in combined_row_ids] + combined_results = combined_results.append_column( + "_relevance_score", pa.array(relevance_scores, type=pa.float32()) + ) + combined_results = combined_results.sort_by( + [("_relevance_score", "descending")] + ) + + if self.score == "relevance": + combined_results = combined_results.drop_columns(["score", "_distance"]) + + return combined_results diff --git a/python/python/tests/test_rerankers.py b/python/python/tests/test_rerankers.py index 1e303905..b7ffc10c 100644 --- a/python/python/tests/test_rerankers.py +++ b/python/python/tests/test_rerankers.py @@ -7,6 +7,8 @@ from lancedb.conftest import MockTextEmbeddingFunction # noqa from lancedb.embeddings import EmbeddingFunctionRegistry from lancedb.pydantic import LanceModel, Vector from lancedb.rerankers import ( + LinearCombinationReranker, + RRFReranker, CohereReranker, ColbertReranker, CrossEncoderReranker, @@ -140,7 +142,7 @@ def _run_test_reranker(reranker, table, query, query_vector, schema): assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err -def test_linear_combination(tmp_path): +def _run_test_hybrid_reranker(reranker, tmp_path): table, schema = get_test_table(tmp_path) # The default reranker result1 = ( @@ -177,6 +179,16 @@ def test_linear_combination(tmp_path): ) +def test_linear_combination(tmp_path): + reranker = LinearCombinationReranker() + _run_test_hybrid_reranker(reranker, tmp_path) + + +def test_rrf_reranker(tmp_path): + reranker = RRFReranker() + _run_test_hybrid_reranker(reranker, tmp_path) + + @pytest.mark.skipif( os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set" )