Files
lancedb/python/python/tests/test_rerankers.py
Ayush Chaurasia e921c90c1b feat: support mean reciprocal rank reranker (#2671)
The basic idea of MRR is this -
https://www.evidentlyai.com/ranking-metrics/mean-reciprocal-rank-mrr
I've implemented a weighted version for allowing user to set weightage
between vector and fts.

The gist is something like this 

### Scenario A: Document at rank 1 in one set, absent from another

```
# Assuming equal weights: weight_vector = 0.5, weight_fts = 0.5
vector_rr = 1.0  # rank 1 → 1/1 = 1.0
fts_rr = 0.0     # absent → 0.0

weighted_mrr = 0.5 × 1.0 + 0.5 × 0.0 = 0.5
```
### Scenario B: Document at rank 1 in one set, rank 2 in another
```
# Same weights: weight_vector = 0.5, weight_fts = 0.5
vector_rr = 1.0  # rank 1 → 1/1 = 1.0
fts_rr = 0.5     # rank 2 → 1/2 = 0.5

weighted_mrr = 0.5 × 1.0 + 0.5 × 0.5 = 0.5 + 0.25 = 0.75
```

And so with `return_score="all"` the result looks something like this
(this is from the reranker tests).
Because this is a weighted rank based reranker, some results might have
the same score
```
                                                 text                                             vector     _distance      _rowid     _score  _relevance_score
0                                    I am your father  [-0.010703234, 0.069315575, 0.030076642, 0.002...  8.149148e-13  8589934598  10.978719          1.000000
1                          the ground beneath my feet  [-0.09500901, 0.00092102867, 0.0755851, 0.0372...  1.376896e+00  8589934604        NaN          0.250000
2                I find your lack of faith disturbing  [0.07525753, -0.0100010475, 0.09990541, 0.0209...           NaN  8589934595   3.483394          0.250000
3                               but I don't wanna die  [0.033476487, -0.011235877, -0.057625435, -0.0...  1.538222e+00  8589934610   1.130355          0.238095
4   if you strike me down I shall become more powe...  [0.00432201, 0.030120496, 5.3317923e-05, 0.033...  1.381086e+00  8589934594   0.715157          0.216667
5           I see a salty message written in the eves  [-0.04213107, 0.0016004723, 0.061052393, -0.02...  1.638301e+00  8589934603   1.043785          0.133333
6                              but his son was mortal  [0.012462767, 0.049041674, -0.057339743, -0.04...  1.421566e+00  8589934620        NaN          0.125000
7                   I've got a bad feeling about this  [-0.06973199, -0.029960092, 0.02641632, -0.031...           NaN  8589934596   1.043785          0.125000
8    now that's a name I haven't heard in a long time  [-0.014374257, -0.013588792, -0.07487557, 0.03...  1.597573e+00  8589934593   0.848772          0.118056
9                                        he was a god  [-0.0258895, 0.11925236, -0.029397793, 0.05888...  1.423147e+00  8589934618        NaN          0.100000
10                 I wish they would make another one  [-0.14737535, -0.015304729, 0.04318139, -0.061...           NaN  8589934622   1.043785          0.100000
11                                   Kratos had a son  [-0.057455737, 0.13734367, -0.03537109, -0.000...  1.488075e+00  8589934617        NaN          0.083333
12                       I don't wanna live like this  [-0.0028891307, 0.015214227, 0.025183653, 0.08...           NaN  8589934609   1.043785          0.071429
13             I see a mansard roof through the trees  [0.052383978, 0.087759204, 0.014739997, 0.0239...           NaN  8589934602   1.043785          0.062500
14                          great kid don't get cocky  [-0.047043696, 0.054648954, -0.008509666, -0.0...  1.618125e+00  8589934592        NaN          0.055556
```
2025-09-23 18:25:18 +05:30

548 lines
18 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import os
import random
import lancedb
import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
import pytest
from lancedb.conftest import MockTextEmbeddingFunction # noqa
from lancedb.embeddings import EmbeddingFunctionRegistry
from lancedb.pydantic import LanceModel, Vector
from lancedb.rerankers import (
LinearCombinationReranker,
RRFReranker,
CohereReranker,
ColbertReranker,
CrossEncoderReranker,
OpenaiReranker,
JinaReranker,
AnswerdotaiRerankers,
VoyageAIReranker,
MRRReranker,
)
from lancedb.table import LanceTable
# Tests rely on FTS index
pytest.importorskip("lancedb.fts")
def get_test_table(tmp_path, use_tantivy):
db = lancedb.connect(tmp_path)
# Create a LanceDB table schema with a vector and a text column
emb = EmbeddingFunctionRegistry.get_instance().get("test").create()
meta_emb = EmbeddingFunctionRegistry.get_instance().get("test").create()
class MyTable(LanceModel):
text: str = emb.SourceField()
vector: Vector(emb.ndims()) = emb.VectorField()
meta: str = meta_emb.SourceField()
meta_vector: Vector(meta_emb.ndims()) = meta_emb.VectorField()
# Initialize the table using the schema
table = LanceTable.create(
db,
"my_table",
schema=MyTable,
mode="overwrite",
)
# Need to test with a bunch of phrases to make sure sorting is consistent
phrases = [
"great kid don't get cocky",
"now that's a name I haven't heard in a long time",
"if you strike me down I shall become more powerful than you imagine",
"I find your lack of faith disturbing",
"I've got a bad feeling about this",
"never tell me the odds",
"I am your father",
"somebody has to save our skins",
"New strategy R2 let the wookiee win",
"Arrrrggghhhhhhh",
"I see a mansard roof through the trees",
"I see a salty message written in the eves",
"the ground beneath my feet",
"the hot garbage and concrete",
"and now the tops of buildings",
"everybody with a worried mind could never forgive the sight",
"of wicked snakes inside a place you thought was dignified",
"I don't wanna live like this",
"but I don't wanna die",
"The templars want control",
"the brotherhood of assassins want freedom",
"if only they could both see the world as it really is",
"there would be peace",
"but the war goes on",
"altair's legacy was a warning",
"Kratos had a son",
"he was a god",
"the god of war",
"but his son was mortal",
"there hasn't been a good battlefield game since 2142",
"I wish they would make another one",
"campains are not as good as they used to be",
"Multiplayer and open world games have destroyed the single player experience",
"Maybe the future is console games",
"I don't know",
]
# Add the phrases and vectors to the table
table.add(
[
{"text": p, "meta": phrases[random.randint(0, len(phrases) - 1)]}
for p in phrases
]
)
# Create a fts index
table.create_fts_index("text", use_tantivy=use_tantivy, replace=True)
return table, MyTable
def _run_test_reranker(reranker, table, query, query_vector, schema):
# Hybrid search setting
result1 = (
table.search(query, query_type="hybrid", vector_column_name="vector")
.rerank(normalize="score", reranker=reranker)
.to_pydantic(schema)
)
result2 = (
table.search(query, query_type="hybrid", vector_column_name="vector")
.rerank(reranker=reranker)
.to_pydantic(schema)
)
assert result1 == result2
query_vector = table.to_pandas()["vector"][0]
result = (
table.search(query_type="hybrid", vector_column_name="vector")
.vector(query_vector)
.text(query)
.limit(30)
.rerank(reranker=reranker)
.to_arrow()
)
assert len(result) == 30
ascending_relevance_err = (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
ascending_relevance_err
)
# Vector search setting
result = (
table.search(query, vector_column_name="vector")
.rerank(reranker=reranker)
.limit(30)
.to_arrow()
)
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
ascending_relevance_err
)
result_explicit = (
table.search(query_vector, vector_column_name="vector")
.rerank(reranker=reranker, query_string=query)
.limit(30)
.to_arrow()
)
assert len(result_explicit) == 30
with pytest.raises(
ValueError
): # This raises an error because vector query is provided without reanking query
table.search(query_vector, vector_column_name="vector").rerank(
reranker=reranker
).limit(30).to_arrow()
# FTS search setting
result = (
table.search(query, query_type="fts", vector_column_name="vector")
.rerank(reranker=reranker)
.limit(30)
.to_arrow()
)
assert len(result) > 0
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
ascending_relevance_err
)
# empty FTS results
query = "abcxyz" * 100
result = (
table.search(query_type="hybrid", vector_column_name="vector")
.vector(query_vector)
.text(query)
.limit(30)
.rerank(reranker=reranker)
.to_arrow()
)
# should return _relevance_score column
assert "_relevance_score" in result.column_names
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
ascending_relevance_err
)
# Multi-vector search setting
rs1 = table.search(query, vector_column_name="vector").limit(10).with_row_id(True)
rs2 = (
table.search(query, vector_column_name="meta_vector")
.limit(10)
.with_row_id(True)
)
result = reranker.rerank_multivector([rs1, rs2], query)
assert len(result) == 20
result_deduped = reranker.rerank_multivector(
[rs1, rs2, rs1], query, deduplicate=True
)
assert len(result_deduped) <= 20
result_arrow = reranker.rerank_multivector([rs1.to_arrow(), rs2.to_arrow()], query)
assert len(result) == 20 and result == result_arrow
def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy):
table, schema = get_test_table(tmp_path, use_tantivy)
# The default reranker
result1 = (
table.search(
"Our father who art in heaven",
query_type="hybrid",
vector_column_name="vector",
)
.rerank(normalize="score")
.to_pydantic(schema)
)
result2 = ( # noqa
table.search(
"Our father who art in heaven.",
query_type="hybrid",
vector_column_name="vector",
)
.rerank(normalize="rank")
.to_pydantic(schema)
)
result3 = table.search(
"Our father who art in heaven..",
query_type="hybrid",
vector_column_name="vector",
).to_pydantic(schema)
assert result1 == result3 # 2 & 3 should be the same as they use score as score
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search(query_type="hybrid", vector_column_name="vector")
.vector(query_vector)
.text(query)
.limit(30)
.rerank(reranker, normalize="score")
.to_arrow()
)
assert len(result) == 30
# Fail if both query and (vector or text) are provided
with pytest.raises(ValueError):
table.search(query, query_type="hybrid", vector_column_name="vector").vector(
query_vector
).to_arrow()
with pytest.raises(ValueError):
table.search(query, query_type="hybrid", vector_column_name="vector").text(
query
).to_arrow()
ascending_relevance_err = (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
ascending_relevance_err
)
# Test with empty FTS results
query = "abcxyz" * 100
result = (
table.search(query_type="hybrid", vector_column_name="vector")
.vector(query_vector)
.text(query)
.limit(30)
.rerank(reranker=reranker)
.to_arrow()
)
# should return _relevance_score column
assert "_relevance_score" in result.column_names
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
ascending_relevance_err
)
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_linear_combination(tmp_path, use_tantivy):
reranker = LinearCombinationReranker()
vector_results = pa.Table.from_pydict(
{
"_rowid": [0, 1, 2, 3, 4],
"_distance": [0.1, 0.2, 0.3, 0.4, 0.5],
"_text": ["a", "b", "c", "d", "e"],
}
)
fts_results = pa.Table.from_pydict(
{
"_rowid": [1, 2, 3, 4, 5],
"_score": [0.1, 0.2, 0.3, 0.4, 0.5],
"_text": ["b", "c", "d", "e", "f"],
}
)
combined_results = reranker.merge_results(vector_results, fts_results, 1.0)
assert len(combined_results) == 6
assert "_rowid" in combined_results.column_names
assert "_text" in combined_results.column_names
assert "_distance" not in combined_results.column_names
assert "_score" not in combined_results.column_names
assert "_relevance_score" in combined_results.column_names
_run_test_hybrid_reranker(reranker, tmp_path, use_tantivy)
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_rrf_reranker(tmp_path, use_tantivy):
reranker = RRFReranker()
_run_test_hybrid_reranker(reranker, tmp_path, use_tantivy)
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_mrr_reranker(tmp_path, use_tantivy):
reranker = MRRReranker()
_run_test_hybrid_reranker(reranker, tmp_path, use_tantivy)
# Test multi-vector part
table, schema = get_test_table(tmp_path, use_tantivy)
query = "single player experience"
rs1 = table.search(query, vector_column_name="vector").limit(10).with_row_id(True)
rs2 = (
table.search(query, vector_column_name="meta_vector")
.limit(10)
.with_row_id(True)
)
result = reranker.rerank_multivector([rs1, rs2])
assert "_relevance_score" in result.column_names
assert len(result) <= 20
if len(result) > 1:
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
"The _relevance_score should be descending."
)
# Test with duplicate results
result_deduped = reranker.rerank_multivector([rs1, rs2, rs1])
assert len(result_deduped) == len(result)
def test_rrf_reranker_distance():
data = pa.table(
{
"vector": pa.FixedSizeListArray.from_arrays(
pc.random(32 * 1024).cast(pa.float32()), 32
),
"text": pa.array(["hello"] * 1024),
}
)
db = lancedb.connect("memory://")
table = db.create_table("test", data)
table.create_index(num_partitions=1, num_sub_vectors=2)
table.create_fts_index("text", use_tantivy=False)
reranker = RRFReranker(return_score="all")
hybrid_results = (
table.search(query_type="hybrid")
.vector([0.0] * 32)
.text("hello")
.with_row_id(True)
.rerank(reranker)
.to_list()
)
hybrid_distances = {row["_rowid"]: row["_distance"] for row in hybrid_results}
hybrid_scores = {row["_rowid"]: row["_score"] for row in hybrid_results}
vector_results = table.search([0.0] * 32).with_row_id(True).to_list()
vector_distances = {row["_rowid"]: row["_distance"] for row in vector_results}
fts_results = table.search("hello", query_type="fts").with_row_id(True).to_list()
fts_scores = {row["_rowid"]: row["_score"] for row in fts_results}
found_match = False
for rowid, distance in hybrid_distances.items():
if rowid in vector_distances:
found_match = True
assert distance == vector_distances[rowid], "Distance mismatch"
assert found_match, "No results matched between hybrid and vector search"
found_match = False
for rowid, score in hybrid_scores.items():
if rowid in fts_scores and fts_scores[rowid] is not None:
found_match = True
assert score == fts_scores[rowid], "Score mismatch"
assert found_match, "No results matched between hybrid and fts search"
# Test for empty fts results
fts_results = (
table.search("abcxyz" * 100, query_type="fts").with_row_id(True).to_list()
)
hybrid_results = (
table.search(query_type="hybrid")
.vector([0.0] * 32)
.text("abcxyz" * 100)
.with_row_id(True)
.rerank(reranker)
.to_list()
)
assert len(fts_results) == 0
# confirm if _rowid, _score, _distance & _relevance_score are present in hybrid
assert len(hybrid_results) > 0
assert "_rowid" in hybrid_results[0]
assert "_score" in hybrid_results[0]
assert "_distance" in hybrid_results[0]
assert "_relevance_score" in hybrid_results[0]
@pytest.mark.skipif(
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
)
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_cohere_reranker(tmp_path, use_tantivy):
pytest.importorskip("cohere")
reranker = CohereReranker()
table, schema = get_test_table(tmp_path, use_tantivy)
_run_test_reranker(reranker, table, "single player experience", None, schema)
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_cross_encoder_reranker(tmp_path, use_tantivy):
pytest.importorskip("sentence_transformers")
reranker = CrossEncoderReranker()
table, schema = get_test_table(tmp_path, use_tantivy)
_run_test_reranker(reranker, table, "single player experience", None, schema)
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_colbert_reranker(tmp_path, use_tantivy):
pytest.importorskip("rerankers")
reranker = ColbertReranker()
table, schema = get_test_table(tmp_path, use_tantivy)
_run_test_reranker(reranker, table, "single player experience", None, schema)
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_answerdotai_reranker(tmp_path, use_tantivy):
pytest.importorskip("rerankers")
reranker = AnswerdotaiRerankers()
table, schema = get_test_table(tmp_path, use_tantivy)
_run_test_reranker(reranker, table, "single player experience", None, schema)
@pytest.mark.skipif(
os.environ.get("OPENAI_API_KEY") is None
or os.environ.get("OPENAI_BASE_URL") is not None,
reason="OPENAI_API_KEY not set",
)
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_openai_reranker(tmp_path, use_tantivy):
pytest.importorskip("openai")
table, schema = get_test_table(tmp_path, use_tantivy)
reranker = OpenaiReranker()
_run_test_reranker(reranker, table, "single player experience", None, schema)
@pytest.mark.skipif(
os.environ.get("JINA_API_KEY") is None, reason="JINA_API_KEY not set"
)
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_jina_reranker(tmp_path, use_tantivy):
pytest.importorskip("jina")
table, schema = get_test_table(tmp_path, use_tantivy)
reranker = JinaReranker()
_run_test_reranker(reranker, table, "single player experience", None, schema)
@pytest.mark.skipif(
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
)
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_voyageai_reranker(tmp_path, use_tantivy):
pytest.importorskip("voyageai")
reranker = VoyageAIReranker(model_name="rerank-2")
table, schema = get_test_table(tmp_path, use_tantivy)
_run_test_reranker(reranker, table, "single player experience", None, schema)
def test_empty_result_reranker():
pytest.importorskip("sentence_transformers")
db = lancedb.connect("memory://")
# Define schema
schema = pa.schema(
[
("id", pa.int64()),
("text", pa.string()),
("vector", pa.list_(pa.float32(), 128)), # 128-dimensional vector
]
)
# Create empty table with schema
empty_table = db.create_table("empty_table", schema=schema, mode="overwrite")
empty_table.create_fts_index("text", use_tantivy=False, replace=True)
for reranker in [
CrossEncoderReranker(),
# ColbertReranker(),
# AnswerdotaiRerankers(),
# OpenaiReranker(),
# JinaReranker(),
# VoyageAIReranker(model_name="rerank-2"),
]:
results = (
empty_table.search(list(range(128)))
.limit(3)
.rerank(reranker, "query")
.to_arrow()
)
# check if empty set contains _relevance_score column
assert "_relevance_score" in results.column_names
assert len(results) == 0
results = (
empty_table.search("query", query_type="fts")
.limit(3)
.rerank(reranker)
.to_arrow()
)
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_cross_encoder_reranker_return_all(tmp_path, use_tantivy):
pytest.importorskip("sentence_transformers")
reranker = CrossEncoderReranker(return_score="all")
table, schema = get_test_table(tmp_path, use_tantivy)
query = "single player experience"
result = (
table.search(query, query_type="hybrid", vector_column_name="vector")
.rerank(reranker=reranker)
.to_arrow()
)
assert "_relevance_score" in result.column_names
assert "_score" in result.column_names
assert "_distance" in result.column_names