mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-09 13:22:58 +00:00
feat: add reciprocal rank fusion reranker (#1456)
Implements https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf Refactors the hybrid search only rerrankers test to avoid repetition.
This commit is contained in:
@@ -5,6 +5,7 @@ from .cross_encoder import CrossEncoderReranker
|
||||
from .linear_combination import LinearCombinationReranker
|
||||
from .openai import OpenaiReranker
|
||||
from .jinaai import JinaReranker
|
||||
from .rrf import RRFReranker
|
||||
|
||||
__all__ = [
|
||||
"Reranker",
|
||||
@@ -14,4 +15,5 @@ __all__ = [
|
||||
"OpenaiReranker",
|
||||
"ColbertReranker",
|
||||
"JinaReranker",
|
||||
"RRFReranker",
|
||||
]
|
||||
|
||||
60
python/python/lancedb/rerankers/rrf.py
Normal file
60
python/python/lancedb/rerankers/rrf.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import pyarrow as pa
|
||||
|
||||
from collections import defaultdict
|
||||
from .base import Reranker
|
||||
|
||||
|
||||
class RRFReranker(Reranker):
|
||||
"""
|
||||
Reranks the results using Reciprocal Rank Fusion(RRF) algorithm based
|
||||
on the scores of vector and FTS search.
|
||||
Parameters
|
||||
----------
|
||||
K : int, default 60
|
||||
A constant used in the RRF formula (default is 60). Experiments
|
||||
indicate that k = 60 was near-optimal, but that the choice is
|
||||
not critical. See paper:
|
||||
https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf
|
||||
return_score : str, default "relevance"
|
||||
opntions are "relevance" or "all"
|
||||
The type of score to return. If "relevance", will return only the relevance
|
||||
score. If "all", will return all scores from the vector and FTS search along
|
||||
with the relevance score.
|
||||
"""
|
||||
|
||||
def __init__(self, K: int = 60, return_score="relevance"):
|
||||
if K <= 0:
|
||||
raise ValueError("K must be greater than 0")
|
||||
super().__init__(return_score)
|
||||
self.K = K
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query: str, # noqa: F821
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
vector_ids = vector_results["_rowid"].to_pylist() if vector_results else []
|
||||
fts_ids = fts_results["_rowid"].to_pylist() if fts_results else []
|
||||
rrf_score_map = defaultdict(float)
|
||||
|
||||
# Calculate RRF score of each result
|
||||
for ids in [vector_ids, fts_ids]:
|
||||
for i, result_id in enumerate(ids, 1):
|
||||
rrf_score_map[result_id] += 1 / (i + self.K)
|
||||
|
||||
# Sort the results based on RRF score
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
combined_row_ids = combined_results["_rowid"].to_pylist()
|
||||
relevance_scores = [rrf_score_map[row_id] for row_id in combined_row_ids]
|
||||
combined_results = combined_results.append_column(
|
||||
"_relevance_score", pa.array(relevance_scores, type=pa.float32())
|
||||
)
|
||||
combined_results = combined_results.sort_by(
|
||||
[("_relevance_score", "descending")]
|
||||
)
|
||||
|
||||
if self.score == "relevance":
|
||||
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||
|
||||
return combined_results
|
||||
@@ -7,6 +7,8 @@ from lancedb.conftest import MockTextEmbeddingFunction # noqa
|
||||
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.rerankers import (
|
||||
LinearCombinationReranker,
|
||||
RRFReranker,
|
||||
CohereReranker,
|
||||
ColbertReranker,
|
||||
CrossEncoderReranker,
|
||||
@@ -140,7 +142,7 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
|
||||
|
||||
def test_linear_combination(tmp_path):
|
||||
def _run_test_hybrid_reranker(reranker, tmp_path):
|
||||
table, schema = get_test_table(tmp_path)
|
||||
# The default reranker
|
||||
result1 = (
|
||||
@@ -177,6 +179,16 @@ def test_linear_combination(tmp_path):
|
||||
)
|
||||
|
||||
|
||||
def test_linear_combination(tmp_path):
|
||||
reranker = LinearCombinationReranker()
|
||||
_run_test_hybrid_reranker(reranker, tmp_path)
|
||||
|
||||
|
||||
def test_rrf_reranker(tmp_path):
|
||||
reranker = RRFReranker()
|
||||
_run_test_hybrid_reranker(reranker, tmp_path)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user