diff --git a/python/python/lancedb/embeddings/base.py b/python/python/lancedb/embeddings/base.py index 3d940810..bcd6d2cd 100644 --- a/python/python/lancedb/embeddings/base.py +++ b/python/python/lancedb/embeddings/base.py @@ -153,7 +153,7 @@ class TextEmbeddingFunction(EmbeddingFunction): @abstractmethod def generate_embeddings( - self, texts: Union[List[str], np.ndarray] + self, texts: Union[List[str], np.ndarray], *args, **kwargs ) -> List[np.array]: """ Generate the embeddings for the given texts diff --git a/python/python/lancedb/embeddings/bedrock.py b/python/python/lancedb/embeddings/bedrock.py index 767faa65..c7105370 100644 --- a/python/python/lancedb/embeddings/bedrock.py +++ b/python/python/lancedb/embeddings/bedrock.py @@ -73,6 +73,8 @@ class BedRockText(TextEmbeddingFunction): assumed_role: Union[str, None] = None profile_name: Union[str, None] = None role_session_name: str = "lancedb-embeddings" + source_input_type: str = "search_document" + query_input_type: str = "search_query" if PYDANTIC_VERSION.major < 2: # Pydantic 1.x compat @@ -87,21 +89,29 @@ class BedRockText(TextEmbeddingFunction): # TODO: fix hardcoding if self.name == "amazon.titan-embed-text-v1": return 1536 - elif self.name in {"cohere.embed-english-v3", "cohere.embed-multilingual-v3"}: + elif self.name in [ + "amazon.titan-embed-text-v2:0", + "cohere.embed-english-v3", + "cohere.embed-multilingual-v3", + ]: + # TODO: "amazon.titan-embed-text-v2:0" model supports dynamic ndims return 1024 else: - raise ValueError(f"Unknown model name: {self.name}") + raise ValueError(f"Model {self.name} not supported") def compute_query_embeddings( self, query: str, *args, **kwargs ) -> List[List[float]]: - return self.compute_source_embeddings(query) + return self.compute_source_embeddings(query, input_type=self.query_input_type) def compute_source_embeddings( self, texts: TEXT, *args, **kwargs ) -> List[List[float]]: texts = self.sanitize_input(texts) - return self.generate_embeddings(texts) + # assume source input type if not passed by `compute_query_embeddings` + kwargs["input_type"] = kwargs.get("input_type") or self.source_input_type + + return self.generate_embeddings(texts, **kwargs) def generate_embeddings( self, texts: Union[List[str], np.ndarray], *args, **kwargs @@ -121,11 +131,11 @@ class BedRockText(TextEmbeddingFunction): """ results = [] for text in texts: - response = self._generate_embedding(text) + response = self._generate_embedding(text, *args, **kwargs) results.append(response) return results - def _generate_embedding(self, text: str) -> List[float]: + def _generate_embedding(self, text: str, *args, **kwargs) -> List[float]: """ Get the embeddings for the given texts @@ -141,14 +151,12 @@ class BedRockText(TextEmbeddingFunction): """ # format input body for provider provider = self.name.split(".")[0] - _model_kwargs = {} - input_body = {**_model_kwargs} + input_body = {**kwargs} if provider == "cohere": - if "input_type" not in input_body.keys(): - input_body["input_type"] = "search_document" input_body["texts"] = [text] else: # includes common provider == "amazon" + input_body.pop("input_type", None) input_body["inputText"] = text body = json.dumps(input_body) diff --git a/python/python/lancedb/embeddings/cohere.py b/python/python/lancedb/embeddings/cohere.py index 29d203c0..03e41814 100644 --- a/python/python/lancedb/embeddings/cohere.py +++ b/python/python/lancedb/embeddings/cohere.py @@ -19,7 +19,7 @@ import numpy as np from ..util import attempt_import_or_raise from .base import TextEmbeddingFunction from .registry import register -from .utils import api_key_not_found_help +from .utils import api_key_not_found_help, TEXT @register("cohere") @@ -32,8 +32,36 @@ class CohereEmbeddingFunction(TextEmbeddingFunction): Parameters ---------- name: str, default "embed-multilingual-v2.0" - The name of the model to use. See the Cohere documentation for - a list of available models. + The name of the model to use. List of acceptable models: + + * embed-english-v3.0 + * embed-multilingual-v3.0 + * embed-english-light-v3.0 + * embed-multilingual-light-v3.0 + * embed-english-v2.0 + * embed-english-light-v2.0 + * embed-multilingual-v2.0 + + source_input_type: str, default "search_document" + The input type for the source column in the database + + query_input_type: str, default "search_query" + The input type for the query column in the database + + Cohere supports following input types: + + | Input Type | Description | + |-------------------------|---------------------------------------| + | "`search_document`" | Used for embeddings stored in a vector| + | | database for search use-cases. | + | "`search_query`" | Used for embeddings of search queries | + | | run against a vector DB | + | "`semantic_similarity`" | Specifies the given text will be used | + | | for Semantic Textual Similarity (STS) | + | "`classification`" | Used for embeddings passed through a | + | | text classifier. | + | "`clustering`" | Used for the embeddings run through a | + | | clustering algorithm | Examples -------- @@ -61,14 +89,39 @@ class CohereEmbeddingFunction(TextEmbeddingFunction): """ name: str = "embed-multilingual-v2.0" + source_input_type: str = "search_document" + query_input_type: str = "search_query" client: ClassVar = None def ndims(self): # TODO: fix hardcoding - return 768 + if self.name in [ + "embed-english-v3.0", + "embed-multilingual-v3.0", + "embed-english-light-v2.0", + ]: + return 1024 + elif self.name in ["embed-english-light-v3.0", "embed-multilingual-light-v3.0"]: + return 384 + elif self.name == "embed-english-v2.0": + return 4096 + elif self.name == "embed-multilingual-v2.0": + return 768 + else: + raise ValueError(f"Model {self.name} not supported") + + def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]: + return self.compute_source_embeddings(query, input_type=self.query_input_type) + + def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]: + texts = self.sanitize_input(texts) + input_type = ( + kwargs.get("input_type") or self.source_input_type + ) # assume source input type if not passed by `compute_query_embeddings` + return self.generate_embeddings(texts, input_type=input_type) def generate_embeddings( - self, texts: Union[List[str], np.ndarray] + self, texts: Union[List[str], np.ndarray], *args, **kwargs ) -> List[np.array]: """ Get the embeddings for the given texts @@ -78,9 +131,10 @@ class CohereEmbeddingFunction(TextEmbeddingFunction): texts: list[str] or np.ndarray (of str) The texts to embed """ - # TODO retry, rate limit, token limit self._init_client() - rs = CohereEmbeddingFunction.client.embed(texts=texts, model=self.name) + rs = CohereEmbeddingFunction.client.embed( + texts=texts, model=self.name, **kwargs + ) return [emb for emb in rs.embeddings]