From 5d8c91256c0fc807ceabad55bf0ba746270004f0 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Thu, 11 Apr 2024 15:20:29 +0530 Subject: [PATCH] fix(python): Update to latest cohere reranking api (#1212) Fixes https://github.com/lancedb/lancedb/issues/1196 Cohere introduced a breaking change in their reranker API starting version 5.0.0. More context in discussion here https://github.com/cohere-ai/cohere-python/issues/446 --- python/python/lancedb/rerankers/cohere.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/python/python/lancedb/rerankers/cohere.py b/python/python/lancedb/rerankers/cohere.py index a1ccb060..373e76b8 100644 --- a/python/python/lancedb/rerankers/cohere.py +++ b/python/python/lancedb/rerankers/cohere.py @@ -1,4 +1,5 @@ import os +import semver from functools import cached_property from typing import Union @@ -42,6 +43,14 @@ class CohereReranker(Reranker): @cached_property def _client(self): cohere = attempt_import_or_raise("cohere") + # ensure version is at least 0.5.0 + if ( + hasattr(cohere, "__version__") + and semver.compare(cohere.__version__, "5.0.0") < 0 + ): + raise ValueError( + f"cohere version must be at least 0.5.0, found {cohere.__version__}" + ) if os.environ.get("COHERE_API_KEY") is None and self.api_key is None: raise ValueError( "COHERE_API_KEY not set. Either set it in your environment or \ @@ -51,11 +60,14 @@ class CohereReranker(Reranker): def _rerank(self, result_set: pa.Table, query: str): docs = result_set[self.column].to_pylist() - results = self._client.rerank( + response = self._client.rerank( query=query, documents=docs, top_n=self.top_n, model=self.model_name, + ) + results = ( + response.results ) # returns list (text, idx, relevance) attributes sorted descending by score indices, scores = list( zip(*[(result.index, result.relevance_score) for result in results])