Files
lancedb/python/tests/test_rerankers.py
Ayush Chaurasia 3ffed89793 feat(python): Hybrid search & Reranker API (#824)
based on https://github.com/lancedb/lancedb/pull/713
- The Reranker api can be plugged into vector only or fts only search
but this PR doesn't do that (see example -
https://txt.cohere.com/rerank/)


### Default reranker -- `LinearCombinationReranker(weight=0.7,
fill=1.0)`

```
table.search("hello", query_type="hybrid").rerank(normalize="score").to_pandas()
```
### Available rerankers
LinearCombinationReranker
```
from lancedb.rerankers import LinearCombinationReranker

# Same as default 
table.search("hello", query_type="hybrid").rerank(
                                      normalize="score", 
                                      reranker=LinearCombinationReranker()
                                     ).to_pandas()

# with custom params
reranker = LinearCombinationReranker(weight=0.3, fill=1.0)
table.search("hello", query_type="hybrid").rerank(
                                      normalize="score", 
                                      reranker=reranker
                                     ).to_pandas()
```

Cohere Reranker
```
from lancedb.rerankers import CohereReranker

# default model.. English and multi-lingual supported. See docstring for available custom params
table.search("hello", query_type="hybrid").rerank(
                                      normalize="rank",  # score or rank
                                      reranker=CohereReranker()
                                     ).to_pandas()

```

CrossEncoderReranker

```
from lancedb.rerankers import CrossEncoderReranker

table.search("hello", query_type="hybrid").rerank(
                                      normalize="rank", 
                                      reranker=CrossEncoderReranker()
                                     ).to_pandas()

```

## Using custom Reranker
```
from lancedb.reranker import Reranker

class CustomReranker(Reranker):
    def rerank_hybrid(self, vector_result, fts_result):
           combined_res = self.merge_results(vector_results, fts_results) # or use custom combination logic
           # Custom rerank logic here
           
           return combined_res
```

- [x] Expand testing
- [x] Make sure usage makes sense
- [x] Run simple benchmarks for correctness (Seeing weird result from
cohere reranker in the toy example)
- Support diverse rerankers by default:
- [x] Cross encoding
- [x] Cohere
- [x] Reciprocal Rank Fusion

---------

Co-authored-by: Chang She <759245+changhiskhan@users.noreply.github.com>
Co-authored-by: Prashanth Rao <35005448+prrao87@users.noreply.github.com>
2024-01-30 19:10:33 +05:30

169 lines
5.7 KiB
Python

import os
import numpy as np
import pytest
import lancedb
from lancedb.conftest import MockTextEmbeddingFunction # noqa
from lancedb.embeddings import EmbeddingFunctionRegistry
from lancedb.pydantic import LanceModel, Vector
from lancedb.rerankers import CohereReranker, CrossEncoderReranker
from lancedb.table import LanceTable
def get_test_table(tmp_path):
db = lancedb.connect(tmp_path)
# Create a LanceDB table schema with a vector and a text column
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
class MyTable(LanceModel):
text: str = emb.SourceField()
vector: Vector(emb.ndims()) = emb.VectorField()
# Initialize the table using the schema
table = LanceTable.create(
db,
"my_table",
schema=MyTable,
)
# Need to test with a bunch of phrases to make sure sorting is consistent
phrases = [
"great kid don't get cocky",
"now that's a name I haven't heard in a long time",
"if you strike me down I shall become more powerful than you imagine",
"I find your lack of faith disturbing",
"I've got a bad feeling about this",
"never tell me the odds",
"I am your father",
"somebody has to save our skins",
"New strategy R2 let the wookiee win",
"Arrrrggghhhhhhh",
"I see a mansard roof through the trees",
"I see a salty message written in the eves",
"the ground beneath my feet",
"the hot garbage and concrete",
"and now the tops of buildings",
"everybody with a worried mind could never forgive the sight",
"of wicked snakes inside a place you thought was dignified",
"I don't wanna live like this",
"but I don't wanna die",
"The templars want control",
"the brotherhood of assassins want freedom",
"if only they could both see the world as it really is",
"there would be peace",
"but the war goes on",
"altair's legacy was a warning",
"Kratos had a son",
"he was a god",
"the god of war",
"but his son was mortal",
"there hasn't been a good battlefield game since 2142",
"I wish they would make another one",
"campains are not as good as they used to be",
"Multiplayer and open world games have destroyed the single player experience",
"Maybe the future is console games",
"I don't know",
]
# Add the phrases and vectors to the table
table.add([{"text": p} for p in phrases])
# Create a fts index
table.create_fts_index("text")
return table, MyTable
## These tests are pretty loose, we should also check for correctness
def test_linear_combination(tmp_path):
table, schema = get_test_table(tmp_path)
# The default reranker
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score")
.to_pydantic(schema)
)
result2 = ( # noqa
table.search("Our father who art in heaven.", query_type="hybrid")
.rerank(normalize="rank")
.to_pydantic(schema)
)
result3 = table.search(
"Our father who art in heaven..", query_type="hybrid"
).to_pydantic(schema)
assert result1 == result3 # 2 & 3 should be the same as they use score as score
result = (
table.search("Our father who art in heaven", query_type="hybrid")
.limit(50)
.rerank(normalize="score")
.to_arrow()
)
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
"The _score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
@pytest.mark.skipif(
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
)
def test_cohere_reranker(tmp_path):
pytest.importorskip("cohere")
table, schema = get_test_table(tmp_path)
# The default reranker
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score", reranker=CohereReranker())
.to_pydantic(schema)
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="rank", reranker=CohereReranker())
.to_pydantic(schema)
)
assert result1 == result2
result = (
table.search("Our father who art in heaven", query_type="hybrid")
.limit(50)
.rerank(reranker=CohereReranker())
.to_arrow()
)
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
"The _score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
def test_cross_encoder_reranker(tmp_path):
pytest.importorskip("sentence_transformers")
table, schema = get_test_table(tmp_path)
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score", reranker=CrossEncoderReranker())
.to_pydantic(schema)
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="rank", reranker=CrossEncoderReranker())
.to_pydantic(schema)
)
assert result1 == result2
result = (
table.search("Our father who art in heaven", query_type="hybrid")
.limit(50)
.rerank(reranker=CrossEncoderReranker())
.to_arrow()
)
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
"The _score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)