diff --git a/python/python/lancedb/rerankers/cohere.py b/python/python/lancedb/rerankers/cohere.py index a1ccb060..373e76b8 100644 --- a/python/python/lancedb/rerankers/cohere.py +++ b/python/python/lancedb/rerankers/cohere.py @@ -1,4 +1,5 @@ import os +import semver from functools import cached_property from typing import Union @@ -42,6 +43,14 @@ class CohereReranker(Reranker): @cached_property def _client(self): cohere = attempt_import_or_raise("cohere") + # ensure version is at least 0.5.0 + if ( + hasattr(cohere, "__version__") + and semver.compare(cohere.__version__, "5.0.0") < 0 + ): + raise ValueError( + f"cohere version must be at least 0.5.0, found {cohere.__version__}" + ) if os.environ.get("COHERE_API_KEY") is None and self.api_key is None: raise ValueError( "COHERE_API_KEY not set. Either set it in your environment or \ @@ -51,11 +60,14 @@ class CohereReranker(Reranker): def _rerank(self, result_set: pa.Table, query: str): docs = result_set[self.column].to_pylist() - results = self._client.rerank( + response = self._client.rerank( query=query, documents=docs, top_n=self.top_n, model=self.model_name, + ) + results = ( + response.results ) # returns list (text, idx, relevance) attributes sorted descending by score indices, scores = list( zip(*[(result.index, result.relevance_score) for result in results])