mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-12 23:02:59 +00:00
feat: voyage-multimodal-3.5 (#2887)
voyage-multimodal-3.5 support (text, image and video embeddings)
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
import base64
|
||||
import os
|
||||
from typing import ClassVar, TYPE_CHECKING, List, Union, Any, Generator
|
||||
from typing import ClassVar, TYPE_CHECKING, List, Union, Any, Generator, Optional
|
||||
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
@@ -45,11 +45,29 @@ def is_valid_url(text):
|
||||
return False
|
||||
|
||||
|
||||
VIDEO_EXTENSIONS = {".mp4", ".webm", ".mov", ".avi", ".mkv", ".m4v", ".gif"}
|
||||
|
||||
|
||||
def is_video_url(url: str) -> bool:
|
||||
"""Check if URL points to a video file based on extension."""
|
||||
parsed = urlparse(url)
|
||||
path = parsed.path.lower()
|
||||
return any(path.endswith(ext) for ext in VIDEO_EXTENSIONS)
|
||||
|
||||
|
||||
def is_video_path(path: Path) -> bool:
|
||||
"""Check if file path is a video file based on extension."""
|
||||
return path.suffix.lower() in VIDEO_EXTENSIONS
|
||||
|
||||
|
||||
def transform_input(input_data: Union[str, bytes, Path]):
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
if isinstance(input_data, str):
|
||||
if is_valid_url(input_data):
|
||||
content = {"type": "image_url", "image_url": input_data}
|
||||
if is_video_url(input_data):
|
||||
content = {"type": "video_url", "video_url": input_data}
|
||||
else:
|
||||
content = {"type": "image_url", "image_url": input_data}
|
||||
else:
|
||||
content = {"type": "text", "text": input_data}
|
||||
elif isinstance(input_data, PIL.Image.Image):
|
||||
@@ -70,14 +88,24 @@ def transform_input(input_data: Union[str, bytes, Path]):
|
||||
"image_base64": "data:image/jpeg;base64," + img_str,
|
||||
}
|
||||
elif isinstance(input_data, Path):
|
||||
img = PIL.Image.open(input_data)
|
||||
buffered = BytesIO()
|
||||
img.save(buffered, format="JPEG")
|
||||
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
content = {
|
||||
"type": "image_base64",
|
||||
"image_base64": "data:image/jpeg;base64," + img_str,
|
||||
}
|
||||
if is_video_path(input_data):
|
||||
# Read video file and encode as base64
|
||||
with open(input_data, "rb") as f:
|
||||
video_bytes = f.read()
|
||||
video_str = base64.b64encode(video_bytes).decode("utf-8")
|
||||
content = {
|
||||
"type": "video_base64",
|
||||
"video_base64": video_str,
|
||||
}
|
||||
else:
|
||||
img = PIL.Image.open(input_data)
|
||||
buffered = BytesIO()
|
||||
img.save(buffered, format="JPEG")
|
||||
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
content = {
|
||||
"type": "image_base64",
|
||||
"image_base64": "data:image/jpeg;base64," + img_str,
|
||||
}
|
||||
else:
|
||||
raise ValueError("Each input should be either str, bytes, Path or Image.")
|
||||
|
||||
@@ -91,6 +119,8 @@ def sanitize_multimodal_input(inputs: Union[TEXT, IMAGES]) -> List[Any]:
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
if isinstance(inputs, (str, bytes, Path, PIL.Image.Image)):
|
||||
inputs = [inputs]
|
||||
elif isinstance(inputs, list):
|
||||
pass # Already a list, use as-is
|
||||
elif isinstance(inputs, pa.Array):
|
||||
inputs = inputs.to_pylist()
|
||||
elif isinstance(inputs, pa.ChunkedArray):
|
||||
@@ -143,11 +173,16 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
* voyage-3
|
||||
* voyage-3-lite
|
||||
* voyage-multimodal-3
|
||||
* voyage-multimodal-3.5
|
||||
* voyage-finance-2
|
||||
* voyage-multilingual-2
|
||||
* voyage-law-2
|
||||
* voyage-code-2
|
||||
|
||||
output_dimension: int, optional
|
||||
The output dimension for models that support flexible dimensions.
|
||||
Currently only voyage-multimodal-3.5 supports this feature.
|
||||
Valid options: 256, 512, 1024 (default), 2048.
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -175,7 +210,10 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
"""
|
||||
|
||||
name: str
|
||||
output_dimension: Optional[int] = None
|
||||
client: ClassVar = None
|
||||
_FLEXIBLE_DIM_MODELS: ClassVar[list] = ["voyage-multimodal-3.5"]
|
||||
_VALID_DIMENSIONS: ClassVar[list] = [256, 512, 1024, 2048]
|
||||
text_embedding_models: list = [
|
||||
"voyage-3.5",
|
||||
"voyage-3.5-lite",
|
||||
@@ -186,7 +224,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
"voyage-law-2",
|
||||
"voyage-code-2",
|
||||
]
|
||||
multimodal_embedding_models: list = ["voyage-multimodal-3"]
|
||||
multimodal_embedding_models: list = ["voyage-multimodal-3", "voyage-multimodal-3.5"]
|
||||
contextual_embedding_models: list = ["voyage-context-3"]
|
||||
|
||||
def _is_multimodal_model(self, model_name: str):
|
||||
@@ -198,6 +236,17 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
return model_name in self.contextual_embedding_models or "context" in model_name
|
||||
|
||||
def ndims(self):
|
||||
# Handle flexible dimension models
|
||||
if self.name in self._FLEXIBLE_DIM_MODELS:
|
||||
if self.output_dimension is not None:
|
||||
if self.output_dimension not in self._VALID_DIMENSIONS:
|
||||
raise ValueError(
|
||||
f"Invalid output_dimension {self.output_dimension} "
|
||||
f"for {self.name}. Valid options: {self._VALID_DIMENSIONS}"
|
||||
)
|
||||
return self.output_dimension
|
||||
return 1024 # default dimension
|
||||
|
||||
if self.name == "voyage-3-lite":
|
||||
return 512
|
||||
elif self.name == "voyage-code-2":
|
||||
@@ -211,12 +260,17 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
"voyage-finance-2",
|
||||
"voyage-multilingual-2",
|
||||
"voyage-law-2",
|
||||
"voyage-multimodal-3",
|
||||
]:
|
||||
return 1024
|
||||
else:
|
||||
raise ValueError(f"Model {self.name} not supported")
|
||||
|
||||
def _get_multimodal_kwargs(self, **kwargs):
|
||||
"""Get kwargs for multimodal embed call, including output_dimension if set."""
|
||||
if self.name in self._FLEXIBLE_DIM_MODELS and self.output_dimension is not None:
|
||||
kwargs["output_dimension"] = self.output_dimension
|
||||
return kwargs
|
||||
|
||||
def compute_query_embeddings(
|
||||
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
|
||||
) -> List[np.ndarray]:
|
||||
@@ -234,6 +288,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
"""
|
||||
client = VoyageAIEmbeddingFunction._get_client()
|
||||
if self._is_multimodal_model(self.name):
|
||||
kwargs = self._get_multimodal_kwargs(**kwargs)
|
||||
result = client.multimodal_embed(
|
||||
inputs=[[query]], model=self.name, input_type="query", **kwargs
|
||||
)
|
||||
@@ -275,6 +330,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
)
|
||||
if has_images:
|
||||
# Use non-batched API for images
|
||||
kwargs = self._get_multimodal_kwargs(**kwargs)
|
||||
result = client.multimodal_embed(
|
||||
inputs=sanitized, model=self.name, input_type="document", **kwargs
|
||||
)
|
||||
@@ -357,6 +413,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
callable: A function that takes a batch of texts and returns embeddings.
|
||||
"""
|
||||
if self._is_multimodal_model(self.name):
|
||||
multimodal_kwargs = self._get_multimodal_kwargs(**kwargs)
|
||||
|
||||
def embed_batch(batch: List[str]) -> List[np.array]:
|
||||
batch_inputs = sanitize_multimodal_input(batch)
|
||||
@@ -364,7 +421,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
inputs=batch_inputs,
|
||||
model=self.name,
|
||||
input_type=input_type,
|
||||
**kwargs,
|
||||
**multimodal_kwargs,
|
||||
)
|
||||
return result.embeddings
|
||||
|
||||
|
||||
@@ -613,6 +613,133 @@ def test_voyageai_multimodal_embedding_text_function():
|
||||
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
||||
)
|
||||
def test_voyageai_multimodal_35_embedding_function():
|
||||
"""Test voyage-multimodal-3.5 model with text input."""
|
||||
voyageai = (
|
||||
get_registry()
|
||||
.get("voyageai")
|
||||
.create(name="voyage-multimodal-3.5", max_retries=0)
|
||||
)
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = voyageai.SourceField()
|
||||
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
|
||||
|
||||
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||
db = lancedb.connect("~/lancedb")
|
||||
tbl = db.create_table("test_multimodal_35", schema=TextModel, mode="overwrite")
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
|
||||
assert voyageai.ndims() == 1024
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
||||
)
|
||||
def test_voyageai_multimodal_35_flexible_dimensions():
|
||||
"""Test voyage-multimodal-3.5 model with custom output dimension."""
|
||||
voyageai = (
|
||||
get_registry()
|
||||
.get("voyageai")
|
||||
.create(name="voyage-multimodal-3.5", output_dimension=512, max_retries=0)
|
||||
)
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = voyageai.SourceField()
|
||||
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
|
||||
|
||||
assert voyageai.ndims() == 512
|
||||
|
||||
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||
db = lancedb.connect("~/lancedb")
|
||||
tbl = db.create_table("test_multimodal_35_dim", schema=TextModel, mode="overwrite")
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == 512
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
||||
)
|
||||
def test_voyageai_multimodal_35_image_embedding():
|
||||
"""Test voyage-multimodal-3.5 model with image input."""
|
||||
voyageai = (
|
||||
get_registry()
|
||||
.get("voyageai")
|
||||
.create(name="voyage-multimodal-3.5", max_retries=0)
|
||||
)
|
||||
|
||||
class Images(LanceModel):
|
||||
label: str
|
||||
image_uri: str = voyageai.SourceField()
|
||||
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
|
||||
|
||||
db = lancedb.connect("~/lancedb")
|
||||
table = db.create_table(
|
||||
"test_multimodal_35_images", schema=Images, mode="overwrite"
|
||||
)
|
||||
labels = ["cat", "dog"]
|
||||
uris = [
|
||||
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
||||
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
|
||||
]
|
||||
table.add(pd.DataFrame({"label": labels, "image_uri": uris}))
|
||||
assert len(table.to_pandas()["vector"][0]) == voyageai.ndims()
|
||||
assert voyageai.ndims() == 1024
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
||||
)
|
||||
@pytest.mark.parametrize("dimension", [256, 512, 1024, 2048])
|
||||
def test_voyageai_multimodal_35_all_dimensions(dimension):
|
||||
"""Test voyage-multimodal-3.5 model with all valid output dimensions."""
|
||||
voyageai = (
|
||||
get_registry()
|
||||
.get("voyageai")
|
||||
.create(name="voyage-multimodal-3.5", output_dimension=dimension, max_retries=0)
|
||||
)
|
||||
|
||||
assert voyageai.ndims() == dimension
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = voyageai.SourceField()
|
||||
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
|
||||
|
||||
df = pd.DataFrame({"text": ["hello world"]})
|
||||
db = lancedb.connect("~/lancedb")
|
||||
tbl = db.create_table(
|
||||
f"test_multimodal_35_dim_{dimension}", schema=TextModel, mode="overwrite"
|
||||
)
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == dimension
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
||||
)
|
||||
def test_voyageai_multimodal_35_invalid_dimension():
|
||||
"""Test voyage-multimodal-3.5 model raises error for invalid output dimension."""
|
||||
with pytest.raises(ValueError, match="Invalid output_dimension"):
|
||||
voyageai = (
|
||||
get_registry()
|
||||
.get("voyageai")
|
||||
.create(name="voyage-multimodal-3.5", output_dimension=999, max_retries=0)
|
||||
)
|
||||
# ndims() is where the validation happens
|
||||
voyageai.ndims()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
importlib.util.find_spec("colpali_engine") is None,
|
||||
|
||||
Reference in New Issue
Block a user