mirror of
https://github.com/lancedb/lancedb.git
synced 2026-03-28 03:20:39 +00:00
Previously `return_score="all"` was supported only for the default reranker (RRF) and not the model based rerankers. This adds support for keeping all scores in the base reranker so that all model based rerankers can use it. Its a slower path than keeping just the relevance score but can be useful in debugging
116 lines
3.9 KiB
Python
116 lines
3.9 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
|
|
|
|
|
import os
|
|
from packaging.version import Version
|
|
from functools import cached_property
|
|
from typing import Union
|
|
|
|
import pyarrow as pa
|
|
|
|
from ..util import attempt_import_or_raise
|
|
from .base import Reranker
|
|
|
|
|
|
class CohereReranker(Reranker):
|
|
"""
|
|
Reranks the results using the Cohere Rerank API.
|
|
https://docs.cohere.com/docs/rerank-guide
|
|
|
|
Parameters
|
|
----------
|
|
model_name : str, default "rerank-english-v2.0"
|
|
The name of the cross encoder model to use. Available cohere models are:
|
|
- rerank-english-v2.0
|
|
- rerank-multilingual-v2.0
|
|
column : str, default "text"
|
|
The name of the column to use as input to the cross encoder model.
|
|
top_n : str, default None
|
|
The number of results to return. If None, will return all results.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_name: str = "rerank-english-v3.0",
|
|
column: str = "text",
|
|
top_n: Union[int, None] = None,
|
|
return_score="relevance",
|
|
api_key: Union[str, None] = None,
|
|
):
|
|
super().__init__(return_score)
|
|
self.model_name = model_name
|
|
self.column = column
|
|
self.top_n = top_n
|
|
self.api_key = api_key
|
|
|
|
@cached_property
|
|
def _client(self):
|
|
cohere = attempt_import_or_raise("cohere")
|
|
# ensure version is at least 0.5.0
|
|
if hasattr(cohere, "__version__") and Version(cohere.__version__) < Version(
|
|
"0.5.0"
|
|
):
|
|
raise ValueError(
|
|
f"cohere version must be at least 0.5.0, found {cohere.__version__}"
|
|
)
|
|
if os.environ.get("COHERE_API_KEY") is None and self.api_key is None:
|
|
raise ValueError(
|
|
"COHERE_API_KEY not set. Either set it in your environment or \
|
|
pass it as `api_key` argument to the CohereReranker."
|
|
)
|
|
return cohere.Client(os.environ.get("COHERE_API_KEY") or self.api_key)
|
|
|
|
def _rerank(self, result_set: pa.Table, query: str):
|
|
result_set = self._handle_empty_results(result_set)
|
|
if len(result_set) == 0:
|
|
return result_set
|
|
docs = result_set[self.column].to_pylist()
|
|
response = self._client.rerank(
|
|
query=query,
|
|
documents=docs,
|
|
top_n=self.top_n,
|
|
model=self.model_name,
|
|
)
|
|
results = (
|
|
response.results
|
|
) # returns list (text, idx, relevance) attributes sorted descending by score
|
|
indices, scores = list(
|
|
zip(*[(result.index, result.relevance_score) for result in results])
|
|
) # tuples
|
|
result_set = result_set.take(list(indices))
|
|
# add the scores
|
|
result_set = result_set.append_column(
|
|
"_relevance_score", pa.array(scores, type=pa.float32())
|
|
)
|
|
|
|
return result_set
|
|
|
|
def rerank_hybrid(
|
|
self,
|
|
query: str,
|
|
vector_results: pa.Table,
|
|
fts_results: pa.Table,
|
|
):
|
|
if self.score == "all":
|
|
combined_results = self._merge_and_keep_scores(vector_results, fts_results)
|
|
else:
|
|
combined_results = self.merge_results(vector_results, fts_results)
|
|
combined_results = self._rerank(combined_results, query)
|
|
if self.score == "relevance":
|
|
combined_results = self._keep_relevance_score(combined_results)
|
|
|
|
return combined_results
|
|
|
|
def rerank_vector(self, query: str, vector_results: pa.Table):
|
|
vector_results = self._rerank(vector_results, query)
|
|
if self.score == "relevance":
|
|
vector_results = vector_results.drop_columns(["_distance"])
|
|
return vector_results
|
|
|
|
def rerank_fts(self, query: str, fts_results: pa.Table):
|
|
fts_results = self._rerank(fts_results, query)
|
|
if self.score == "relevance":
|
|
fts_results = fts_results.drop_columns(["_score"])
|
|
return fts_results
|