mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-10 13:52:58 +00:00
feat: support mean reciprocal rank reranker (#2671)
The basic idea of MRR is this - https://www.evidentlyai.com/ranking-metrics/mean-reciprocal-rank-mrr I've implemented a weighted version for allowing user to set weightage between vector and fts. The gist is something like this ### Scenario A: Document at rank 1 in one set, absent from another ``` # Assuming equal weights: weight_vector = 0.5, weight_fts = 0.5 vector_rr = 1.0 # rank 1 → 1/1 = 1.0 fts_rr = 0.0 # absent → 0.0 weighted_mrr = 0.5 × 1.0 + 0.5 × 0.0 = 0.5 ``` ### Scenario B: Document at rank 1 in one set, rank 2 in another ``` # Same weights: weight_vector = 0.5, weight_fts = 0.5 vector_rr = 1.0 # rank 1 → 1/1 = 1.0 fts_rr = 0.5 # rank 2 → 1/2 = 0.5 weighted_mrr = 0.5 × 1.0 + 0.5 × 0.5 = 0.5 + 0.25 = 0.75 ``` And so with `return_score="all"` the result looks something like this (this is from the reranker tests). Because this is a weighted rank based reranker, some results might have the same score ``` text vector _distance _rowid _score _relevance_score 0 I am your father [-0.010703234, 0.069315575, 0.030076642, 0.002... 8.149148e-13 8589934598 10.978719 1.000000 1 the ground beneath my feet [-0.09500901, 0.00092102867, 0.0755851, 0.0372... 1.376896e+00 8589934604 NaN 0.250000 2 I find your lack of faith disturbing [0.07525753, -0.0100010475, 0.09990541, 0.0209... NaN 8589934595 3.483394 0.250000 3 but I don't wanna die [0.033476487, -0.011235877, -0.057625435, -0.0... 1.538222e+00 8589934610 1.130355 0.238095 4 if you strike me down I shall become more powe... [0.00432201, 0.030120496, 5.3317923e-05, 0.033... 1.381086e+00 8589934594 0.715157 0.216667 5 I see a salty message written in the eves [-0.04213107, 0.0016004723, 0.061052393, -0.02... 1.638301e+00 8589934603 1.043785 0.133333 6 but his son was mortal [0.012462767, 0.049041674, -0.057339743, -0.04... 1.421566e+00 8589934620 NaN 0.125000 7 I've got a bad feeling about this [-0.06973199, -0.029960092, 0.02641632, -0.031... NaN 8589934596 1.043785 0.125000 8 now that's a name I haven't heard in a long time [-0.014374257, -0.013588792, -0.07487557, 0.03... 1.597573e+00 8589934593 0.848772 0.118056 9 he was a god [-0.0258895, 0.11925236, -0.029397793, 0.05888... 1.423147e+00 8589934618 NaN 0.100000 10 I wish they would make another one [-0.14737535, -0.015304729, 0.04318139, -0.061... NaN 8589934622 1.043785 0.100000 11 Kratos had a son [-0.057455737, 0.13734367, -0.03537109, -0.000... 1.488075e+00 8589934617 NaN 0.083333 12 I don't wanna live like this [-0.0028891307, 0.015214227, 0.025183653, 0.08... NaN 8589934609 1.043785 0.071429 13 I see a mansard roof through the trees [0.052383978, 0.087759204, 0.014739997, 0.0239... NaN 8589934602 1.043785 0.062500 14 great kid don't get cocky [-0.047043696, 0.054648954, -0.008509666, -0.0... 1.618125e+00 8589934592 NaN 0.055556 ```
This commit is contained in:
@@ -9,6 +9,7 @@ from .linear_combination import LinearCombinationReranker
|
||||
from .openai import OpenaiReranker
|
||||
from .jinaai import JinaReranker
|
||||
from .rrf import RRFReranker
|
||||
from .mrr import MRRReranker
|
||||
from .answerdotai import AnswerdotaiRerankers
|
||||
from .voyageai import VoyageAIReranker
|
||||
|
||||
@@ -23,4 +24,5 @@ __all__ = [
|
||||
"RRFReranker",
|
||||
"AnswerdotaiRerankers",
|
||||
"VoyageAIReranker",
|
||||
"MRRReranker",
|
||||
]
|
||||
|
||||
169
python/python/lancedb/rerankers/mrr.py
Normal file
169
python/python/lancedb/rerankers/mrr.py
Normal file
@@ -0,0 +1,169 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
|
||||
from typing import Union, List, TYPE_CHECKING
|
||||
import pyarrow as pa
|
||||
import numpy as np
|
||||
|
||||
from collections import defaultdict
|
||||
from .base import Reranker
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..table import LanceVectorQueryBuilder
|
||||
|
||||
|
||||
class MRRReranker(Reranker):
|
||||
"""
|
||||
Reranks the results using Mean Reciprocal Rank (MRR) algorithm based
|
||||
on the scores of vector and FTS search.
|
||||
Algorithm reference - https://en.wikipedia.org/wiki/Mean_reciprocal_rank
|
||||
|
||||
MRR calculates the average of reciprocal ranks across different search results.
|
||||
For each document, it computes the reciprocal of its rank in each system,
|
||||
then takes the mean of these reciprocal ranks as the final score.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
weight_vector : float, default 0.5
|
||||
Weight for vector search results (0.0 to 1.0)
|
||||
weight_fts : float, default 0.5
|
||||
Weight for FTS search results (0.0 to 1.0)
|
||||
Note: weight_vector + weight_fts should equal 1.0
|
||||
return_score : str, default "relevance"
|
||||
Options are "relevance" or "all"
|
||||
The type of score to return. If "relevance", will return only the relevance
|
||||
score. If "all", will return all scores from the vector and FTS search along
|
||||
with the relevance score.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_vector: float = 0.5,
|
||||
weight_fts: float = 0.5,
|
||||
return_score="relevance",
|
||||
):
|
||||
if not (0.0 <= weight_vector <= 1.0):
|
||||
raise ValueError("weight_vector must be between 0.0 and 1.0")
|
||||
if not (0.0 <= weight_fts <= 1.0):
|
||||
raise ValueError("weight_fts must be between 0.0 and 1.0")
|
||||
if abs(weight_vector + weight_fts - 1.0) > 1e-6:
|
||||
raise ValueError("weight_vector + weight_fts must equal 1.0")
|
||||
|
||||
super().__init__(return_score)
|
||||
self.weight_vector = weight_vector
|
||||
self.weight_fts = weight_fts
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query: str, # noqa: F821
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
vector_ids = vector_results["_rowid"].to_pylist() if vector_results else []
|
||||
fts_ids = fts_results["_rowid"].to_pylist() if fts_results else []
|
||||
|
||||
# Maps result_id to list of (type, reciprocal_rank)
|
||||
mrr_score_map = defaultdict(list)
|
||||
|
||||
if vector_ids:
|
||||
for rank, result_id in enumerate(vector_ids, 1):
|
||||
reciprocal_rank = 1.0 / rank
|
||||
mrr_score_map[result_id].append(("vector", reciprocal_rank))
|
||||
|
||||
if fts_ids:
|
||||
for rank, result_id in enumerate(fts_ids, 1):
|
||||
reciprocal_rank = 1.0 / rank
|
||||
mrr_score_map[result_id].append(("fts", reciprocal_rank))
|
||||
|
||||
final_mrr_scores = {}
|
||||
for result_id, scores in mrr_score_map.items():
|
||||
vector_rr = 0.0
|
||||
fts_rr = 0.0
|
||||
|
||||
for score_type, reciprocal_rank in scores:
|
||||
if score_type == "vector":
|
||||
vector_rr = reciprocal_rank
|
||||
elif score_type == "fts":
|
||||
fts_rr = reciprocal_rank
|
||||
|
||||
# If a document doesn't appear, its reciprocal rank is 0
|
||||
weighted_mrr = self.weight_vector * vector_rr + self.weight_fts * fts_rr
|
||||
final_mrr_scores[result_id] = weighted_mrr
|
||||
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
combined_row_ids = combined_results["_rowid"].to_pylist()
|
||||
relevance_scores = [final_mrr_scores[row_id] for row_id in combined_row_ids]
|
||||
combined_results = combined_results.append_column(
|
||||
"_relevance_score", pa.array(relevance_scores, type=pa.float32())
|
||||
)
|
||||
combined_results = combined_results.sort_by(
|
||||
[("_relevance_score", "descending")]
|
||||
)
|
||||
|
||||
if self.score == "relevance":
|
||||
combined_results = self._keep_relevance_score(combined_results)
|
||||
|
||||
return combined_results
|
||||
|
||||
def rerank_multivector(
|
||||
self,
|
||||
vector_results: Union[List[pa.Table], List["LanceVectorQueryBuilder"]],
|
||||
query: str = None,
|
||||
deduplicate: bool = True, # noqa: F821
|
||||
):
|
||||
"""
|
||||
Reranks the results from multiple vector searches using MRR algorithm.
|
||||
Each vector search result is treated as a separate ranking system,
|
||||
and MRR calculates the mean of reciprocal ranks across all systems.
|
||||
This cannot reuse rerank_hybrid because MRR semantics require treating
|
||||
each vector result as a separate ranking system.
|
||||
"""
|
||||
if not all(isinstance(v, type(vector_results[0])) for v in vector_results):
|
||||
raise ValueError(
|
||||
"All elements in vector_results should be of the same type"
|
||||
)
|
||||
|
||||
# avoid circular import
|
||||
if type(vector_results[0]).__name__ == "LanceVectorQueryBuilder":
|
||||
vector_results = [result.to_arrow() for result in vector_results]
|
||||
elif not isinstance(vector_results[0], pa.Table):
|
||||
raise ValueError(
|
||||
"vector_results should be a list of pa.Table or LanceVectorQueryBuilder"
|
||||
)
|
||||
|
||||
if not all("_rowid" in result.column_names for result in vector_results):
|
||||
raise ValueError(
|
||||
"'_rowid' is required for deduplication. \
|
||||
add _rowid to search results like this: \
|
||||
`search().with_row_id(True)`"
|
||||
)
|
||||
|
||||
mrr_score_map = defaultdict(list)
|
||||
|
||||
for result_table in vector_results:
|
||||
result_ids = result_table["_rowid"].to_pylist()
|
||||
for rank, result_id in enumerate(result_ids, 1):
|
||||
reciprocal_rank = 1.0 / rank
|
||||
mrr_score_map[result_id].append(reciprocal_rank)
|
||||
|
||||
final_mrr_scores = {}
|
||||
for result_id, reciprocal_ranks in mrr_score_map.items():
|
||||
mean_rr = np.mean(reciprocal_ranks)
|
||||
final_mrr_scores[result_id] = mean_rr
|
||||
|
||||
combined = pa.concat_tables(vector_results, **self._concat_tables_args)
|
||||
combined = self._deduplicate(combined)
|
||||
|
||||
combined_row_ids = combined["_rowid"].to_pylist()
|
||||
|
||||
relevance_scores = [final_mrr_scores[row_id] for row_id in combined_row_ids]
|
||||
combined = combined.append_column(
|
||||
"_relevance_score", pa.array(relevance_scores, type=pa.float32())
|
||||
)
|
||||
combined = combined.sort_by([("_relevance_score", "descending")])
|
||||
|
||||
if self.score == "relevance":
|
||||
combined = self._keep_relevance_score(combined)
|
||||
|
||||
return combined
|
||||
@@ -22,6 +22,7 @@ from lancedb.rerankers import (
|
||||
JinaReranker,
|
||||
AnswerdotaiRerankers,
|
||||
VoyageAIReranker,
|
||||
MRRReranker,
|
||||
)
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
@@ -46,6 +47,7 @@ def get_test_table(tmp_path, use_tantivy):
|
||||
db,
|
||||
"my_table",
|
||||
schema=MyTable,
|
||||
mode="overwrite",
|
||||
)
|
||||
|
||||
# Need to test with a bunch of phrases to make sure sorting is consistent
|
||||
@@ -96,7 +98,7 @@ def get_test_table(tmp_path, use_tantivy):
|
||||
)
|
||||
|
||||
# Create a fts index
|
||||
table.create_fts_index("text", use_tantivy=use_tantivy)
|
||||
table.create_fts_index("text", use_tantivy=use_tantivy, replace=True)
|
||||
|
||||
return table, MyTable
|
||||
|
||||
@@ -320,6 +322,34 @@ def test_rrf_reranker(tmp_path, use_tantivy):
|
||||
_run_test_hybrid_reranker(reranker, tmp_path, use_tantivy)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_mrr_reranker(tmp_path, use_tantivy):
|
||||
reranker = MRRReranker()
|
||||
_run_test_hybrid_reranker(reranker, tmp_path, use_tantivy)
|
||||
|
||||
# Test multi-vector part
|
||||
table, schema = get_test_table(tmp_path, use_tantivy)
|
||||
query = "single player experience"
|
||||
rs1 = table.search(query, vector_column_name="vector").limit(10).with_row_id(True)
|
||||
rs2 = (
|
||||
table.search(query, vector_column_name="meta_vector")
|
||||
.limit(10)
|
||||
.with_row_id(True)
|
||||
)
|
||||
result = reranker.rerank_multivector([rs1, rs2])
|
||||
assert "_relevance_score" in result.column_names
|
||||
assert len(result) <= 20
|
||||
|
||||
if len(result) > 1:
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
"The _relevance_score should be descending."
|
||||
)
|
||||
|
||||
# Test with duplicate results
|
||||
result_deduped = reranker.rerank_multivector([rs1, rs2, rs1])
|
||||
assert len(result_deduped) == len(result)
|
||||
|
||||
|
||||
def test_rrf_reranker_distance():
|
||||
data = pa.table(
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user