diff --git a/python/python/lancedb/rerankers/mrr.py b/python/python/lancedb/rerankers/mrr.py index af5ce778d..81a03c9ca 100644 --- a/python/python/lancedb/rerankers/mrr.py +++ b/python/python/lancedb/rerankers/mrr.py @@ -156,9 +156,16 @@ class MRRReranker(Reranker): reciprocal_rank = 1.0 / rank mrr_score_map[result_id].append(reciprocal_rank) + # MRR averages the reciprocal rank across *all* ranking systems, treating + # a system in which a document does not appear as a reciprocal rank of 0. + # We therefore divide by the total number of systems, not by the number of + # systems the document happens to appear in -- otherwise a document found + # by a single ranking would outrank one ranked highly by every system, + # defeating the purpose of fusing the rankings. + num_systems = len(vector_results) final_mrr_scores = {} for result_id, reciprocal_ranks in mrr_score_map.items(): - mean_rr = np.mean(reciprocal_ranks) + mean_rr = float(np.sum(reciprocal_ranks)) / num_systems final_mrr_scores[result_id] = mean_rr combined = pa.concat_tables(vector_results, **self._concat_tables_args) diff --git a/python/python/tests/test_rerankers.py b/python/python/tests/test_rerankers.py index 5ab24be5d..3ea119670 100644 --- a/python/python/tests/test_rerankers.py +++ b/python/python/tests/test_rerankers.py @@ -350,6 +350,38 @@ def test_mrr_reranker_empty_input(): reranker.rerank_multivector([]) +def test_mrr_multivector_rewards_consensus(): + # Reciprocal ranks must be averaged across *all* ranking systems, treating a + # missing system as 0. A document ranked first by every system must outrank a + # document ranked first by only one of them. + reranker = MRRReranker() + + def ranking(row_ids): + return pa.table({"_rowid": pa.array(row_ids, type=pa.int64())}) + + # Doc 1 is rank 1 in only the first system; doc 2 is rank 1 in two systems + # and rank 2 in the third (strong cross-system consensus). + rs1 = ranking([1, 2, 3]) + rs2 = ranking([2, 3, 4]) + rs3 = ranking([2, 5, 6]) + + result = reranker.rerank_multivector([rs1, rs2, rs3]) + scores = { + row_id: score + for row_id, score in zip( + result["_rowid"].to_pylist(), + result["_relevance_score"].to_pylist(), + ) + } + + # sum of reciprocal ranks / number of systems + assert scores[1] == pytest.approx(1.0 / 3) + assert scores[2] == pytest.approx((0.5 + 1.0 + 1.0) / 3) + assert scores[2] > scores[1] + # The consensus document ranks first overall. + assert result["_rowid"].to_pylist()[0] == 2 + + def test_rrf_reranker_distance(): data = pa.table( {