mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 07:09:57 +00:00
feat: add support for new cohere models in cohere and bedrock embedding functions (#1335)
Fixes #1329 Will update docs on https://github.com/lancedb/lancedb/pull/1326
This commit is contained in:
@@ -153,7 +153,7 @@ class TextEmbeddingFunction(EmbeddingFunction):
|
||||
|
||||
@abstractmethod
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
self, texts: Union[List[str], np.ndarray], *args, **kwargs
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Generate the embeddings for the given texts
|
||||
|
||||
@@ -73,6 +73,8 @@ class BedRockText(TextEmbeddingFunction):
|
||||
assumed_role: Union[str, None] = None
|
||||
profile_name: Union[str, None] = None
|
||||
role_session_name: str = "lancedb-embeddings"
|
||||
source_input_type: str = "search_document"
|
||||
query_input_type: str = "search_query"
|
||||
|
||||
if PYDANTIC_VERSION.major < 2: # Pydantic 1.x compat
|
||||
|
||||
@@ -87,21 +89,29 @@ class BedRockText(TextEmbeddingFunction):
|
||||
# TODO: fix hardcoding
|
||||
if self.name == "amazon.titan-embed-text-v1":
|
||||
return 1536
|
||||
elif self.name in {"cohere.embed-english-v3", "cohere.embed-multilingual-v3"}:
|
||||
elif self.name in [
|
||||
"amazon.titan-embed-text-v2:0",
|
||||
"cohere.embed-english-v3",
|
||||
"cohere.embed-multilingual-v3",
|
||||
]:
|
||||
# TODO: "amazon.titan-embed-text-v2:0" model supports dynamic ndims
|
||||
return 1024
|
||||
else:
|
||||
raise ValueError(f"Unknown model name: {self.name}")
|
||||
raise ValueError(f"Model {self.name} not supported")
|
||||
|
||||
def compute_query_embeddings(
|
||||
self, query: str, *args, **kwargs
|
||||
) -> List[List[float]]:
|
||||
return self.compute_source_embeddings(query)
|
||||
return self.compute_source_embeddings(query, input_type=self.query_input_type)
|
||||
|
||||
def compute_source_embeddings(
|
||||
self, texts: TEXT, *args, **kwargs
|
||||
) -> List[List[float]]:
|
||||
texts = self.sanitize_input(texts)
|
||||
return self.generate_embeddings(texts)
|
||||
# assume source input type if not passed by `compute_query_embeddings`
|
||||
kwargs["input_type"] = kwargs.get("input_type") or self.source_input_type
|
||||
|
||||
return self.generate_embeddings(texts, **kwargs)
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray], *args, **kwargs
|
||||
@@ -121,11 +131,11 @@ class BedRockText(TextEmbeddingFunction):
|
||||
"""
|
||||
results = []
|
||||
for text in texts:
|
||||
response = self._generate_embedding(text)
|
||||
response = self._generate_embedding(text, *args, **kwargs)
|
||||
results.append(response)
|
||||
return results
|
||||
|
||||
def _generate_embedding(self, text: str) -> List[float]:
|
||||
def _generate_embedding(self, text: str, *args, **kwargs) -> List[float]:
|
||||
"""
|
||||
Get the embeddings for the given texts
|
||||
|
||||
@@ -141,14 +151,12 @@ class BedRockText(TextEmbeddingFunction):
|
||||
"""
|
||||
# format input body for provider
|
||||
provider = self.name.split(".")[0]
|
||||
_model_kwargs = {}
|
||||
input_body = {**_model_kwargs}
|
||||
input_body = {**kwargs}
|
||||
if provider == "cohere":
|
||||
if "input_type" not in input_body.keys():
|
||||
input_body["input_type"] = "search_document"
|
||||
input_body["texts"] = [text]
|
||||
else:
|
||||
# includes common provider == "amazon"
|
||||
input_body.pop("input_type", None)
|
||||
input_body["inputText"] = text
|
||||
body = json.dumps(input_body)
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ import numpy as np
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import api_key_not_found_help
|
||||
from .utils import api_key_not_found_help, TEXT
|
||||
|
||||
|
||||
@register("cohere")
|
||||
@@ -32,8 +32,36 @@ class CohereEmbeddingFunction(TextEmbeddingFunction):
|
||||
Parameters
|
||||
----------
|
||||
name: str, default "embed-multilingual-v2.0"
|
||||
The name of the model to use. See the Cohere documentation for
|
||||
a list of available models.
|
||||
The name of the model to use. List of acceptable models:
|
||||
|
||||
* embed-english-v3.0
|
||||
* embed-multilingual-v3.0
|
||||
* embed-english-light-v3.0
|
||||
* embed-multilingual-light-v3.0
|
||||
* embed-english-v2.0
|
||||
* embed-english-light-v2.0
|
||||
* embed-multilingual-v2.0
|
||||
|
||||
source_input_type: str, default "search_document"
|
||||
The input type for the source column in the database
|
||||
|
||||
query_input_type: str, default "search_query"
|
||||
The input type for the query column in the database
|
||||
|
||||
Cohere supports following input types:
|
||||
|
||||
| Input Type | Description |
|
||||
|-------------------------|---------------------------------------|
|
||||
| "`search_document`" | Used for embeddings stored in a vector|
|
||||
| | database for search use-cases. |
|
||||
| "`search_query`" | Used for embeddings of search queries |
|
||||
| | run against a vector DB |
|
||||
| "`semantic_similarity`" | Specifies the given text will be used |
|
||||
| | for Semantic Textual Similarity (STS) |
|
||||
| "`classification`" | Used for embeddings passed through a |
|
||||
| | text classifier. |
|
||||
| "`clustering`" | Used for the embeddings run through a |
|
||||
| | clustering algorithm |
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -61,14 +89,39 @@ class CohereEmbeddingFunction(TextEmbeddingFunction):
|
||||
"""
|
||||
|
||||
name: str = "embed-multilingual-v2.0"
|
||||
source_input_type: str = "search_document"
|
||||
query_input_type: str = "search_query"
|
||||
client: ClassVar = None
|
||||
|
||||
def ndims(self):
|
||||
# TODO: fix hardcoding
|
||||
return 768
|
||||
if self.name in [
|
||||
"embed-english-v3.0",
|
||||
"embed-multilingual-v3.0",
|
||||
"embed-english-light-v2.0",
|
||||
]:
|
||||
return 1024
|
||||
elif self.name in ["embed-english-light-v3.0", "embed-multilingual-light-v3.0"]:
|
||||
return 384
|
||||
elif self.name == "embed-english-v2.0":
|
||||
return 4096
|
||||
elif self.name == "embed-multilingual-v2.0":
|
||||
return 768
|
||||
else:
|
||||
raise ValueError(f"Model {self.name} not supported")
|
||||
|
||||
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
||||
return self.compute_source_embeddings(query, input_type=self.query_input_type)
|
||||
|
||||
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||
texts = self.sanitize_input(texts)
|
||||
input_type = (
|
||||
kwargs.get("input_type") or self.source_input_type
|
||||
) # assume source input type if not passed by `compute_query_embeddings`
|
||||
return self.generate_embeddings(texts, input_type=input_type)
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
self, texts: Union[List[str], np.ndarray], *args, **kwargs
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Get the embeddings for the given texts
|
||||
@@ -78,9 +131,10 @@ class CohereEmbeddingFunction(TextEmbeddingFunction):
|
||||
texts: list[str] or np.ndarray (of str)
|
||||
The texts to embed
|
||||
"""
|
||||
# TODO retry, rate limit, token limit
|
||||
self._init_client()
|
||||
rs = CohereEmbeddingFunction.client.embed(texts=texts, model=self.name)
|
||||
rs = CohereEmbeddingFunction.client.embed(
|
||||
texts=texts, model=self.name, **kwargs
|
||||
)
|
||||
|
||||
return [emb for emb in rs.embeddings]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user