From 30ed8c4c43c849fedceb3195662b2e1a6330405a Mon Sep 17 00:00:00 2001 From: fzowl <160063452+fzowl@users.noreply.github.com> Date: Fri, 4 Apr 2025 23:45:56 +0200 Subject: [PATCH] fix: voyageai regression multimodal supercedes text models (#2268) fix #2160 --- python/pyproject.toml | 1 + python/python/lancedb/embeddings/voyageai.py | 202 +++++++++++++------ python/python/tests/test_embeddings_slow.py | 59 ++++++ 3 files changed, 202 insertions(+), 60 deletions(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index 364582fc..22c0e591 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -56,6 +56,7 @@ tests = [ "tantivy", "pyarrow-stubs", "pylance>=0.23.2", + "requests", ] dev = [ "ruff", diff --git a/python/python/lancedb/embeddings/voyageai.py b/python/python/lancedb/embeddings/voyageai.py index 73d62703..b366329d 100644 --- a/python/python/lancedb/embeddings/voyageai.py +++ b/python/python/lancedb/embeddings/voyageai.py @@ -1,9 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright The LanceDB Authors - - +import base64 import os -from typing import ClassVar, TYPE_CHECKING, List, Union +from typing import ClassVar, TYPE_CHECKING, List, Union, Any + +from pathlib import Path +from urllib.parse import urlparse +from io import BytesIO import numpy as np import pyarrow as pa @@ -11,12 +14,100 @@ import pyarrow as pa from ..util import attempt_import_or_raise from .base import EmbeddingFunction from .registry import register -from .utils import api_key_not_found_help, IMAGES +from .utils import api_key_not_found_help, IMAGES, TEXT if TYPE_CHECKING: import PIL +def is_valid_url(text): + try: + parsed = urlparse(text) + return bool(parsed.scheme) and bool(parsed.netloc) + except Exception: + return False + + +def transform_input(input_data: Union[str, bytes, Path]): + PIL = attempt_import_or_raise("PIL", "pillow") + if isinstance(input_data, str): + if is_valid_url(input_data): + content = {"type": "image_url", "image_url": input_data} + else: + content = {"type": "text", "text": input_data} + elif isinstance(input_data, PIL.Image.Image): + buffered = BytesIO() + input_data.save(buffered, format="JPEG") + img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") + content = { + "type": "image_base64", + "image_base64": "data:image/jpeg;base64," + img_str, + } + elif isinstance(input_data, bytes): + img = PIL.Image.open(BytesIO(input_data)) + buffered = BytesIO() + img.save(buffered, format="JPEG") + img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") + content = { + "type": "image_base64", + "image_base64": "data:image/jpeg;base64," + img_str, + } + elif isinstance(input_data, Path): + img = PIL.Image.open(input_data) + buffered = BytesIO() + img.save(buffered, format="JPEG") + img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") + content = { + "type": "image_base64", + "image_base64": "data:image/jpeg;base64," + img_str, + } + else: + raise ValueError("Each input should be either str, bytes, Path or Image.") + + return {"content": [content]} + + +def sanitize_multimodal_input(inputs: Union[TEXT, IMAGES]) -> List[Any]: + """ + Sanitize the input to the embedding function. + """ + PIL = attempt_import_or_raise("PIL", "pillow") + if isinstance(inputs, (str, bytes, Path, PIL.Image.Image)): + inputs = [inputs] + elif isinstance(inputs, pa.Array): + inputs = inputs.to_pylist() + elif isinstance(inputs, pa.ChunkedArray): + inputs = inputs.combine_chunks().to_pylist() + else: + raise ValueError( + f"Input type {type(inputs)} not allowed with multimodal model." + ) + + if not all(isinstance(x, (str, bytes, Path, PIL.Image.Image)) for x in inputs): + raise ValueError("Each input should be either str, bytes, Path or Image.") + + return [transform_input(i) for i in inputs] + + +def sanitize_text_input(inputs: TEXT) -> List[str]: + """ + Sanitize the input to the embedding function. + """ + if isinstance(inputs, str): + inputs = [inputs] + elif isinstance(inputs, pa.Array): + inputs = inputs.to_pylist() + elif isinstance(inputs, pa.ChunkedArray): + inputs = inputs.combine_chunks().to_pylist() + else: + raise ValueError(f"Input type {type(inputs)} not allowed with text model.") + + if not all(isinstance(x, str) for x in inputs): + raise ValueError("Each input should be str.") + + return inputs + + @register("voyageai") class VoyageAIEmbeddingFunction(EmbeddingFunction): """ @@ -74,6 +165,11 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction): ] multimodal_embedding_models: list = ["voyage-multimodal-3"] + def _is_multimodal_model(self, model_name: str): + return ( + model_name in self.multimodal_embedding_models or "multimodal" in model_name + ) + def ndims(self): if self.name == "voyage-3-lite": return 512 @@ -85,55 +181,12 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction): "voyage-finance-2", "voyage-multilingual-2", "voyage-law-2", + "voyage-multimodal-3", ]: return 1024 else: raise ValueError(f"Model {self.name} not supported") - def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]: - """ - Sanitize the input to the embedding function. - """ - if isinstance(images, (str, bytes)): - images = [images] - elif isinstance(images, pa.Array): - images = images.to_pylist() - elif isinstance(images, pa.ChunkedArray): - images = images.combine_chunks().to_pylist() - return images - - def generate_text_embeddings(self, text: str, **kwargs) -> np.ndarray: - """ - Get the embeddings for the given texts - - Parameters - ---------- - texts: list[str] or np.ndarray (of str) - The texts to embed - input_type: Optional[str] - - truncation: Optional[bool] - """ - client = VoyageAIEmbeddingFunction._get_client() - if self.name in self.text_embedding_models: - rs = client.embed(texts=[text], model=self.name, **kwargs) - elif self.name in self.multimodal_embedding_models: - rs = client.multimodal_embed(inputs=[[text]], model=self.name, **kwargs) - else: - raise ValueError( - f"Model {self.name} not supported to generate text embeddings" - ) - - return rs.embeddings[0] - - def generate_image_embedding( - self, image: "PIL.Image.Image", **kwargs - ) -> np.ndarray: - rs = VoyageAIEmbeddingFunction._get_client().multimodal_embed( - inputs=[[image]], model=self.name, **kwargs - ) - return rs.embeddings[0] - def compute_query_embeddings( self, query: Union[str, "PIL.Image.Image"], *args, **kwargs ) -> List[np.ndarray]: @@ -144,23 +197,52 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction): ---------- query : Union[str, PIL.Image.Image] The query to embed. A query can be either text or an image. + + Returns + ------- + List[np.array]: the list of embeddings """ - if isinstance(query, str): - return [self.generate_text_embeddings(query, input_type="query")] + client = VoyageAIEmbeddingFunction._get_client() + if self._is_multimodal_model(self.name): + result = client.multimodal_embed( + inputs=[[query]], model=self.name, input_type="query", **kwargs + ) else: - PIL = attempt_import_or_raise("PIL", "pillow") - if isinstance(query, PIL.Image.Image): - return [self.generate_image_embedding(query, input_type="query")] - else: - raise TypeError("Only text PIL images supported as query") + result = client.embed( + texts=[query], model=self.name, input_type="query", **kwargs + ) + + return [result.embeddings[0]] def compute_source_embeddings( - self, images: IMAGES, *args, **kwargs + self, inputs: Union[TEXT, IMAGES], *args, **kwargs ) -> List[np.array]: - images = self.sanitize_input(images) - return [ - self.generate_image_embedding(img, input_type="document") for img in images - ] + """ + Compute the embeddings for the inputs + + Parameters + ---------- + inputs : Union[TEXT, IMAGES] + The inputs to embed. The input can be either str, bytes, Path (to an image), + PIL.Image or list of these. + + Returns + ------- + List[np.array]: the list of embeddings + """ + client = VoyageAIEmbeddingFunction._get_client() + if self._is_multimodal_model(self.name): + inputs = sanitize_multimodal_input(inputs) + result = client.multimodal_embed( + inputs=inputs, model=self.name, input_type="document", **kwargs + ) + else: + inputs = sanitize_text_input(inputs) + result = client.embed( + texts=inputs, model=self.name, input_type="document", **kwargs + ) + + return result.embeddings @staticmethod def _get_client(): diff --git a/python/python/tests/test_embeddings_slow.py b/python/python/tests/test_embeddings_slow.py index ba9f8fea..0eec2b3d 100644 --- a/python/python/tests/test_embeddings_slow.py +++ b/python/python/tests/test_embeddings_slow.py @@ -12,6 +12,7 @@ import pyarrow as pa import pytest from lancedb.embeddings import get_registry from lancedb.pydantic import LanceModel, Vector +import requests # These are integration tests for embedding functions. # They are slow because they require downloading models @@ -516,3 +517,61 @@ def test_voyageai_embedding_function(): tbl.add(df) assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims() + + +@pytest.mark.slow +@pytest.mark.skipif( + os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set" +) +def test_voyageai_multimodal_embedding_function(): + voyageai = ( + get_registry().get("voyageai").create(name="voyage-multimodal-3", max_retries=0) + ) + + class Images(LanceModel): + label: str + image_uri: str = voyageai.SourceField() # image uri as the source + image_bytes: bytes = voyageai.SourceField() # image bytes as the source + vector: Vector(voyageai.ndims()) = voyageai.VectorField() # vector column + vec_from_bytes: Vector(voyageai.ndims()) = ( + voyageai.VectorField() + ) # Another vector column + + db = lancedb.connect("~/lancedb") + table = db.create_table("test", schema=Images, mode="overwrite") + labels = ["cat", "cat", "dog", "dog", "horse", "horse"] + uris = [ + "http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg", + "http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg", + "http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg", + "http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg", + "http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg", + "http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg", + ] + # get each uri as bytes + image_bytes = [requests.get(uri).content for uri in uris] + table.add( + pd.DataFrame({"label": labels, "image_uri": uris, "image_bytes": image_bytes}) + ) + assert len(table.to_pandas()["vector"][0]) == voyageai.ndims() + + +@pytest.mark.slow +@pytest.mark.skipif( + os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set" +) +def test_voyageai_multimodal_embedding_text_function(): + voyageai = ( + get_registry().get("voyageai").create(name="voyage-multimodal-3", max_retries=0) + ) + + class TextModel(LanceModel): + text: str = voyageai.SourceField() + vector: Vector(voyageai.ndims()) = voyageai.VectorField() + + df = pd.DataFrame({"text": ["hello world", "goodbye world"]}) + db = lancedb.connect("~/lancedb") + tbl = db.create_table("test", schema=TextModel, mode="overwrite") + + tbl.add(df) + assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()