mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-25 14:29:56 +00:00
This PR adds ColPali support with ColPaliEmbeddings class (tagged "colpali") using ColQwen2.5 for multi-vector text/image embeddings. Also added MultiVector Pydantic type to handle the vector lists. I've added some integration test for the embedding model and some unit test for the new Pydantic type. Could be a template for other ColPali variants as well. or until transformers🤗 starts supporting it. Still `TODO`: - [ ] Documentation - [ ] Add an example _Could also allow Image as query, but didn't work well when testing it._ [ColPali-Engine](https://github.com/illuin-tech/colpali) version: 0.3.9.dev17+g3faee24 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced support for ColPali-based multimodal multi-vector embeddings for both text and images. - Added a new embedding class for generating multi-vector embeddings, configurable for various model and processing options. - Added a new Pydantic type for multi-vector embeddings, supporting validation and schema generation for lists of fixed-dimension vectors. - **Bug Fixes** - Ensured proper asynchronous index creation in query tests for improved reliability. - **Tests** - Added integration tests for ColPali embeddings, including text-to-image search and validation of multi-vector fields. - Added comprehensive tests for the new multi-vector Pydantic type, covering schema, validation, and default value behavior. - **Chores** - Updated optional dependencies to include the ColPali engine. - Added utility to check for availability of flash attention support. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
642 lines
21 KiB
Python
642 lines
21 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
|
|
|
import importlib
|
|
import io
|
|
import os
|
|
|
|
import lancedb
|
|
import numpy as np
|
|
import pandas as pd
|
|
import pyarrow as pa
|
|
import pytest
|
|
from lancedb.embeddings import get_registry
|
|
from lancedb.pydantic import LanceModel, Vector, MultiVector
|
|
import requests
|
|
|
|
# These are integration tests for embedding functions.
|
|
# They are slow because they require downloading models
|
|
# or connection to external api
|
|
|
|
|
|
try:
|
|
if importlib.util.find_spec("mlx.core") is not None:
|
|
_mlx = True
|
|
else:
|
|
_mlx = None
|
|
except Exception:
|
|
_mlx = None
|
|
|
|
try:
|
|
if importlib.util.find_spec("imagebind") is not None:
|
|
_imagebind = True
|
|
else:
|
|
_imagebind = None
|
|
except Exception:
|
|
_imagebind = None
|
|
|
|
|
|
@pytest.mark.slow
|
|
@pytest.mark.parametrize(
|
|
"alias", ["sentence-transformers", "openai", "huggingface", "ollama"]
|
|
)
|
|
def test_basic_text_embeddings(alias, tmp_path):
|
|
db = lancedb.connect(tmp_path)
|
|
registry = get_registry()
|
|
func = registry.get(alias).create(max_retries=0)
|
|
func2 = registry.get(alias).create(max_retries=0)
|
|
|
|
class Words(LanceModel):
|
|
text: str = func.SourceField()
|
|
text2: str = func2.SourceField()
|
|
vector: Vector(func.ndims()) = func.VectorField()
|
|
vector2: Vector(func2.ndims()) = func2.VectorField()
|
|
|
|
table = db.create_table("words", schema=Words)
|
|
table.add(
|
|
pd.DataFrame(
|
|
{
|
|
"text": [
|
|
"hello world",
|
|
"goodbye world",
|
|
"fizz",
|
|
"buzz",
|
|
"foo",
|
|
"bar",
|
|
"baz",
|
|
],
|
|
"text2": [
|
|
"to be or not to be",
|
|
"that is the question",
|
|
"for whether tis nobler",
|
|
"in the mind to suffer",
|
|
"the slings and arrows",
|
|
"of outrageous fortune",
|
|
"or to take arms",
|
|
],
|
|
}
|
|
)
|
|
)
|
|
|
|
query = "greeting"
|
|
actual = (
|
|
table.search(query, vector_column_name="vector").limit(1).to_pydantic(Words)[0]
|
|
)
|
|
|
|
vec = func.compute_query_embeddings(query)[0]
|
|
expected = (
|
|
table.search(vec, vector_column_name="vector").limit(1).to_pydantic(Words)[0]
|
|
)
|
|
assert actual.text == expected.text
|
|
assert actual.text == "hello world"
|
|
assert not np.allclose(actual.vector, actual.vector2)
|
|
|
|
actual = (
|
|
table.search(query, vector_column_name="vector2").limit(1).to_pydantic(Words)[0]
|
|
)
|
|
assert actual.text != "hello world"
|
|
assert not np.allclose(actual.vector, actual.vector2)
|
|
|
|
|
|
@pytest.mark.slow
|
|
def test_openclip(tmp_path):
|
|
import requests
|
|
from PIL import Image
|
|
|
|
db = lancedb.connect(tmp_path)
|
|
registry = get_registry()
|
|
func = registry.get("open-clip").create(max_retries=0)
|
|
|
|
class Images(LanceModel):
|
|
label: str
|
|
image_uri: str = func.SourceField()
|
|
image_bytes: bytes = func.SourceField()
|
|
vector: Vector(func.ndims()) = func.VectorField()
|
|
vec_from_bytes: Vector(func.ndims()) = func.VectorField()
|
|
|
|
table = db.create_table("images", schema=Images)
|
|
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
|
|
uris = [
|
|
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
|
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
|
|
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
|
|
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
|
|
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
|
|
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
|
|
]
|
|
# get each uri as bytes
|
|
image_bytes = [requests.get(uri).content for uri in uris]
|
|
table.add(
|
|
pd.DataFrame({"label": labels, "image_uri": uris, "image_bytes": image_bytes})
|
|
)
|
|
|
|
# text search
|
|
actual = (
|
|
table.search("man's best friend", vector_column_name="vector")
|
|
.limit(1)
|
|
.to_pydantic(Images)[0]
|
|
)
|
|
assert actual.label == "dog"
|
|
frombytes = (
|
|
table.search("man's best friend", vector_column_name="vec_from_bytes")
|
|
.limit(1)
|
|
.to_pydantic(Images)[0]
|
|
)
|
|
assert actual.label == frombytes.label
|
|
assert np.allclose(actual.vector, frombytes.vector)
|
|
|
|
# image search
|
|
query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg"
|
|
image_bytes = requests.get(query_image_uri).content
|
|
query_image = Image.open(io.BytesIO(image_bytes))
|
|
actual = (
|
|
table.search(query_image, vector_column_name="vector")
|
|
.limit(1)
|
|
.to_pydantic(Images)[0]
|
|
)
|
|
assert actual.label == "dog"
|
|
other = (
|
|
table.search(query_image, vector_column_name="vec_from_bytes")
|
|
.limit(1)
|
|
.to_pydantic(Images)[0]
|
|
)
|
|
assert actual.label == other.label
|
|
|
|
arrow_table = table.search().select(["vector", "vec_from_bytes"]).to_arrow()
|
|
assert np.allclose(
|
|
arrow_table["vector"].combine_chunks().values.to_numpy(),
|
|
arrow_table["vec_from_bytes"].combine_chunks().values.to_numpy(),
|
|
)
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
_imagebind is None,
|
|
reason="skip if imagebind not installed.",
|
|
)
|
|
@pytest.mark.slow
|
|
def test_imagebind(tmp_path):
|
|
import os
|
|
import shutil
|
|
import tempfile
|
|
|
|
import pandas as pd
|
|
import requests
|
|
|
|
from lancedb.embeddings import get_registry
|
|
from lancedb.pydantic import LanceModel, Vector
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
print(f"Created temporary directory {temp_dir}")
|
|
|
|
def download_images(image_uris):
|
|
downloaded_image_paths = []
|
|
for uri in image_uris:
|
|
try:
|
|
response = requests.get(uri, stream=True)
|
|
if response.status_code == 200:
|
|
# Extract image name from URI
|
|
image_name = os.path.basename(uri)
|
|
image_path = os.path.join(temp_dir, image_name)
|
|
with open(image_path, "wb") as out_file:
|
|
shutil.copyfileobj(response.raw, out_file)
|
|
downloaded_image_paths.append(image_path)
|
|
except Exception as e: # noqa: PERF203
|
|
print(f"Failed to download {uri}. Error: {e}")
|
|
return temp_dir, downloaded_image_paths
|
|
|
|
db = lancedb.connect(tmp_path)
|
|
registry = get_registry()
|
|
func = registry.get("imagebind").create(max_retries=0)
|
|
|
|
class Images(LanceModel):
|
|
label: str
|
|
image_uri: str = func.SourceField()
|
|
vector: Vector(func.ndims()) = func.VectorField()
|
|
|
|
table = db.create_table("images", schema=Images)
|
|
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
|
|
uris = [
|
|
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
|
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
|
|
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
|
|
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
|
|
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
|
|
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
|
|
]
|
|
temp_dir, downloaded_images = download_images(uris)
|
|
table.add(pd.DataFrame({"label": labels, "image_uri": downloaded_images}))
|
|
# text search
|
|
actual = (
|
|
table.search("man's best friend", vector_column_name="vector")
|
|
.limit(1)
|
|
.to_pydantic(Images)[0]
|
|
)
|
|
assert actual.label == "dog"
|
|
|
|
# image search
|
|
query_image_uri = [
|
|
"https://live.staticflickr.com/65535/33336453970_491665f66e_h.jpg"
|
|
]
|
|
temp_dir, downloaded_images = download_images(query_image_uri)
|
|
query_image_uri = downloaded_images[0]
|
|
actual = (
|
|
table.search(query_image_uri, vector_column_name="vector")
|
|
.limit(1)
|
|
.to_pydantic(Images)[0]
|
|
)
|
|
assert actual.label == "dog"
|
|
|
|
if os.path.isdir(temp_dir):
|
|
shutil.rmtree(temp_dir)
|
|
print(f"Deleted temporary directory {temp_dir}")
|
|
|
|
|
|
@pytest.mark.slow
|
|
@pytest.mark.skipif(
|
|
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
|
|
) # also skip if cohere not installed
|
|
def test_cohere_embedding_function():
|
|
cohere = (
|
|
get_registry()
|
|
.get("cohere")
|
|
.create(name="embed-multilingual-v2.0", max_retries=0)
|
|
)
|
|
|
|
class TextModel(LanceModel):
|
|
text: str = cohere.SourceField()
|
|
vector: Vector(cohere.ndims()) = cohere.VectorField()
|
|
|
|
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
|
db = lancedb.connect("~/lancedb")
|
|
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
|
|
|
tbl.add(df)
|
|
assert len(tbl.to_pandas()["vector"][0]) == cohere.ndims()
|
|
|
|
|
|
@pytest.mark.slow
|
|
def test_instructor_embedding(tmp_path):
|
|
model = get_registry().get("instructor").create(max_retries=0)
|
|
|
|
class TextModel(LanceModel):
|
|
text: str = model.SourceField()
|
|
vector: Vector(model.ndims()) = model.VectorField()
|
|
|
|
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
|
db = lancedb.connect(tmp_path)
|
|
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
|
|
|
tbl.add(df)
|
|
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
|
|
|
|
|
@pytest.mark.slow
|
|
@pytest.mark.skipif(
|
|
os.environ.get("GOOGLE_API_KEY") is None, reason="GOOGLE_API_KEY not set"
|
|
)
|
|
def test_gemini_embedding(tmp_path):
|
|
model = get_registry().get("gemini-text").create(max_retries=0)
|
|
|
|
class TextModel(LanceModel):
|
|
text: str = model.SourceField()
|
|
vector: Vector(model.ndims()) = model.VectorField()
|
|
|
|
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
|
db = lancedb.connect(tmp_path)
|
|
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
|
|
|
tbl.add(df)
|
|
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
|
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
_mlx is None,
|
|
reason="mlx tests only required for apple users.",
|
|
)
|
|
@pytest.mark.slow
|
|
def test_gte_embedding(tmp_path):
|
|
model = get_registry().get("gte-text").create()
|
|
|
|
class TextModel(LanceModel):
|
|
text: str = model.SourceField()
|
|
vector: Vector(model.ndims()) = model.VectorField()
|
|
|
|
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
|
db = lancedb.connect(tmp_path)
|
|
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
|
|
|
tbl.add(df)
|
|
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
|
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
|
|
|
|
|
def aws_setup():
|
|
try:
|
|
import boto3
|
|
|
|
sts = boto3.client("sts")
|
|
sts.get_caller_identity()
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
@pytest.mark.slow
|
|
@pytest.mark.skipif(
|
|
not aws_setup(), reason="AWS credentials not set or libraries not installed"
|
|
)
|
|
def test_bedrock_embedding(tmp_path):
|
|
for name in [
|
|
"amazon.titan-embed-text-v1",
|
|
"cohere.embed-english-v3",
|
|
"cohere.embed-multilingual-v3",
|
|
]:
|
|
model = get_registry().get("bedrock-text").create(max_retries=0, name=name)
|
|
|
|
class TextModel(LanceModel):
|
|
text: str = model.SourceField()
|
|
vector: Vector(model.ndims()) = model.VectorField()
|
|
|
|
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
|
db = lancedb.connect(tmp_path)
|
|
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
|
|
|
tbl.add(df)
|
|
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
|
|
|
|
|
@pytest.mark.slow
|
|
@pytest.mark.skipif(
|
|
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
|
|
)
|
|
def test_openai_embedding(tmp_path):
|
|
def _get_table(model):
|
|
class TextModel(LanceModel):
|
|
text: str = model.SourceField()
|
|
vector: Vector(model.ndims()) = model.VectorField()
|
|
|
|
db = lancedb.connect(tmp_path)
|
|
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
|
|
|
return tbl
|
|
|
|
model = get_registry().get("openai").create(max_retries=0)
|
|
tbl = _get_table(model)
|
|
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
|
|
|
tbl.add(df)
|
|
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
|
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
|
|
|
model = (
|
|
get_registry()
|
|
.get("openai")
|
|
.create(max_retries=0, name="text-embedding-3-large")
|
|
)
|
|
tbl = _get_table(model)
|
|
|
|
tbl.add(df)
|
|
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
|
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
|
|
|
model = (
|
|
get_registry()
|
|
.get("openai")
|
|
.create(max_retries=0, name="text-embedding-3-large", dim=1024)
|
|
)
|
|
tbl = _get_table(model)
|
|
|
|
tbl.add(df)
|
|
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
|
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
|
|
|
|
|
@pytest.mark.slow
|
|
@pytest.mark.skipif(
|
|
os.environ.get("WATSONX_API_KEY") is None
|
|
or os.environ.get("WATSONX_PROJECT_ID") is None,
|
|
reason="WATSONX_API_KEY and WATSONX_PROJECT_ID not set",
|
|
)
|
|
def test_watsonx_embedding(tmp_path):
|
|
from lancedb.embeddings import WatsonxEmbeddings
|
|
|
|
for name in WatsonxEmbeddings.model_names():
|
|
model = get_registry().get("watsonx").create(max_retries=0, name=name)
|
|
|
|
class TextModel(LanceModel):
|
|
text: str = model.SourceField()
|
|
vector: Vector(model.ndims()) = model.VectorField()
|
|
|
|
db = lancedb.connect("~/.lancedb")
|
|
tbl = db.create_table("watsonx_test", schema=TextModel, mode="overwrite")
|
|
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
|
|
|
tbl.add(df)
|
|
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
|
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
|
|
|
|
|
@pytest.mark.slow
|
|
@pytest.mark.skipif(
|
|
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
|
|
)
|
|
def test_openai_with_empty_strs(tmp_path):
|
|
model = get_registry().get("openai").create(max_retries=0)
|
|
|
|
class TextModel(LanceModel):
|
|
text: str = model.SourceField()
|
|
vector: Vector(model.ndims()) = model.VectorField()
|
|
|
|
df = pd.DataFrame({"text": ["hello world", ""]})
|
|
db = lancedb.connect(tmp_path)
|
|
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
|
|
|
tbl.add(df, on_bad_vectors="skip")
|
|
tb = tbl.to_arrow()
|
|
assert tb.schema.field_by_name("vector").type == pa.list_(
|
|
pa.float32(), model.ndims()
|
|
)
|
|
assert len(tb) == 2
|
|
assert tb["vector"].is_null().to_pylist() == [False, True]
|
|
|
|
|
|
@pytest.mark.slow
|
|
@pytest.mark.skipif(
|
|
importlib.util.find_spec("ollama") is None, reason="Ollama not installed"
|
|
)
|
|
def test_ollama_embedding(tmp_path):
|
|
model = get_registry().get("ollama").create(max_retries=0)
|
|
|
|
class TextModel(LanceModel):
|
|
text: str = model.SourceField()
|
|
vector: Vector(model.ndims()) = model.VectorField()
|
|
|
|
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
|
db = lancedb.connect(tmp_path)
|
|
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
|
|
|
tbl.add(df)
|
|
|
|
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
|
|
|
result = tbl.search("hello").limit(1).to_pandas()
|
|
assert result["text"][0] == "hello world"
|
|
|
|
# Test safe_model_dump
|
|
dumped_model = model.safe_model_dump()
|
|
assert isinstance(dumped_model, dict)
|
|
assert "name" in dumped_model
|
|
assert "max_retries" in dumped_model
|
|
assert dumped_model["max_retries"] == 0
|
|
assert all(not k.startswith("_") for k in dumped_model.keys())
|
|
|
|
# Test serialization of the dumped model
|
|
import json
|
|
|
|
try:
|
|
json.dumps(dumped_model)
|
|
except TypeError:
|
|
pytest.fail("Failed to JSON serialize the dumped model")
|
|
|
|
|
|
@pytest.mark.slow
|
|
@pytest.mark.skipif(
|
|
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
|
)
|
|
def test_voyageai_embedding_function():
|
|
voyageai = get_registry().get("voyageai").create(name="voyage-3", max_retries=0)
|
|
|
|
class TextModel(LanceModel):
|
|
text: str = voyageai.SourceField()
|
|
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
|
|
|
|
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
|
db = lancedb.connect("~/lancedb")
|
|
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
|
|
|
tbl.add(df)
|
|
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
|
|
|
|
|
|
@pytest.mark.slow
|
|
@pytest.mark.skipif(
|
|
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
|
)
|
|
def test_voyageai_multimodal_embedding_function():
|
|
voyageai = (
|
|
get_registry().get("voyageai").create(name="voyage-multimodal-3", max_retries=0)
|
|
)
|
|
|
|
class Images(LanceModel):
|
|
label: str
|
|
image_uri: str = voyageai.SourceField() # image uri as the source
|
|
image_bytes: bytes = voyageai.SourceField() # image bytes as the source
|
|
vector: Vector(voyageai.ndims()) = voyageai.VectorField() # vector column
|
|
vec_from_bytes: Vector(voyageai.ndims()) = (
|
|
voyageai.VectorField()
|
|
) # Another vector column
|
|
|
|
db = lancedb.connect("~/lancedb")
|
|
table = db.create_table("test", schema=Images, mode="overwrite")
|
|
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
|
|
uris = [
|
|
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
|
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
|
|
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
|
|
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
|
|
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
|
|
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
|
|
]
|
|
# get each uri as bytes
|
|
image_bytes = [requests.get(uri).content for uri in uris]
|
|
table.add(
|
|
pd.DataFrame({"label": labels, "image_uri": uris, "image_bytes": image_bytes})
|
|
)
|
|
assert len(table.to_pandas()["vector"][0]) == voyageai.ndims()
|
|
|
|
|
|
@pytest.mark.slow
|
|
@pytest.mark.skipif(
|
|
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
|
)
|
|
def test_voyageai_multimodal_embedding_text_function():
|
|
voyageai = (
|
|
get_registry().get("voyageai").create(name="voyage-multimodal-3", max_retries=0)
|
|
)
|
|
|
|
class TextModel(LanceModel):
|
|
text: str = voyageai.SourceField()
|
|
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
|
|
|
|
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
|
db = lancedb.connect("~/lancedb")
|
|
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
|
|
|
tbl.add(df)
|
|
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
|
|
|
|
|
|
@pytest.mark.slow
|
|
@pytest.mark.skipif(
|
|
importlib.util.find_spec("colpali_engine") is None,
|
|
reason="colpali_engine not installed",
|
|
)
|
|
def test_colpali(tmp_path):
|
|
import requests
|
|
from lancedb.pydantic import LanceModel
|
|
|
|
db = lancedb.connect(tmp_path)
|
|
registry = get_registry()
|
|
func = registry.get("colpali").create()
|
|
|
|
class MediaItems(LanceModel):
|
|
text: str
|
|
image_uri: str = func.SourceField()
|
|
image_bytes: bytes = func.SourceField()
|
|
image_vectors: MultiVector(func.ndims()) = (
|
|
func.VectorField()
|
|
) # Multivector image embeddings
|
|
|
|
table = db.create_table("media", schema=MediaItems)
|
|
|
|
texts = [
|
|
"a cute cat playing with yarn",
|
|
"a puppy in a flower field",
|
|
"a red sports car on the highway",
|
|
"a vintage bicycle leaning against a wall",
|
|
"a plate of delicious pasta",
|
|
"fresh fruit salad in a bowl",
|
|
]
|
|
|
|
uris = [
|
|
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
|
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
|
|
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
|
|
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
|
|
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
|
|
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
|
|
]
|
|
|
|
# Get images as bytes
|
|
image_bytes = [requests.get(uri).content for uri in uris]
|
|
|
|
table.add(
|
|
pd.DataFrame({"text": texts, "image_uri": uris, "image_bytes": image_bytes})
|
|
)
|
|
|
|
# Test text-to-image search
|
|
image_results = (
|
|
table.search("fluffy companion", vector_column_name="image_vectors")
|
|
.limit(1)
|
|
.to_pydantic(MediaItems)[0]
|
|
)
|
|
assert "cat" in image_results.text.lower() or "puppy" in image_results.text.lower()
|
|
|
|
# Verify multivector dimensions
|
|
first_row = table.to_arrow().to_pylist()[0]
|
|
assert len(first_row["image_vectors"]) > 1, "Should have multiple image vectors"
|
|
assert len(first_row["image_vectors"][0]) == func.ndims(), (
|
|
"Vector dimension mismatch"
|
|
)
|