mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-04 10:52:56 +00:00
fix(python): Update to latest cohere reranking api (#1212)
Fixes https://github.com/lancedb/lancedb/issues/1196 Cohere introduced a breaking change in their reranker API starting version 5.0.0. More context in discussion here https://github.com/cohere-ai/cohere-python/issues/446
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import semver
|
||||
from functools import cached_property
|
||||
from typing import Union
|
||||
|
||||
@@ -42,6 +43,14 @@ class CohereReranker(Reranker):
|
||||
@cached_property
|
||||
def _client(self):
|
||||
cohere = attempt_import_or_raise("cohere")
|
||||
# ensure version is at least 0.5.0
|
||||
if (
|
||||
hasattr(cohere, "__version__")
|
||||
and semver.compare(cohere.__version__, "5.0.0") < 0
|
||||
):
|
||||
raise ValueError(
|
||||
f"cohere version must be at least 0.5.0, found {cohere.__version__}"
|
||||
)
|
||||
if os.environ.get("COHERE_API_KEY") is None and self.api_key is None:
|
||||
raise ValueError(
|
||||
"COHERE_API_KEY not set. Either set it in your environment or \
|
||||
@@ -51,11 +60,14 @@ class CohereReranker(Reranker):
|
||||
|
||||
def _rerank(self, result_set: pa.Table, query: str):
|
||||
docs = result_set[self.column].to_pylist()
|
||||
results = self._client.rerank(
|
||||
response = self._client.rerank(
|
||||
query=query,
|
||||
documents=docs,
|
||||
top_n=self.top_n,
|
||||
model=self.model_name,
|
||||
)
|
||||
results = (
|
||||
response.results
|
||||
) # returns list (text, idx, relevance) attributes sorted descending by score
|
||||
indices, scores = list(
|
||||
zip(*[(result.index, result.relevance_score) for result in results])
|
||||
|
||||
Reference in New Issue
Block a user