Files
lancedb/python/python/tests/test_embeddings_slow.py
fzowl 93c2cf2f59 feat(voyageai): update voyage integration (#2713)
Adding multimodal usage guide
VoyageAI integration changes:
 - Adding voyage-3.5 and voyage-3.5-lite models
 - Adding voyage-context-3 model
 - Adding rerank-2.5 and rerank-2.5-lite models
2025-10-29 16:49:07 +05:30

846 lines
27 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import importlib
import io
import os
import lancedb
import numpy as np
import pandas as pd
import pyarrow as pa
import pytest
from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector, MultiVector
# These are integration tests for embedding functions.
# They are slow because they require downloading models
# or connection to external api
try:
if importlib.util.find_spec("mlx.core") is not None:
_mlx = True
else:
_mlx = None
except Exception:
_mlx = None
try:
if importlib.util.find_spec("imagebind") is not None:
_imagebind = True
else:
_imagebind = None
except Exception:
_imagebind = None
@pytest.mark.slow
@pytest.mark.parametrize(
"alias", ["sentence-transformers", "openai", "huggingface", "ollama"]
)
def test_basic_text_embeddings(alias, tmp_path):
db = lancedb.connect(tmp_path)
registry = get_registry()
func = registry.get(alias).create(max_retries=0)
func2 = registry.get(alias).create(max_retries=0)
class Words(LanceModel):
text: str = func.SourceField()
text2: str = func2.SourceField()
vector: Vector(func.ndims()) = func.VectorField()
vector2: Vector(func2.ndims()) = func2.VectorField()
table = db.create_table("words", schema=Words)
table.add(
pd.DataFrame(
{
"text": [
"hello world",
"goodbye world",
"fizz",
"buzz",
"foo",
"bar",
"baz",
],
"text2": [
"to be or not to be",
"that is the question",
"for whether tis nobler",
"in the mind to suffer",
"the slings and arrows",
"of outrageous fortune",
"or to take arms",
],
}
)
)
query = "greeting"
actual = (
table.search(query, vector_column_name="vector").limit(1).to_pydantic(Words)[0]
)
vec = func.compute_query_embeddings(query)[0]
expected = (
table.search(vec, vector_column_name="vector").limit(1).to_pydantic(Words)[0]
)
assert actual.text == expected.text
assert actual.text == "hello world"
assert not np.allclose(actual.vector, actual.vector2)
actual = (
table.search(query, vector_column_name="vector2").limit(1).to_pydantic(Words)[0]
)
assert actual.text != "hello world"
assert not np.allclose(actual.vector, actual.vector2)
@pytest.fixture(scope="module")
def test_images():
import requests
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
uris = [
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
]
image_bytes = [requests.get(uri).content for uri in uris]
return labels, uris, image_bytes
@pytest.fixture(scope="module")
def query_image_bytes():
import requests
query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg"
image_bytes = requests.get(query_image_uri).content
return image_bytes
@pytest.mark.slow
def test_openclip(tmp_path, test_images, query_image_bytes):
from PIL import Image
db = lancedb.connect(tmp_path)
registry = get_registry()
func = registry.get("open-clip").create(max_retries=0)
class Images(LanceModel):
label: str
image_uri: str = func.SourceField()
image_bytes: bytes = func.SourceField()
vector: Vector(func.ndims()) = func.VectorField()
vec_from_bytes: Vector(func.ndims()) = func.VectorField()
labels, uris, image_bytes_list = test_images
table = db.create_table("images", schema=Images)
table.add(
pd.DataFrame(
{"label": labels, "image_uri": uris, "image_bytes": image_bytes_list}
)
)
# text search
actual = (
table.search("man's best friend", vector_column_name="vector")
.limit(1)
.to_pydantic(Images)[0]
)
assert actual.label == "dog"
frombytes = (
table.search("man's best friend", vector_column_name="vec_from_bytes")
.limit(1)
.to_pydantic(Images)[0]
)
assert actual.label == frombytes.label
assert np.allclose(actual.vector, frombytes.vector)
# image search
query_image = Image.open(io.BytesIO(query_image_bytes))
actual = (
table.search(query_image, vector_column_name="vector")
.limit(1)
.to_pydantic(Images)[0]
)
assert actual.label == "dog"
other = (
table.search(query_image, vector_column_name="vec_from_bytes")
.limit(1)
.to_pydantic(Images)[0]
)
assert actual.label == other.label
arrow_table = table.search().select(["vector", "vec_from_bytes"]).to_arrow()
assert np.allclose(
arrow_table["vector"].combine_chunks().values.to_numpy(),
arrow_table["vec_from_bytes"].combine_chunks().values.to_numpy(),
)
@pytest.mark.skipif(
_imagebind is None,
reason="skip if imagebind not installed.",
)
@pytest.mark.slow
def test_imagebind(tmp_path):
import os
import shutil
import tempfile
import pandas as pd
import requests
from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector
with tempfile.TemporaryDirectory() as temp_dir:
print(f"Created temporary directory {temp_dir}")
def download_images(image_uris):
downloaded_image_paths = []
for uri in image_uris:
try:
response = requests.get(uri, stream=True)
if response.status_code == 200:
# Extract image name from URI
image_name = os.path.basename(uri)
image_path = os.path.join(temp_dir, image_name)
with open(image_path, "wb") as out_file:
shutil.copyfileobj(response.raw, out_file)
downloaded_image_paths.append(image_path)
except Exception as e: # noqa: PERF203
print(f"Failed to download {uri}. Error: {e}")
return temp_dir, downloaded_image_paths
db = lancedb.connect(tmp_path)
registry = get_registry()
func = registry.get("imagebind").create(max_retries=0)
class Images(LanceModel):
label: str
image_uri: str = func.SourceField()
vector: Vector(func.ndims()) = func.VectorField()
table = db.create_table("images", schema=Images)
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
uris = [
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
]
temp_dir, downloaded_images = download_images(uris)
table.add(pd.DataFrame({"label": labels, "image_uri": downloaded_images}))
# text search
actual = (
table.search("man's best friend", vector_column_name="vector")
.limit(1)
.to_pydantic(Images)[0]
)
assert actual.label == "dog"
# image search
query_image_uri = [
"https://live.staticflickr.com/65535/33336453970_491665f66e_h.jpg"
]
temp_dir, downloaded_images = download_images(query_image_uri)
query_image_uri = downloaded_images[0]
actual = (
table.search(query_image_uri, vector_column_name="vector")
.limit(1)
.to_pydantic(Images)[0]
)
assert actual.label == "dog"
if os.path.isdir(temp_dir):
shutil.rmtree(temp_dir)
print(f"Deleted temporary directory {temp_dir}")
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
) # also skip if cohere not installed
def test_cohere_embedding_function():
cohere = (
get_registry()
.get("cohere")
.create(name="embed-multilingual-v2.0", max_retries=0)
)
class TextModel(LanceModel):
text: str = cohere.SourceField()
vector: Vector(cohere.ndims()) = cohere.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect("~/lancedb")
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == cohere.ndims()
@pytest.mark.slow
def test_instructor_embedding(tmp_path):
model = get_registry().get("instructor").create(max_retries=0)
class TextModel(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect(tmp_path)
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("GOOGLE_API_KEY") is None, reason="GOOGLE_API_KEY not set"
)
def test_gemini_embedding(tmp_path):
model = get_registry().get("gemini-text").create(max_retries=0)
class TextModel(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect(tmp_path)
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
@pytest.mark.skipif(
_mlx is None,
reason="mlx tests only required for apple users.",
)
@pytest.mark.slow
def test_gte_embedding(tmp_path):
model = get_registry().get("gte-text").create()
class TextModel(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect(tmp_path)
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
def aws_setup():
try:
import boto3
sts = boto3.client("sts")
sts.get_caller_identity()
return True
except Exception:
return False
@pytest.mark.slow
@pytest.mark.skipif(
not aws_setup(), reason="AWS credentials not set or libraries not installed"
)
def test_bedrock_embedding(tmp_path):
for name in [
"amazon.titan-embed-text-v1",
"cohere.embed-english-v3",
"cohere.embed-multilingual-v3",
]:
model = get_registry().get("bedrock-text").create(max_retries=0, name=name)
class TextModel(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect(tmp_path)
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
)
def test_openai_embedding(tmp_path):
def _get_table(model):
class TextModel(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
db = lancedb.connect(tmp_path)
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
return tbl
model = get_registry().get("openai").create(max_retries=0)
tbl = _get_table(model)
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
model = (
get_registry()
.get("openai")
.create(max_retries=0, name="text-embedding-3-large")
)
tbl = _get_table(model)
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
model = (
get_registry()
.get("openai")
.create(max_retries=0, name="text-embedding-3-large", dim=1024)
)
tbl = _get_table(model)
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("WATSONX_API_KEY") is None
or os.environ.get("WATSONX_PROJECT_ID") is None,
reason="WATSONX_API_KEY and WATSONX_PROJECT_ID not set",
)
def test_watsonx_embedding(tmp_path):
from lancedb.embeddings import WatsonxEmbeddings
for name in WatsonxEmbeddings.model_names():
model = get_registry().get("watsonx").create(max_retries=0, name=name)
class TextModel(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
db = lancedb.connect("~/.lancedb")
tbl = db.create_table("watsonx_test", schema=TextModel, mode="overwrite")
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
)
def test_openai_with_empty_strs(tmp_path):
model = get_registry().get("openai").create(max_retries=0)
class TextModel(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
df = pd.DataFrame({"text": ["hello world", ""]})
db = lancedb.connect(tmp_path)
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df, on_bad_vectors="skip")
tb = tbl.to_arrow()
assert tb.schema.field_by_name("vector").type == pa.list_(
pa.float32(), model.ndims()
)
assert len(tb) == 2
assert tb["vector"].is_null().to_pylist() == [False, True]
@pytest.mark.slow
@pytest.mark.skipif(
importlib.util.find_spec("ollama") is None, reason="Ollama not installed"
)
def test_ollama_embedding(tmp_path):
model = get_registry().get("ollama").create(max_retries=0)
class TextModel(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect(tmp_path)
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
result = tbl.search("hello").limit(1).to_pandas()
assert result["text"][0] == "hello world"
# Test safe_model_dump
dumped_model = model.safe_model_dump()
assert isinstance(dumped_model, dict)
assert "name" in dumped_model
assert "max_retries" in dumped_model
assert dumped_model["max_retries"] == 0
assert all(not k.startswith("_") for k in dumped_model.keys())
# Test serialization of the dumped model
import json
try:
json.dumps(dumped_model)
except TypeError:
pytest.fail("Failed to JSON serialize the dumped model")
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
)
def test_voyageai_embedding_function():
voyageai = get_registry().get("voyageai").create(name="voyage-3", max_retries=0)
class TextModel(LanceModel):
text: str = voyageai.SourceField()
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect("~/lancedb")
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
)
def test_voyageai_embedding_function_contextual_model():
voyageai = (
get_registry().get("voyageai").create(name="voyage-context-3", max_retries=0)
)
class TextModel(LanceModel):
text: str = voyageai.SourceField()
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect("~/lancedb")
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
)
def test_voyageai_multimodal_embedding_function():
import requests
voyageai = (
get_registry().get("voyageai").create(name="voyage-multimodal-3", max_retries=0)
)
class Images(LanceModel):
label: str
image_uri: str = voyageai.SourceField() # image uri as the source
image_bytes: bytes = voyageai.SourceField() # image bytes as the source
vector: Vector(voyageai.ndims()) = voyageai.VectorField() # vector column
vec_from_bytes: Vector(voyageai.ndims()) = (
voyageai.VectorField()
) # Another vector column
db = lancedb.connect("~/lancedb")
table = db.create_table("test", schema=Images, mode="overwrite")
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
uris = [
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
]
# get each uri as bytes
image_bytes = [requests.get(uri).content for uri in uris]
table.add(
pd.DataFrame({"label": labels, "image_uri": uris, "image_bytes": image_bytes})
)
assert len(table.to_pandas()["vector"][0]) == voyageai.ndims()
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
)
def test_voyageai_multimodal_embedding_text_function():
voyageai = (
get_registry().get("voyageai").create(name="voyage-multimodal-3", max_retries=0)
)
class TextModel(LanceModel):
text: str = voyageai.SourceField()
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect("~/lancedb")
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
@pytest.mark.slow
@pytest.mark.skipif(
importlib.util.find_spec("colpali_engine") is None,
reason="colpali_engine not installed",
)
def test_colpali(tmp_path):
import requests
from lancedb.pydantic import LanceModel
db = lancedb.connect(tmp_path)
registry = get_registry()
func = registry.get("colpali").create()
class MediaItems(LanceModel):
text: str
image_uri: str = func.SourceField()
image_bytes: bytes = func.SourceField()
image_vectors: MultiVector(func.ndims()) = (
func.VectorField()
) # Multivector image embeddings
table = db.create_table("media", schema=MediaItems)
texts = [
"a cute cat playing with yarn",
"a puppy in a flower field",
"a red sports car on the highway",
"a vintage bicycle leaning against a wall",
"a plate of delicious pasta",
"fresh fruit salad in a bowl",
]
uris = [
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
]
# Get images as bytes
image_bytes = [requests.get(uri).content for uri in uris]
table.add(
pd.DataFrame({"text": texts, "image_uri": uris, "image_bytes": image_bytes})
)
# Test text-to-image search
image_results = (
table.search("fluffy companion", vector_column_name="image_vectors")
.limit(1)
.to_pydantic(MediaItems)[0]
)
assert "cat" in image_results.text.lower() or "puppy" in image_results.text.lower()
# Verify multivector dimensions
first_row = table.to_arrow().to_pylist()[0]
assert len(first_row["image_vectors"]) > 1, "Should have multiple image vectors"
assert len(first_row["image_vectors"][0]) == func.ndims(), (
"Vector dimension mismatch"
)
@pytest.mark.slow
@pytest.mark.skipif(
importlib.util.find_spec("colpali_engine") is None,
reason="colpali_engine not installed",
)
@pytest.mark.parametrize(
"model_name",
[
"vidore/colSmol-256M",
"vidore/colqwen2.5-v0.2",
"vidore/colpali-v1.3",
"vidore/colqwen2-v1.0",
],
)
def test_colpali_models(tmp_path, model_name):
import requests
from lancedb.pydantic import LanceModel
db = lancedb.connect(tmp_path)
registry = get_registry()
func = registry.get("colpali").create(model_name=model_name)
class MediaItems(LanceModel):
text: str
image_uri: str = func.SourceField()
image_bytes: bytes = func.SourceField()
image_vectors: MultiVector(func.ndims()) = func.VectorField()
table = db.create_table(f"media_{model_name.replace('/', '_')}", schema=MediaItems)
texts = [
"a cute cat playing with yarn",
]
uris = [
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
]
image_bytes = [requests.get(uri).content for uri in uris]
table.add(
pd.DataFrame({"text": texts, "image_uri": uris, "image_bytes": image_bytes})
)
image_results = (
table.search("fluffy companion", vector_column_name="image_vectors")
.limit(1)
.to_pydantic(MediaItems)[0]
)
assert "cat" in image_results.text.lower() or "puppy" in image_results.text.lower()
first_row = table.to_arrow().to_pylist()[0]
assert len(first_row["image_vectors"]) > 1, "Should have multiple image vectors"
assert len(first_row["image_vectors"][0]) == func.ndims(), (
"Vector dimension mismatch"
)
@pytest.mark.slow
@pytest.mark.skipif(
importlib.util.find_spec("colpali_engine") is None,
reason="colpali_engine not installed",
)
def test_colpali_pooling(tmp_path):
registry = get_registry()
model_name = "vidore/colSmol-256M"
test_sentence = "a test sentence for pooling"
# 1. Get embeddings with no pooling
func_no_pool = registry.get("colpali").create(
model_name=model_name, pooling_strategy=None
)
unpooled_embeddings = func_no_pool.generate_text_embeddings([test_sentence])[0]
original_length = len(unpooled_embeddings)
assert original_length > 1
# 2. Test hierarchical pooling
func_hierarchical = registry.get("colpali").create(
model_name=model_name, pooling_strategy="hierarchical", pool_factor=2
)
hierarchical_embeddings = func_hierarchical.generate_text_embeddings(
[test_sentence]
)[0]
expected_hierarchical_length = (original_length + 1) // 2
assert len(hierarchical_embeddings) == expected_hierarchical_length
# 3. Test lambda pooling
def simple_pool_func(tensor):
return tensor[::2]
func_lambda = registry.get("colpali").create(
model_name=model_name,
pooling_strategy="lambda",
pooling_func=simple_pool_func,
)
lambda_embeddings = func_lambda.generate_text_embeddings([test_sentence])[0]
expected_lambda_length = (original_length + 1) // 2
assert len(lambda_embeddings) == expected_lambda_length
@pytest.mark.slow
def test_siglip(tmp_path, test_images, query_image_bytes):
from PIL import Image
labels, uris, image_bytes = test_images
db = lancedb.connect(tmp_path)
registry = get_registry()
func = registry.get("siglip").create(max_retries=0)
class Images(LanceModel):
label: str
image_uri: str = func.SourceField()
image_bytes: bytes = func.SourceField()
vector: Vector(func.ndims()) = func.VectorField()
vec_from_bytes: Vector(func.ndims()) = func.VectorField()
table = db.create_table("images", schema=Images)
table.add(
pd.DataFrame(
{
"label": labels,
"image_uri": uris,
"image_bytes": image_bytes,
}
)
)
# Text search
actual = (
table.search("man's best friend", vector_column_name="vector")
.limit(1)
.to_pydantic(Images)[0]
)
assert actual.label == "dog"
frombytes = (
table.search("man's best friend", vector_column_name="vec_from_bytes")
.limit(1)
.to_pydantic(Images)[0]
)
assert actual.label == frombytes.label
assert np.allclose(actual.vector, frombytes.vector)
# Image search
query_image = Image.open(io.BytesIO(query_image_bytes))
actual = (
table.search(query_image, vector_column_name="vector")
.limit(1)
.to_pydantic(Images)[0]
)
assert actual.label == "dog"
other = (
table.search(query_image, vector_column_name="vec_from_bytes")
.limit(1)
.to_pydantic(Images)[0]
)
assert actual.label == other.label
arrow_table = table.search().select(["vector", "vec_from_bytes"]).to_arrow()
assert np.allclose(
arrow_table["vector"].combine_chunks().values.to_numpy(),
arrow_table["vec_from_bytes"].combine_chunks().values.to_numpy(),
)