Add cohere embedding function (#550)

This commit is contained in:
Ayush Chaurasia
2023-10-13 16:27:34 +05:30
committed by GitHub
parent db7bdefe77
commit 683824f1e9
5 changed files with 118 additions and 1 deletions

View File

@@ -12,6 +12,7 @@
# limitations under the License.
from .cohere import CohereEmbeddingFunction
from .functions import (
EmbeddingFunction,
EmbeddingFunctionConfig,

View File

@@ -0,0 +1,86 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import ClassVar, List, Union
import numpy as np
from .functions import TextEmbeddingFunction, register
from .utils import api_key_not_found_help
@register("cohere")
class CohereEmbeddingFunction(TextEmbeddingFunction):
"""
An embedding function that uses the Cohere API
https://docs.cohere.com/docs/multilingual-language-models
Parameters
----------
name: str, default "embed-multilingual-v2.0"
The name of the model to use. See the Cohere documentation for a list of available models.
Examples
--------
import lancedb
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import EmbeddingFunctionRegistry
cohere = EmbeddingFunctionRegistry.get_instance().get("cohere").create(name="embed-multilingual-v2.0")
class TextModel(LanceModel):
text: str = cohere.SourceField()
vector: Vector(cohere.ndims()) = cohere.VectorField()
data = [ { "text": "hello world" },
{ "text": "goodbye world" }]
db = lancedb.connect("~/.lancedb")
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(data)
"""
name: str = "embed-multilingual-v2.0"
client: ClassVar = None
def ndims(self):
# TODO: fix hardcoding
return 768
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
"""
Get the embeddings for the given texts
Parameters
----------
texts: list[str] or np.ndarray (of str)
The texts to embed
"""
# TODO retry, rate limit, token limit
self._init_client()
rs = CohereEmbeddingFunction.client.embed(texts=texts, model=self.name)
return [emb for emb in rs.embeddings]
def _init_client(self):
cohere = self.safe_import("cohere")
if CohereEmbeddingFunction.client is None:
if os.environ.get("COHERE_API_KEY") is None:
api_key_not_found_help("cohere")
CohereEmbeddingFunction.client = cohere.Client(os.environ["COHERE_API_KEY"])

View File

@@ -21,6 +21,7 @@ from lance.vector import vec_to_table
from retry import retry
from ..util import safe_import_pandas
from ..utils.general import LOGGER
pd = safe_import_pandas()
DATA = Union[pa.Table, "pd.DataFrame"]
@@ -152,3 +153,8 @@ class FunctionWrapper:
yield from tqdm(_chunker(arr), total=math.ceil(length / self._batch_size))
else:
yield from _chunker(arr)
def api_key_not_found_help(provider):
LOGGER.error(f"Could not find API key for {provider}.")
raise ValueError(f"Please set the {provider.upper()}_API_KEY environment variable.")

View File

@@ -52,7 +52,7 @@ tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests"]
dev = ["ruff", "pre-commit", "black"]
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
clip = ["torch", "pillow", "open-clip"]
embeddings = ["openai", "sentence-transformers", "torch", "pillow", "open-clip"]
embeddings = ["openai", "sentence-transformers", "torch", "pillow", "open-clip", "cohere"]
[project.scripts]
lancedb = "lancedb.cli.cli:cli"

View File

@@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
import numpy as np
import pandas as pd
@@ -123,3 +124,26 @@ def test_openclip(tmp_path):
arrow_table["vector"].combine_chunks().values.to_numpy(),
arrow_table["vec_from_bytes"].combine_chunks().values.to_numpy(),
)
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
) # also skip if cohere not installed
def test_cohere_embedding_function():
cohere = (
EmbeddingFunctionRegistry.get_instance()
.get("cohere")
.create(name="embed-multilingual-v2.0")
)
class TextModel(LanceModel):
text: str = cohere.SourceField()
vector: Vector(cohere.ndims()) = cohere.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect("~/lancedb")
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == cohere.ndims()