mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-05 03:12:57 +00:00
based on https://github.com/lancedb/lancedb/pull/713 - The Reranker api can be plugged into vector only or fts only search but this PR doesn't do that (see example - https://txt.cohere.com/rerank/) ### Default reranker -- `LinearCombinationReranker(weight=0.7, fill=1.0)` ``` table.search("hello", query_type="hybrid").rerank(normalize="score").to_pandas() ``` ### Available rerankers LinearCombinationReranker ``` from lancedb.rerankers import LinearCombinationReranker # Same as default table.search("hello", query_type="hybrid").rerank( normalize="score", reranker=LinearCombinationReranker() ).to_pandas() # with custom params reranker = LinearCombinationReranker(weight=0.3, fill=1.0) table.search("hello", query_type="hybrid").rerank( normalize="score", reranker=reranker ).to_pandas() ``` Cohere Reranker ``` from lancedb.rerankers import CohereReranker # default model.. English and multi-lingual supported. See docstring for available custom params table.search("hello", query_type="hybrid").rerank( normalize="rank", # score or rank reranker=CohereReranker() ).to_pandas() ``` CrossEncoderReranker ``` from lancedb.rerankers import CrossEncoderReranker table.search("hello", query_type="hybrid").rerank( normalize="rank", reranker=CrossEncoderReranker() ).to_pandas() ``` ## Using custom Reranker ``` from lancedb.reranker import Reranker class CustomReranker(Reranker): def rerank_hybrid(self, vector_result, fts_result): combined_res = self.merge_results(vector_results, fts_results) # or use custom combination logic # Custom rerank logic here return combined_res ``` - [x] Expand testing - [x] Make sure usage makes sense - [x] Run simple benchmarks for correctness (Seeing weird result from cohere reranker in the toy example) - Support diverse rerankers by default: - [x] Cross encoding - [x] Cohere - [x] Reciprocal Rank Fusion --------- Co-authored-by: Chang She <759245+changhiskhan@users.noreply.github.com> Co-authored-by: Prashanth Rao <35005448+prrao87@users.noreply.github.com>
79 lines
2.6 KiB
Python
79 lines
2.6 KiB
Python
import typing
|
|
from functools import cached_property
|
|
from typing import Union
|
|
|
|
import pyarrow as pa
|
|
|
|
from ..util import safe_import
|
|
from .base import Reranker
|
|
|
|
if typing.TYPE_CHECKING:
|
|
import lancedb
|
|
|
|
|
|
class CrossEncoderReranker(Reranker):
|
|
"""
|
|
Reranks the results using a cross encoder model. The cross encoder model is
|
|
used to score the query and each result. The results are then sorted by the score.
|
|
|
|
Parameters
|
|
----------
|
|
model : str, default "cross-encoder/ms-marco-TinyBERT-L-6"
|
|
The name of the cross encoder model to use. See the sentence transformers
|
|
documentation for a list of available models.
|
|
column : str, default "text"
|
|
The name of the column to use as input to the cross encoder model.
|
|
device : str, default None
|
|
The device to use for the cross encoder model. If None, will use "cuda"
|
|
if available, otherwise "cpu".
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_name: str = "cross-encoder/ms-marco-TinyBERT-L-6",
|
|
column: str = "text",
|
|
device: Union[str, None] = None,
|
|
return_score="relevance",
|
|
):
|
|
super().__init__(return_score)
|
|
torch = safe_import("torch")
|
|
self.model_name = model_name
|
|
self.column = column
|
|
self.device = device
|
|
if self.device is None:
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
@cached_property
|
|
def model(self):
|
|
sbert = safe_import("sentence_transformers")
|
|
cross_encoder = sbert.CrossEncoder(self.model_name)
|
|
|
|
return cross_encoder
|
|
|
|
def rerank_hybrid(
|
|
self,
|
|
query_builder: "lancedb.HybridQueryBuilder",
|
|
vector_results: pa.Table,
|
|
fts_results: pa.Table,
|
|
):
|
|
combined_results = self.merge_results(vector_results, fts_results)
|
|
passages = combined_results[self.column].to_pylist()
|
|
cross_inp = [[query_builder._query, passage] for passage in passages]
|
|
cross_scores = self.model.predict(cross_inp)
|
|
combined_results = combined_results.append_column(
|
|
"_relevance_score", pa.array(cross_scores, type=pa.float32())
|
|
)
|
|
|
|
# sort the results by _score
|
|
if self.score == "relevance":
|
|
combined_results = combined_results.drop_columns(["score", "_distance"])
|
|
elif self.score == "all":
|
|
raise NotImplementedError(
|
|
"return_score='all' not implemented for CrossEncoderReranker"
|
|
)
|
|
combined_results = combined_results.sort_by(
|
|
[("_relevance_score", "descending")]
|
|
)
|
|
|
|
return combined_results
|