mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-10 05:42:58 +00:00
fix: voyageai regression multimodal supercedes text models (#2268)
fix #2160
This commit is contained in:
@@ -56,6 +56,7 @@ tests = [
|
||||
"tantivy",
|
||||
"pyarrow-stubs",
|
||||
"pylance>=0.23.2",
|
||||
"requests",
|
||||
]
|
||||
dev = [
|
||||
"ruff",
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
|
||||
import base64
|
||||
import os
|
||||
from typing import ClassVar, TYPE_CHECKING, List, Union
|
||||
from typing import ClassVar, TYPE_CHECKING, List, Union, Any
|
||||
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
from io import BytesIO
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
@@ -11,12 +14,100 @@ import pyarrow as pa
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import EmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import api_key_not_found_help, IMAGES
|
||||
from .utils import api_key_not_found_help, IMAGES, TEXT
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import PIL
|
||||
|
||||
|
||||
def is_valid_url(text):
|
||||
try:
|
||||
parsed = urlparse(text)
|
||||
return bool(parsed.scheme) and bool(parsed.netloc)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def transform_input(input_data: Union[str, bytes, Path]):
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
if isinstance(input_data, str):
|
||||
if is_valid_url(input_data):
|
||||
content = {"type": "image_url", "image_url": input_data}
|
||||
else:
|
||||
content = {"type": "text", "text": input_data}
|
||||
elif isinstance(input_data, PIL.Image.Image):
|
||||
buffered = BytesIO()
|
||||
input_data.save(buffered, format="JPEG")
|
||||
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
content = {
|
||||
"type": "image_base64",
|
||||
"image_base64": "data:image/jpeg;base64," + img_str,
|
||||
}
|
||||
elif isinstance(input_data, bytes):
|
||||
img = PIL.Image.open(BytesIO(input_data))
|
||||
buffered = BytesIO()
|
||||
img.save(buffered, format="JPEG")
|
||||
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
content = {
|
||||
"type": "image_base64",
|
||||
"image_base64": "data:image/jpeg;base64," + img_str,
|
||||
}
|
||||
elif isinstance(input_data, Path):
|
||||
img = PIL.Image.open(input_data)
|
||||
buffered = BytesIO()
|
||||
img.save(buffered, format="JPEG")
|
||||
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
content = {
|
||||
"type": "image_base64",
|
||||
"image_base64": "data:image/jpeg;base64," + img_str,
|
||||
}
|
||||
else:
|
||||
raise ValueError("Each input should be either str, bytes, Path or Image.")
|
||||
|
||||
return {"content": [content]}
|
||||
|
||||
|
||||
def sanitize_multimodal_input(inputs: Union[TEXT, IMAGES]) -> List[Any]:
|
||||
"""
|
||||
Sanitize the input to the embedding function.
|
||||
"""
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
if isinstance(inputs, (str, bytes, Path, PIL.Image.Image)):
|
||||
inputs = [inputs]
|
||||
elif isinstance(inputs, pa.Array):
|
||||
inputs = inputs.to_pylist()
|
||||
elif isinstance(inputs, pa.ChunkedArray):
|
||||
inputs = inputs.combine_chunks().to_pylist()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Input type {type(inputs)} not allowed with multimodal model."
|
||||
)
|
||||
|
||||
if not all(isinstance(x, (str, bytes, Path, PIL.Image.Image)) for x in inputs):
|
||||
raise ValueError("Each input should be either str, bytes, Path or Image.")
|
||||
|
||||
return [transform_input(i) for i in inputs]
|
||||
|
||||
|
||||
def sanitize_text_input(inputs: TEXT) -> List[str]:
|
||||
"""
|
||||
Sanitize the input to the embedding function.
|
||||
"""
|
||||
if isinstance(inputs, str):
|
||||
inputs = [inputs]
|
||||
elif isinstance(inputs, pa.Array):
|
||||
inputs = inputs.to_pylist()
|
||||
elif isinstance(inputs, pa.ChunkedArray):
|
||||
inputs = inputs.combine_chunks().to_pylist()
|
||||
else:
|
||||
raise ValueError(f"Input type {type(inputs)} not allowed with text model.")
|
||||
|
||||
if not all(isinstance(x, str) for x in inputs):
|
||||
raise ValueError("Each input should be str.")
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
@register("voyageai")
|
||||
class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
"""
|
||||
@@ -74,6 +165,11 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
]
|
||||
multimodal_embedding_models: list = ["voyage-multimodal-3"]
|
||||
|
||||
def _is_multimodal_model(self, model_name: str):
|
||||
return (
|
||||
model_name in self.multimodal_embedding_models or "multimodal" in model_name
|
||||
)
|
||||
|
||||
def ndims(self):
|
||||
if self.name == "voyage-3-lite":
|
||||
return 512
|
||||
@@ -85,55 +181,12 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
"voyage-finance-2",
|
||||
"voyage-multilingual-2",
|
||||
"voyage-law-2",
|
||||
"voyage-multimodal-3",
|
||||
]:
|
||||
return 1024
|
||||
else:
|
||||
raise ValueError(f"Model {self.name} not supported")
|
||||
|
||||
def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]:
|
||||
"""
|
||||
Sanitize the input to the embedding function.
|
||||
"""
|
||||
if isinstance(images, (str, bytes)):
|
||||
images = [images]
|
||||
elif isinstance(images, pa.Array):
|
||||
images = images.to_pylist()
|
||||
elif isinstance(images, pa.ChunkedArray):
|
||||
images = images.combine_chunks().to_pylist()
|
||||
return images
|
||||
|
||||
def generate_text_embeddings(self, text: str, **kwargs) -> np.ndarray:
|
||||
"""
|
||||
Get the embeddings for the given texts
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts: list[str] or np.ndarray (of str)
|
||||
The texts to embed
|
||||
input_type: Optional[str]
|
||||
|
||||
truncation: Optional[bool]
|
||||
"""
|
||||
client = VoyageAIEmbeddingFunction._get_client()
|
||||
if self.name in self.text_embedding_models:
|
||||
rs = client.embed(texts=[text], model=self.name, **kwargs)
|
||||
elif self.name in self.multimodal_embedding_models:
|
||||
rs = client.multimodal_embed(inputs=[[text]], model=self.name, **kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Model {self.name} not supported to generate text embeddings"
|
||||
)
|
||||
|
||||
return rs.embeddings[0]
|
||||
|
||||
def generate_image_embedding(
|
||||
self, image: "PIL.Image.Image", **kwargs
|
||||
) -> np.ndarray:
|
||||
rs = VoyageAIEmbeddingFunction._get_client().multimodal_embed(
|
||||
inputs=[[image]], model=self.name, **kwargs
|
||||
)
|
||||
return rs.embeddings[0]
|
||||
|
||||
def compute_query_embeddings(
|
||||
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
|
||||
) -> List[np.ndarray]:
|
||||
@@ -144,23 +197,52 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
----------
|
||||
query : Union[str, PIL.Image.Image]
|
||||
The query to embed. A query can be either text or an image.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[np.array]: the list of embeddings
|
||||
"""
|
||||
if isinstance(query, str):
|
||||
return [self.generate_text_embeddings(query, input_type="query")]
|
||||
client = VoyageAIEmbeddingFunction._get_client()
|
||||
if self._is_multimodal_model(self.name):
|
||||
result = client.multimodal_embed(
|
||||
inputs=[[query]], model=self.name, input_type="query", **kwargs
|
||||
)
|
||||
else:
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
if isinstance(query, PIL.Image.Image):
|
||||
return [self.generate_image_embedding(query, input_type="query")]
|
||||
else:
|
||||
raise TypeError("Only text PIL images supported as query")
|
||||
result = client.embed(
|
||||
texts=[query], model=self.name, input_type="query", **kwargs
|
||||
)
|
||||
|
||||
return [result.embeddings[0]]
|
||||
|
||||
def compute_source_embeddings(
|
||||
self, images: IMAGES, *args, **kwargs
|
||||
self, inputs: Union[TEXT, IMAGES], *args, **kwargs
|
||||
) -> List[np.array]:
|
||||
images = self.sanitize_input(images)
|
||||
return [
|
||||
self.generate_image_embedding(img, input_type="document") for img in images
|
||||
]
|
||||
"""
|
||||
Compute the embeddings for the inputs
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inputs : Union[TEXT, IMAGES]
|
||||
The inputs to embed. The input can be either str, bytes, Path (to an image),
|
||||
PIL.Image or list of these.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[np.array]: the list of embeddings
|
||||
"""
|
||||
client = VoyageAIEmbeddingFunction._get_client()
|
||||
if self._is_multimodal_model(self.name):
|
||||
inputs = sanitize_multimodal_input(inputs)
|
||||
result = client.multimodal_embed(
|
||||
inputs=inputs, model=self.name, input_type="document", **kwargs
|
||||
)
|
||||
else:
|
||||
inputs = sanitize_text_input(inputs)
|
||||
result = client.embed(
|
||||
texts=inputs, model=self.name, input_type="document", **kwargs
|
||||
)
|
||||
|
||||
return result.embeddings
|
||||
|
||||
@staticmethod
|
||||
def _get_client():
|
||||
|
||||
@@ -12,6 +12,7 @@ import pyarrow as pa
|
||||
import pytest
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
import requests
|
||||
|
||||
# These are integration tests for embedding functions.
|
||||
# They are slow because they require downloading models
|
||||
@@ -516,3 +517,61 @@ def test_voyageai_embedding_function():
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
||||
)
|
||||
def test_voyageai_multimodal_embedding_function():
|
||||
voyageai = (
|
||||
get_registry().get("voyageai").create(name="voyage-multimodal-3", max_retries=0)
|
||||
)
|
||||
|
||||
class Images(LanceModel):
|
||||
label: str
|
||||
image_uri: str = voyageai.SourceField() # image uri as the source
|
||||
image_bytes: bytes = voyageai.SourceField() # image bytes as the source
|
||||
vector: Vector(voyageai.ndims()) = voyageai.VectorField() # vector column
|
||||
vec_from_bytes: Vector(voyageai.ndims()) = (
|
||||
voyageai.VectorField()
|
||||
) # Another vector column
|
||||
|
||||
db = lancedb.connect("~/lancedb")
|
||||
table = db.create_table("test", schema=Images, mode="overwrite")
|
||||
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
|
||||
uris = [
|
||||
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
||||
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
|
||||
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
|
||||
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
|
||||
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
|
||||
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
|
||||
]
|
||||
# get each uri as bytes
|
||||
image_bytes = [requests.get(uri).content for uri in uris]
|
||||
table.add(
|
||||
pd.DataFrame({"label": labels, "image_uri": uris, "image_bytes": image_bytes})
|
||||
)
|
||||
assert len(table.to_pandas()["vector"][0]) == voyageai.ndims()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
||||
)
|
||||
def test_voyageai_multimodal_embedding_text_function():
|
||||
voyageai = (
|
||||
get_registry().get("voyageai").create(name="voyage-multimodal-3", max_retries=0)
|
||||
)
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = voyageai.SourceField()
|
||||
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
|
||||
|
||||
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||
db = lancedb.connect("~/lancedb")
|
||||
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
|
||||
|
||||
Reference in New Issue
Block a user