Add cohere embedding function (#550)

This commit is contained in:
Ayush Chaurasia
2023-10-13 16:27:34 +05:30
committed by GitHub
parent db7bdefe77
commit 683824f1e9
5 changed files with 118 additions and 1 deletions

View File

@@ -12,6 +12,7 @@
# limitations under the License.
from .cohere import CohereEmbeddingFunction
from .functions import (
EmbeddingFunction,
EmbeddingFunctionConfig,

View File

@@ -0,0 +1,86 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import ClassVar, List, Union
import numpy as np
from .functions import TextEmbeddingFunction, register
from .utils import api_key_not_found_help
@register("cohere")
class CohereEmbeddingFunction(TextEmbeddingFunction):
"""
An embedding function that uses the Cohere API
https://docs.cohere.com/docs/multilingual-language-models
Parameters
----------
name: str, default "embed-multilingual-v2.0"
The name of the model to use. See the Cohere documentation for a list of available models.
Examples
--------
import lancedb
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import EmbeddingFunctionRegistry
cohere = EmbeddingFunctionRegistry.get_instance().get("cohere").create(name="embed-multilingual-v2.0")
class TextModel(LanceModel):
text: str = cohere.SourceField()
vector: Vector(cohere.ndims()) = cohere.VectorField()
data = [ { "text": "hello world" },
{ "text": "goodbye world" }]
db = lancedb.connect("~/.lancedb")
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(data)
"""
name: str = "embed-multilingual-v2.0"
client: ClassVar = None
def ndims(self):
# TODO: fix hardcoding
return 768
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
"""
Get the embeddings for the given texts
Parameters
----------
texts: list[str] or np.ndarray (of str)
The texts to embed
"""
# TODO retry, rate limit, token limit
self._init_client()
rs = CohereEmbeddingFunction.client.embed(texts=texts, model=self.name)
return [emb for emb in rs.embeddings]
def _init_client(self):
cohere = self.safe_import("cohere")
if CohereEmbeddingFunction.client is None:
if os.environ.get("COHERE_API_KEY") is None:
api_key_not_found_help("cohere")
CohereEmbeddingFunction.client = cohere.Client(os.environ["COHERE_API_KEY"])

View File

@@ -21,6 +21,7 @@ from lance.vector import vec_to_table
from retry import retry
from ..util import safe_import_pandas
from ..utils.general import LOGGER
pd = safe_import_pandas()
DATA = Union[pa.Table, "pd.DataFrame"]
@@ -152,3 +153,8 @@ class FunctionWrapper:
yield from tqdm(_chunker(arr), total=math.ceil(length / self._batch_size))
else:
yield from _chunker(arr)
def api_key_not_found_help(provider):
LOGGER.error(f"Could not find API key for {provider}.")
raise ValueError(f"Please set the {provider.upper()}_API_KEY environment variable.")