diff --git a/python/lancedb/embeddings/__init__.py b/python/lancedb/embeddings/__init__.py index 454157cb..a71b3595 100644 --- a/python/lancedb/embeddings/__init__.py +++ b/python/lancedb/embeddings/__init__.py @@ -12,6 +12,7 @@ # limitations under the License. +from .cohere import CohereEmbeddingFunction from .functions import ( EmbeddingFunction, EmbeddingFunctionConfig, diff --git a/python/lancedb/embeddings/cohere.py b/python/lancedb/embeddings/cohere.py new file mode 100644 index 00000000..d9733dbc --- /dev/null +++ b/python/lancedb/embeddings/cohere.py @@ -0,0 +1,86 @@ +# Copyright (c) 2023. LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import ClassVar, List, Union + +import numpy as np + +from .functions import TextEmbeddingFunction, register +from .utils import api_key_not_found_help + + +@register("cohere") +class CohereEmbeddingFunction(TextEmbeddingFunction): + """ + An embedding function that uses the Cohere API + + https://docs.cohere.com/docs/multilingual-language-models + + Parameters + ---------- + name: str, default "embed-multilingual-v2.0" + The name of the model to use. See the Cohere documentation for a list of available models. + + Examples + -------- + import lancedb + from lancedb.pydantic import LanceModel, Vector + from lancedb.embeddings import EmbeddingFunctionRegistry + + cohere = EmbeddingFunctionRegistry.get_instance().get("cohere").create(name="embed-multilingual-v2.0") + + class TextModel(LanceModel): + text: str = cohere.SourceField() + vector: Vector(cohere.ndims()) = cohere.VectorField() + + data = [ { "text": "hello world" }, + { "text": "goodbye world" }] + + db = lancedb.connect("~/.lancedb") + tbl = db.create_table("test", schema=TextModel, mode="overwrite") + + tbl.add(data) + + """ + + name: str = "embed-multilingual-v2.0" + client: ClassVar = None + + def ndims(self): + # TODO: fix hardcoding + return 768 + + def generate_embeddings( + self, texts: Union[List[str], np.ndarray] + ) -> List[np.array]: + """ + Get the embeddings for the given texts + + Parameters + ---------- + texts: list[str] or np.ndarray (of str) + The texts to embed + """ + # TODO retry, rate limit, token limit + self._init_client() + rs = CohereEmbeddingFunction.client.embed(texts=texts, model=self.name) + + return [emb for emb in rs.embeddings] + + def _init_client(self): + cohere = self.safe_import("cohere") + if CohereEmbeddingFunction.client is None: + if os.environ.get("COHERE_API_KEY") is None: + api_key_not_found_help("cohere") + CohereEmbeddingFunction.client = cohere.Client(os.environ["COHERE_API_KEY"]) diff --git a/python/lancedb/embeddings/utils.py b/python/lancedb/embeddings/utils.py index 25cb8af4..c70e6b18 100644 --- a/python/lancedb/embeddings/utils.py +++ b/python/lancedb/embeddings/utils.py @@ -21,6 +21,7 @@ from lance.vector import vec_to_table from retry import retry from ..util import safe_import_pandas +from ..utils.general import LOGGER pd = safe_import_pandas() DATA = Union[pa.Table, "pd.DataFrame"] @@ -152,3 +153,8 @@ class FunctionWrapper: yield from tqdm(_chunker(arr), total=math.ceil(length / self._batch_size)) else: yield from _chunker(arr) + + +def api_key_not_found_help(provider): + LOGGER.error(f"Could not find API key for {provider}.") + raise ValueError(f"Please set the {provider.upper()}_API_KEY environment variable.") diff --git a/python/pyproject.toml b/python/pyproject.toml index fca35b03..fefe0b2a 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -52,7 +52,7 @@ tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests"] dev = ["ruff", "pre-commit", "black"] docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] clip = ["torch", "pillow", "open-clip"] -embeddings = ["openai", "sentence-transformers", "torch", "pillow", "open-clip"] +embeddings = ["openai", "sentence-transformers", "torch", "pillow", "open-clip", "cohere"] [project.scripts] lancedb = "lancedb.cli.cli:cli" diff --git a/python/tests/test_embeddings_slow.py b/python/tests/test_embeddings_slow.py index 92692d8c..1ca0b78d 100644 --- a/python/tests/test_embeddings_slow.py +++ b/python/tests/test_embeddings_slow.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import io +import os import numpy as np import pandas as pd @@ -123,3 +124,26 @@ def test_openclip(tmp_path): arrow_table["vector"].combine_chunks().values.to_numpy(), arrow_table["vec_from_bytes"].combine_chunks().values.to_numpy(), ) + + +@pytest.mark.slow +@pytest.mark.skipif( + os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set" +) # also skip if cohere not installed +def test_cohere_embedding_function(): + cohere = ( + EmbeddingFunctionRegistry.get_instance() + .get("cohere") + .create(name="embed-multilingual-v2.0") + ) + + class TextModel(LanceModel): + text: str = cohere.SourceField() + vector: Vector(cohere.ndims()) = cohere.VectorField() + + df = pd.DataFrame({"text": ["hello world", "goodbye world"]}) + db = lancedb.connect("~/lancedb") + tbl = db.create_table("test", schema=TextModel, mode="overwrite") + + tbl.add(df) + assert len(tbl.to_pandas()["vector"][0]) == cohere.ndims()