diff --git a/python/python/lancedb/embeddings/voyageai.py b/python/python/lancedb/embeddings/voyageai.py index 161c5e43..ec2c5b6d 100644 --- a/python/python/lancedb/embeddings/voyageai.py +++ b/python/python/lancedb/embeddings/voyageai.py @@ -12,18 +12,22 @@ # limitations under the License. import os -from typing import ClassVar, List, Union +from typing import ClassVar, TYPE_CHECKING, List, Union import numpy as np +import pyarrow as pa from ..util import attempt_import_or_raise -from .base import TextEmbeddingFunction +from .base import EmbeddingFunction from .registry import register -from .utils import api_key_not_found_help, TEXT +from .utils import api_key_not_found_help, IMAGES + +if TYPE_CHECKING: + import PIL @register("voyageai") -class VoyageAIEmbeddingFunction(TextEmbeddingFunction): +class VoyageAIEmbeddingFunction(EmbeddingFunction): """ An embedding function that uses the VoyageAI API @@ -36,6 +40,7 @@ class VoyageAIEmbeddingFunction(TextEmbeddingFunction): * voyage-3 * voyage-3-lite + * voyage-multimodal-3 * voyage-finance-2 * voyage-multilingual-2 * voyage-law-2 @@ -54,7 +59,7 @@ class VoyageAIEmbeddingFunction(TextEmbeddingFunction): .create(name="voyage-3") class TextModel(LanceModel): - text: str = voyageai.SourceField() + data: str = voyageai.SourceField() vector: Vector(voyageai.ndims()) = voyageai.VectorField() data = [ { "text": "hello world" }, @@ -77,6 +82,7 @@ class VoyageAIEmbeddingFunction(TextEmbeddingFunction): return 1536 elif self.name in [ "voyage-3", + "voyage-multimodal-3", "voyage-finance-2", "voyage-multilingual-2", "voyage-law-2", @@ -85,19 +91,19 @@ class VoyageAIEmbeddingFunction(TextEmbeddingFunction): else: raise ValueError(f"Model {self.name} not supported") - def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]: - return self.compute_source_embeddings(query, input_type="query") + def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]: + """ + Sanitize the input to the embedding function. + """ + if isinstance(images, (str, bytes)): + images = [images] + elif isinstance(images, pa.Array): + images = images.to_pylist() + elif isinstance(images, pa.ChunkedArray): + images = images.combine_chunks().to_pylist() + return images - def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]: - texts = self.sanitize_input(texts) - input_type = ( - kwargs.get("input_type") or "document" - ) # assume source input type if not passed by `compute_query_embeddings` - return self.generate_embeddings(texts, input_type=input_type) - - def generate_embeddings( - self, texts: Union[List[str], np.ndarray], *args, **kwargs - ) -> List[np.array]: + def generate_text_embeddings(self, text: str, **kwargs) -> np.ndarray: """ Get the embeddings for the given texts @@ -109,15 +115,55 @@ class VoyageAIEmbeddingFunction(TextEmbeddingFunction): truncation: Optional[bool] """ - VoyageAIEmbeddingFunction._init_client() - rs = VoyageAIEmbeddingFunction.client.embed( - texts=texts, model=self.name, **kwargs - ) + if self.name in ["voyage-multimodal-3"]: + rs = VoyageAIEmbeddingFunction._get_client().multimodal_embed( + inputs=[[text]], model=self.name, **kwargs + ) + else: + rs = VoyageAIEmbeddingFunction._get_client().embed( + texts=[text], model=self.name, **kwargs + ) - return [emb for emb in rs.embeddings] + return rs.embeddings[0] + + def generate_image_embedding( + self, image: "PIL.Image.Image", **kwargs + ) -> np.ndarray: + rs = VoyageAIEmbeddingFunction._get_client().multimodal_embed( + inputs=[[image]], model=self.name, **kwargs + ) + return rs.embeddings[0] + + def compute_query_embeddings( + self, query: Union[str, "PIL.Image.Image"], *args, **kwargs + ) -> List[np.ndarray]: + """ + Compute the embeddings for a given user query + + Parameters + ---------- + query : Union[str, PIL.Image.Image] + The query to embed. A query can be either text or an image. + """ + if isinstance(query, str): + return [self.generate_text_embeddings(query, input_type="query")] + else: + PIL = attempt_import_or_raise("PIL", "pillow") + if isinstance(query, PIL.Image.Image): + return [self.generate_image_embedding(query, input_type="query")] + else: + raise TypeError("Only text PIL images supported as query") + + def compute_source_embeddings( + self, images: IMAGES, *args, **kwargs + ) -> List[np.array]: + images = self.sanitize_input(images) + return [ + self.generate_image_embedding(img, input_type="document") for img in images + ] @staticmethod - def _init_client(): + def _get_client(): if VoyageAIEmbeddingFunction.client is None: voyageai = attempt_import_or_raise("voyageai") if os.environ.get("VOYAGE_API_KEY") is None: @@ -125,3 +171,4 @@ class VoyageAIEmbeddingFunction(TextEmbeddingFunction): VoyageAIEmbeddingFunction.client = voyageai.Client( os.environ["VOYAGE_API_KEY"] ) + return VoyageAIEmbeddingFunction.client