fix(python): Update to latest cohere reranking api (#1212)

Fixes https://github.com/lancedb/lancedb/issues/1196
Cohere introduced a breaking change in their reranker API starting
version 5.0.0. More context in discussion here
https://github.com/cohere-ai/cohere-python/issues/446
This commit is contained in:
Ayush Chaurasia
2024-04-11 15:20:29 +05:30
committed by GitHub
parent 44c03ebef3
commit 5d8c91256c

View File

@@ -1,4 +1,5 @@
import os
import semver
from functools import cached_property
from typing import Union
@@ -42,6 +43,14 @@ class CohereReranker(Reranker):
@cached_property
def _client(self):
cohere = attempt_import_or_raise("cohere")
# ensure version is at least 0.5.0
if (
hasattr(cohere, "__version__")
and semver.compare(cohere.__version__, "5.0.0") < 0
):
raise ValueError(
f"cohere version must be at least 0.5.0, found {cohere.__version__}"
)
if os.environ.get("COHERE_API_KEY") is None and self.api_key is None:
raise ValueError(
"COHERE_API_KEY not set. Either set it in your environment or \
@@ -51,11 +60,14 @@ class CohereReranker(Reranker):
def _rerank(self, result_set: pa.Table, query: str):
docs = result_set[self.column].to_pylist()
results = self._client.rerank(
response = self._client.rerank(
query=query,
documents=docs,
top_n=self.top_n,
model=self.model_name,
)
results = (
response.results
) # returns list (text, idx, relevance) attributes sorted descending by score
indices, scores = list(
zip(*[(result.index, result.relevance_score) for result in results])