mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-22 21:09:58 +00:00
feat(voyageai): update voyage integration (#2713)
Adding multimodal usage guide VoyageAI integration changes: - Adding voyage-3.5 and voyage-3.5-lite models - Adding voyage-context-3 model - Adding rerank-2.5 and rerank-2.5-lite models
This commit is contained in:
@@ -0,0 +1,97 @@
|
||||
# VoyageAI Embeddings : Multimodal
|
||||
|
||||
VoyageAI embeddings can also be used to embed both text and image data, only some of the models support image data and you can check the list
|
||||
under [https://docs.voyageai.com/docs/multimodal-embeddings](https://docs.voyageai.com/docs/multimodal-embeddings)
|
||||
|
||||
Supported parameters (to be passed in `create` method) are:
|
||||
|
||||
| Parameter | Type | Default Value | Description |
|
||||
|---|---|-------------------------|-------------------------------------------|
|
||||
| `name` | `str` | `"voyage-multimodal-3"` | The model ID of the VoyageAI model to use |
|
||||
|
||||
Usage Example:
|
||||
|
||||
```python
|
||||
import base64
|
||||
import os
|
||||
from io import BytesIO
|
||||
|
||||
import requests
|
||||
import lancedb
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.embeddings import get_registry
|
||||
import pandas as pd
|
||||
|
||||
os.environ['VOYAGE_API_KEY'] = 'YOUR_VOYAGE_API_KEY'
|
||||
|
||||
db = lancedb.connect(".lancedb")
|
||||
func = get_registry().get("voyageai").create(name="voyage-multimodal-3")
|
||||
|
||||
|
||||
def image_to_base64(image_bytes: bytes):
|
||||
buffered = BytesIO(image_bytes)
|
||||
img_str = base64.b64encode(buffered.getvalue())
|
||||
return img_str.decode("utf-8")
|
||||
|
||||
|
||||
class Images(LanceModel):
|
||||
label: str
|
||||
image_uri: str = func.SourceField() # image uri as the source
|
||||
image_bytes: str = func.SourceField() # image bytes base64 encoded as the source
|
||||
vector: Vector(func.ndims()) = func.VectorField() # vector column
|
||||
vec_from_bytes: Vector(func.ndims()) = func.VectorField() # Another vector column
|
||||
|
||||
|
||||
if "images" in db.table_names():
|
||||
db.drop_table("images")
|
||||
table = db.create_table("images", schema=Images)
|
||||
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
|
||||
uris = [
|
||||
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
||||
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
|
||||
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
|
||||
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
|
||||
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
|
||||
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
|
||||
]
|
||||
# get each uri as bytes
|
||||
images_bytes = [image_to_base64(requests.get(uri).content) for uri in uris]
|
||||
table.add(
|
||||
pd.DataFrame({"label": labels, "image_uri": uris, "image_bytes": images_bytes})
|
||||
)
|
||||
```
|
||||
Now we can search using text from both the default vector column and the custom vector column
|
||||
```python
|
||||
|
||||
# text search
|
||||
actual = table.search("man's best friend", "vec_from_bytes").limit(1).to_pydantic(Images)[0]
|
||||
print(actual.label) # prints "dog"
|
||||
|
||||
frombytes = (
|
||||
table.search("man's best friend", vector_column_name="vec_from_bytes")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
print(frombytes.label)
|
||||
|
||||
```
|
||||
|
||||
Because we're using a multi-modal embedding function, we can also search using images
|
||||
|
||||
```python
|
||||
# image search
|
||||
query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg"
|
||||
image_bytes = requests.get(query_image_uri).content
|
||||
query_image = Image.open(BytesIO(image_bytes))
|
||||
actual = table.search(query_image, "vec_from_bytes").limit(1).to_pydantic(Images)[0]
|
||||
print(actual.label == "dog")
|
||||
|
||||
# image search using a custom vector column
|
||||
other = (
|
||||
table.search(query_image, vector_column_name="vec_from_bytes")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
print(actual.label)
|
||||
|
||||
```
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
import base64
|
||||
import os
|
||||
from typing import ClassVar, TYPE_CHECKING, List, Union, Any
|
||||
from typing import ClassVar, TYPE_CHECKING, List, Union, Any, Generator
|
||||
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
@@ -19,6 +19,23 @@ from .utils import api_key_not_found_help, IMAGES, TEXT
|
||||
if TYPE_CHECKING:
|
||||
import PIL
|
||||
|
||||
# Token limits for different VoyageAI models
|
||||
VOYAGE_TOTAL_TOKEN_LIMITS = {
|
||||
"voyage-context-3": 32_000,
|
||||
"voyage-3.5-lite": 1_000_000,
|
||||
"voyage-3.5": 320_000,
|
||||
"voyage-3-lite": 120_000,
|
||||
"voyage-3": 120_000,
|
||||
"voyage-multimodal-3": 120_000,
|
||||
"voyage-finance-2": 120_000,
|
||||
"voyage-multilingual-2": 120_000,
|
||||
"voyage-law-2": 120_000,
|
||||
"voyage-code-2": 120_000,
|
||||
}
|
||||
|
||||
# Batch size for embedding requests (max number of items per batch)
|
||||
BATCH_SIZE = 1000
|
||||
|
||||
|
||||
def is_valid_url(text):
|
||||
try:
|
||||
@@ -120,6 +137,9 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
name: str
|
||||
The name of the model to use. List of acceptable models:
|
||||
|
||||
* voyage-context-3
|
||||
* voyage-3.5
|
||||
* voyage-3.5-lite
|
||||
* voyage-3
|
||||
* voyage-3-lite
|
||||
* voyage-multimodal-3
|
||||
@@ -157,25 +177,35 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
name: str
|
||||
client: ClassVar = None
|
||||
text_embedding_models: list = [
|
||||
"voyage-3.5",
|
||||
"voyage-3.5-lite",
|
||||
"voyage-3",
|
||||
"voyage-3-lite",
|
||||
"voyage-finance-2",
|
||||
"voyage-multilingual-2",
|
||||
"voyage-law-2",
|
||||
"voyage-code-2",
|
||||
]
|
||||
multimodal_embedding_models: list = ["voyage-multimodal-3"]
|
||||
contextual_embedding_models: list = ["voyage-context-3"]
|
||||
|
||||
def _is_multimodal_model(self, model_name: str):
|
||||
return (
|
||||
model_name in self.multimodal_embedding_models or "multimodal" in model_name
|
||||
)
|
||||
|
||||
def _is_contextual_model(self, model_name: str):
|
||||
return model_name in self.contextual_embedding_models or "context" in model_name
|
||||
|
||||
def ndims(self):
|
||||
if self.name == "voyage-3-lite":
|
||||
return 512
|
||||
elif self.name == "voyage-code-2":
|
||||
return 1536
|
||||
elif self.name in [
|
||||
"voyage-context-3",
|
||||
"voyage-3.5",
|
||||
"voyage-3.5-lite",
|
||||
"voyage-3",
|
||||
"voyage-multimodal-3",
|
||||
"voyage-finance-2",
|
||||
@@ -207,6 +237,11 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
result = client.multimodal_embed(
|
||||
inputs=[[query]], model=self.name, input_type="query", **kwargs
|
||||
)
|
||||
elif self._is_contextual_model(self.name):
|
||||
result = client.contextualized_embed(
|
||||
inputs=[[query]], model=self.name, input_type="query", **kwargs
|
||||
)
|
||||
result = result.results[0]
|
||||
else:
|
||||
result = client.embed(
|
||||
texts=[query], model=self.name, input_type="query", **kwargs
|
||||
@@ -231,18 +266,164 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
List[np.array]: the list of embeddings
|
||||
"""
|
||||
client = VoyageAIEmbeddingFunction._get_client()
|
||||
|
||||
# For multimodal models, check if inputs contain images
|
||||
if self._is_multimodal_model(self.name):
|
||||
inputs = sanitize_multimodal_input(inputs)
|
||||
result = client.multimodal_embed(
|
||||
inputs=inputs, model=self.name, input_type="document", **kwargs
|
||||
sanitized = sanitize_multimodal_input(inputs)
|
||||
has_images = any(
|
||||
inp["content"][0].get("type") != "text" for inp in sanitized
|
||||
)
|
||||
if has_images:
|
||||
# Use non-batched API for images
|
||||
result = client.multimodal_embed(
|
||||
inputs=sanitized, model=self.name, input_type="document", **kwargs
|
||||
)
|
||||
return result.embeddings
|
||||
# Extract texts for batching
|
||||
inputs = [inp["content"][0]["text"] for inp in sanitized]
|
||||
else:
|
||||
inputs = sanitize_text_input(inputs)
|
||||
result = client.embed(
|
||||
texts=inputs, model=self.name, input_type="document", **kwargs
|
||||
)
|
||||
|
||||
return result.embeddings
|
||||
# Use batching for all text inputs
|
||||
return self._embed_with_batching(
|
||||
client, inputs, input_type="document", **kwargs
|
||||
)
|
||||
|
||||
def _build_batches(
|
||||
self, client, texts: List[str]
|
||||
) -> Generator[List[str], None, None]:
|
||||
"""
|
||||
Generate batches of texts based on token limits using a generator.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
client : voyageai.Client
|
||||
The VoyageAI client instance.
|
||||
texts : List[str]
|
||||
List of texts to batch.
|
||||
|
||||
Yields
|
||||
------
|
||||
List[str]: Batches of texts.
|
||||
"""
|
||||
if not texts:
|
||||
return
|
||||
|
||||
max_tokens_per_batch = VOYAGE_TOTAL_TOKEN_LIMITS.get(self.name, 120_000)
|
||||
current_batch: List[str] = []
|
||||
current_batch_tokens = 0
|
||||
|
||||
# Tokenize all texts in one API call
|
||||
token_lists = client.tokenize(texts, model=self.name)
|
||||
token_counts = [len(token_list) for token_list in token_lists]
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
n_tokens = token_counts[i]
|
||||
|
||||
# Check if adding this text would exceed limits
|
||||
if current_batch and (
|
||||
len(current_batch) >= BATCH_SIZE
|
||||
or (current_batch_tokens + n_tokens > max_tokens_per_batch)
|
||||
):
|
||||
# Yield the current batch and start a new one
|
||||
yield current_batch
|
||||
current_batch = []
|
||||
current_batch_tokens = 0
|
||||
|
||||
current_batch.append(text)
|
||||
current_batch_tokens += n_tokens
|
||||
|
||||
# Yield the last batch (always has at least one text)
|
||||
if current_batch:
|
||||
yield current_batch
|
||||
|
||||
def _get_embed_function(
|
||||
self, client, input_type: str = "document", **kwargs
|
||||
) -> callable:
|
||||
"""
|
||||
Get the appropriate embedding function based on model type.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
client : voyageai.Client
|
||||
The VoyageAI client instance.
|
||||
input_type : str
|
||||
Either "query" or "document"
|
||||
**kwargs
|
||||
Additional arguments to pass to the embedding API
|
||||
|
||||
Returns
|
||||
-------
|
||||
callable: A function that takes a batch of texts and returns embeddings.
|
||||
"""
|
||||
if self._is_multimodal_model(self.name):
|
||||
|
||||
def embed_batch(batch: List[str]) -> List[np.array]:
|
||||
batch_inputs = sanitize_multimodal_input(batch)
|
||||
result = client.multimodal_embed(
|
||||
inputs=batch_inputs,
|
||||
model=self.name,
|
||||
input_type=input_type,
|
||||
**kwargs,
|
||||
)
|
||||
return result.embeddings
|
||||
|
||||
return embed_batch
|
||||
|
||||
elif self._is_contextual_model(self.name):
|
||||
|
||||
def embed_batch(batch: List[str]) -> List[np.array]:
|
||||
result = client.contextualized_embed(
|
||||
inputs=[batch], model=self.name, input_type=input_type, **kwargs
|
||||
)
|
||||
return result.results[0].embeddings
|
||||
|
||||
return embed_batch
|
||||
|
||||
else:
|
||||
|
||||
def embed_batch(batch: List[str]) -> List[np.array]:
|
||||
result = client.embed(
|
||||
texts=batch, model=self.name, input_type=input_type, **kwargs
|
||||
)
|
||||
return result.embeddings
|
||||
|
||||
return embed_batch
|
||||
|
||||
def _embed_with_batching(
|
||||
self, client, texts: List[str], input_type: str = "document", **kwargs
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Embed texts with automatic batching based on token limits.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
client : voyageai.Client
|
||||
The VoyageAI client instance.
|
||||
texts : List[str]
|
||||
List of texts to embed.
|
||||
input_type : str
|
||||
Either "query" or "document"
|
||||
**kwargs
|
||||
Additional arguments to pass to the embedding API
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[np.array]: List of embeddings.
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
# Get the appropriate embedding function for this model type
|
||||
embed_fn = self._get_embed_function(client, input_type=input_type, **kwargs)
|
||||
|
||||
# Process each batch
|
||||
all_embeddings = []
|
||||
for batch in self._build_batches(client, texts):
|
||||
batch_embeddings = embed_fn(batch)
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
|
||||
return all_embeddings
|
||||
|
||||
@staticmethod
|
||||
def _get_client():
|
||||
|
||||
@@ -21,6 +21,8 @@ class VoyageAIReranker(Reranker):
|
||||
----------
|
||||
model_name : str, default "rerank-english-v2.0"
|
||||
The name of the cross encoder model to use. Available voyageai models are:
|
||||
- rerank-2.5
|
||||
- rerank-2.5-lite
|
||||
- rerank-2
|
||||
- rerank-2-lite
|
||||
column : str, default "text"
|
||||
|
||||
@@ -532,6 +532,27 @@ def test_voyageai_embedding_function():
|
||||
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
||||
)
|
||||
def test_voyageai_embedding_function_contextual_model():
|
||||
voyageai = (
|
||||
get_registry().get("voyageai").create(name="voyage-context-3", max_retries=0)
|
||||
)
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = voyageai.SourceField()
|
||||
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
|
||||
|
||||
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||
db = lancedb.connect("~/lancedb")
|
||||
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
||||
|
||||
@@ -484,7 +484,7 @@ def test_jina_reranker(tmp_path, use_tantivy):
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_voyageai_reranker(tmp_path, use_tantivy):
|
||||
pytest.importorskip("voyageai")
|
||||
reranker = VoyageAIReranker(model_name="rerank-2")
|
||||
reranker = VoyageAIReranker(model_name="rerank-2.5")
|
||||
table, schema = get_test_table(tmp_path, use_tantivy)
|
||||
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user