feat: add reciprocal rank fusion reranker (#1456)

Implements https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf

Refactors the hybrid search only rerrankers test to avoid repetition.
This commit is contained in:
Ayush Chaurasia
2024-07-23 21:37:17 +05:30
committed by GitHub
parent 4ee229490c
commit 0255221086
3 changed files with 75 additions and 1 deletions

View File

@@ -5,6 +5,7 @@ from .cross_encoder import CrossEncoderReranker
from .linear_combination import LinearCombinationReranker
from .openai import OpenaiReranker
from .jinaai import JinaReranker
from .rrf import RRFReranker
__all__ = [
"Reranker",
@@ -14,4 +15,5 @@ __all__ = [
"OpenaiReranker",
"ColbertReranker",
"JinaReranker",
"RRFReranker",
]

View File

@@ -0,0 +1,60 @@
import pyarrow as pa
from collections import defaultdict
from .base import Reranker
class RRFReranker(Reranker):
"""
Reranks the results using Reciprocal Rank Fusion(RRF) algorithm based
on the scores of vector and FTS search.
Parameters
----------
K : int, default 60
A constant used in the RRF formula (default is 60). Experiments
indicate that k = 60 was near-optimal, but that the choice is
not critical. See paper:
https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf
return_score : str, default "relevance"
opntions are "relevance" or "all"
The type of score to return. If "relevance", will return only the relevance
score. If "all", will return all scores from the vector and FTS search along
with the relevance score.
"""
def __init__(self, K: int = 60, return_score="relevance"):
if K <= 0:
raise ValueError("K must be greater than 0")
super().__init__(return_score)
self.K = K
def rerank_hybrid(
self,
query: str, # noqa: F821
vector_results: pa.Table,
fts_results: pa.Table,
):
vector_ids = vector_results["_rowid"].to_pylist() if vector_results else []
fts_ids = fts_results["_rowid"].to_pylist() if fts_results else []
rrf_score_map = defaultdict(float)
# Calculate RRF score of each result
for ids in [vector_ids, fts_ids]:
for i, result_id in enumerate(ids, 1):
rrf_score_map[result_id] += 1 / (i + self.K)
# Sort the results based on RRF score
combined_results = self.merge_results(vector_results, fts_results)
combined_row_ids = combined_results["_rowid"].to_pylist()
relevance_scores = [rrf_score_map[row_id] for row_id in combined_row_ids]
combined_results = combined_results.append_column(
"_relevance_score", pa.array(relevance_scores, type=pa.float32())
)
combined_results = combined_results.sort_by(
[("_relevance_score", "descending")]
)
if self.score == "relevance":
combined_results = combined_results.drop_columns(["score", "_distance"])
return combined_results

View File

@@ -7,6 +7,8 @@ from lancedb.conftest import MockTextEmbeddingFunction # noqa
from lancedb.embeddings import EmbeddingFunctionRegistry
from lancedb.pydantic import LanceModel, Vector
from lancedb.rerankers import (
LinearCombinationReranker,
RRFReranker,
CohereReranker,
ColbertReranker,
CrossEncoderReranker,
@@ -140,7 +142,7 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
def test_linear_combination(tmp_path):
def _run_test_hybrid_reranker(reranker, tmp_path):
table, schema = get_test_table(tmp_path)
# The default reranker
result1 = (
@@ -177,6 +179,16 @@ def test_linear_combination(tmp_path):
)
def test_linear_combination(tmp_path):
reranker = LinearCombinationReranker()
_run_test_hybrid_reranker(reranker, tmp_path)
def test_rrf_reranker(tmp_path):
reranker = RRFReranker()
_run_test_hybrid_reranker(reranker, tmp_path)
@pytest.mark.skipif(
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
)