mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-07 04:12:59 +00:00
feat(python): Hybrid search & Reranker API (#824)
based on https://github.com/lancedb/lancedb/pull/713 - The Reranker api can be plugged into vector only or fts only search but this PR doesn't do that (see example - https://txt.cohere.com/rerank/) ### Default reranker -- `LinearCombinationReranker(weight=0.7, fill=1.0)` ``` table.search("hello", query_type="hybrid").rerank(normalize="score").to_pandas() ``` ### Available rerankers LinearCombinationReranker ``` from lancedb.rerankers import LinearCombinationReranker # Same as default table.search("hello", query_type="hybrid").rerank( normalize="score", reranker=LinearCombinationReranker() ).to_pandas() # with custom params reranker = LinearCombinationReranker(weight=0.3, fill=1.0) table.search("hello", query_type="hybrid").rerank( normalize="score", reranker=reranker ).to_pandas() ``` Cohere Reranker ``` from lancedb.rerankers import CohereReranker # default model.. English and multi-lingual supported. See docstring for available custom params table.search("hello", query_type="hybrid").rerank( normalize="rank", # score or rank reranker=CohereReranker() ).to_pandas() ``` CrossEncoderReranker ``` from lancedb.rerankers import CrossEncoderReranker table.search("hello", query_type="hybrid").rerank( normalize="rank", reranker=CrossEncoderReranker() ).to_pandas() ``` ## Using custom Reranker ``` from lancedb.reranker import Reranker class CustomReranker(Reranker): def rerank_hybrid(self, vector_result, fts_result): combined_res = self.merge_results(vector_results, fts_results) # or use custom combination logic # Custom rerank logic here return combined_res ``` - [x] Expand testing - [x] Make sure usage makes sense - [x] Run simple benchmarks for correctness (Seeing weird result from cohere reranker in the toy example) - Support diverse rerankers by default: - [x] Cross encoding - [x] Cohere - [x] Reciprocal Rank Fusion --------- Co-authored-by: Chang She <759245+changhiskhan@users.noreply.github.com> Co-authored-by: Prashanth Rao <35005448+prrao87@users.noreply.github.com>
This commit is contained in:
committed by
Weston Pace
parent
ecbbe185c7
commit
a41f7be88d
168
python/tests/test_rerankers.py
Normal file
168
python/tests/test_rerankers.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import lancedb
|
||||
from lancedb.conftest import MockTextEmbeddingFunction # noqa
|
||||
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.rerankers import CohereReranker, CrossEncoderReranker
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
|
||||
def get_test_table(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
# Create a LanceDB table schema with a vector and a text column
|
||||
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
|
||||
|
||||
class MyTable(LanceModel):
|
||||
text: str = emb.SourceField()
|
||||
vector: Vector(emb.ndims()) = emb.VectorField()
|
||||
|
||||
# Initialize the table using the schema
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
schema=MyTable,
|
||||
)
|
||||
|
||||
# Need to test with a bunch of phrases to make sure sorting is consistent
|
||||
phrases = [
|
||||
"great kid don't get cocky",
|
||||
"now that's a name I haven't heard in a long time",
|
||||
"if you strike me down I shall become more powerful than you imagine",
|
||||
"I find your lack of faith disturbing",
|
||||
"I've got a bad feeling about this",
|
||||
"never tell me the odds",
|
||||
"I am your father",
|
||||
"somebody has to save our skins",
|
||||
"New strategy R2 let the wookiee win",
|
||||
"Arrrrggghhhhhhh",
|
||||
"I see a mansard roof through the trees",
|
||||
"I see a salty message written in the eves",
|
||||
"the ground beneath my feet",
|
||||
"the hot garbage and concrete",
|
||||
"and now the tops of buildings",
|
||||
"everybody with a worried mind could never forgive the sight",
|
||||
"of wicked snakes inside a place you thought was dignified",
|
||||
"I don't wanna live like this",
|
||||
"but I don't wanna die",
|
||||
"The templars want control",
|
||||
"the brotherhood of assassins want freedom",
|
||||
"if only they could both see the world as it really is",
|
||||
"there would be peace",
|
||||
"but the war goes on",
|
||||
"altair's legacy was a warning",
|
||||
"Kratos had a son",
|
||||
"he was a god",
|
||||
"the god of war",
|
||||
"but his son was mortal",
|
||||
"there hasn't been a good battlefield game since 2142",
|
||||
"I wish they would make another one",
|
||||
"campains are not as good as they used to be",
|
||||
"Multiplayer and open world games have destroyed the single player experience",
|
||||
"Maybe the future is console games",
|
||||
"I don't know",
|
||||
]
|
||||
|
||||
# Add the phrases and vectors to the table
|
||||
table.add([{"text": p} for p in phrases])
|
||||
|
||||
# Create a fts index
|
||||
table.create_fts_index("text")
|
||||
|
||||
return table, MyTable
|
||||
|
||||
|
||||
## These tests are pretty loose, we should also check for correctness
|
||||
def test_linear_combination(tmp_path):
|
||||
table, schema = get_test_table(tmp_path)
|
||||
# The default reranker
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="score")
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result2 = ( # noqa
|
||||
table.search("Our father who art in heaven.", query_type="hybrid")
|
||||
.rerank(normalize="rank")
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result3 = table.search(
|
||||
"Our father who art in heaven..", query_type="hybrid"
|
||||
).to_pydantic(schema)
|
||||
|
||||
assert result1 == result3 # 2 & 3 should be the same as they use score as score
|
||||
|
||||
result = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.limit(50)
|
||||
.rerank(normalize="score")
|
||||
.to_arrow()
|
||||
)
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
"The _score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
|
||||
)
|
||||
def test_cohere_reranker(tmp_path):
|
||||
pytest.importorskip("cohere")
|
||||
table, schema = get_test_table(tmp_path)
|
||||
# The default reranker
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="score", reranker=CohereReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result2 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="rank", reranker=CohereReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
assert result1 == result2
|
||||
|
||||
result = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.limit(50)
|
||||
.rerank(reranker=CohereReranker())
|
||||
.to_arrow()
|
||||
)
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
"The _score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
|
||||
|
||||
def test_cross_encoder_reranker(tmp_path):
|
||||
pytest.importorskip("sentence_transformers")
|
||||
table, schema = get_test_table(tmp_path)
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="score", reranker=CrossEncoderReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result2 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="rank", reranker=CrossEncoderReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
assert result1 == result2
|
||||
|
||||
result = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.limit(50)
|
||||
.rerank(reranker=CrossEncoderReranker())
|
||||
.to_arrow()
|
||||
)
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
"The _score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
@@ -682,3 +682,57 @@ def test_count_rows(db):
|
||||
assert len(table) == 2
|
||||
assert table.count_rows() == 2
|
||||
assert table.count_rows(filter="text='bar'") == 1
|
||||
|
||||
|
||||
def test_hybrid_search(db):
|
||||
# hardcoding temporarily.. this test is failing with tmp_path mockdb.
|
||||
# Probably not being parsed right by the fts
|
||||
db = MockDB("~/lancedb_")
|
||||
# Create a LanceDB table schema with a vector and a text column
|
||||
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
|
||||
|
||||
class MyTable(LanceModel):
|
||||
text: str = emb.SourceField()
|
||||
vector: Vector(emb.ndims()) = emb.VectorField()
|
||||
|
||||
# Initialize the table using the schema
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
schema=MyTable,
|
||||
)
|
||||
|
||||
# Create a list of 10 unique english phrases
|
||||
phrases = [
|
||||
"great kid don't get cocky",
|
||||
"now that's a name I haven't heard in a long time",
|
||||
"if you strike me down I shall become more powerful than you imagine",
|
||||
"I find your lack of faith disturbing",
|
||||
"I've got a bad feeling about this",
|
||||
"never tell me the odds",
|
||||
"I am your father",
|
||||
"somebody has to save our skins",
|
||||
"New strategy R2 let the wookiee win",
|
||||
"Arrrrggghhhhhhh",
|
||||
]
|
||||
|
||||
# Add the phrases and vectors to the table
|
||||
table.add([{"text": p} for p in phrases])
|
||||
|
||||
# Create a fts index
|
||||
table.create_fts_index("text")
|
||||
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="score")
|
||||
.to_pydantic(MyTable)
|
||||
)
|
||||
result2 = ( # noqa
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="rank")
|
||||
.to_pydantic(MyTable)
|
||||
)
|
||||
result3 = table.search(
|
||||
"Our father who art in heaven", query_type="hybrid"
|
||||
).to_pydantic(MyTable)
|
||||
assert result1 == result3
|
||||
|
||||
Reference in New Issue
Block a user