This commit is contained in:
ayush chaurasia
2025-10-13 15:17:50 +05:30
parent dadb042978
commit b766cbe0a9
6 changed files with 426 additions and 231 deletions

View File

@@ -80,7 +80,7 @@ embeddings = [
"pillow",
"open-clip-torch",
"cohere",
"colpali-engine>=0.3.10",
"colpali-engine>=0.3.12",
"huggingface_hub",
"InstructorEmbedding",
"google.generativeai",

View File

@@ -19,5 +19,5 @@ from .imagebind import ImageBindEmbeddings
from .jinaai import JinaEmbeddings
from .watsonx import WatsonxEmbeddings
from .voyageai import VoyageAIEmbeddingFunction
from .colpali import ColPaliEmbeddings
from .colpali import MultimodalLateInteractionEmbeddings, ColPaliEmbeddings # noqa: F401
from .siglip import SigLipEmbeddings

508
python/python/lancedb/embeddings/colpali.py Normal file → Executable file
View File

@@ -1,255 +1,347 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
"""Late-interaction embeddings powered by colpali-engine."""
from __future__ import annotations
from functools import lru_cache
from typing import List, Union, Optional, Any
import numpy as np
import io
from typing import Any, Dict, List, Optional, Sequence
from ..util import attempt_import_or_raise
from .base import EmbeddingFunction
from .registry import register
from .utils import TEXT, IMAGES, is_flash_attn_2_available
from .utils import IMAGES, TEXT, is_flash_attn_2_available, weak_lru
@register("colpali")
class ColPaliEmbeddings(EmbeddingFunction):
"""
An embedding function that uses the ColPali engine for
multimodal multi-vector embeddings.
_FAMILY_ALIASES = {
"colsmol": {"colsmol", "colsmolvlm", "smol"},
"colqwen2.5": {"colqwen2.5", "colqwen25", "colqwen-2.5"},
"colqwen2": {"colqwen2", "colqwen-2"},
"colpali": {"colpali", "paligemma"},
}
This embedding function supports ColQwen2.5 models, producing multivector outputs
for both text and image inputs. The output embeddings are lists of vectors, each
vector being 128-dimensional by default, represented as List[List[float]].
_FAMILY_CLASSES = {
"colpali": ("ColPali", "ColPaliProcessor"),
"colqwen2.5": ("ColQwen2_5", "ColQwen2_5_Processor"),
"colqwen2": ("ColQwen2", "ColQwen2Processor"),
"colsmol": ("ColIdefics3", "ColIdefics3Processor"),
}
Parameters
----------
model_name : str
The name of the model to use (e.g., "Metric-AI/ColQwen2.5-3b-multilingual-v1.0")
device : str
The device for inference (default "cuda:0").
dtype : str
Data type for model weights (default "bfloat16").
use_token_pooling : bool
Whether to use token pooling to reduce embedding size (default True).
pool_factor : int
Factor to reduce sequence length if token pooling is enabled (default 2).
quantization_config : Optional[BitsAndBytesConfig]
Quantization configuration for the model. (default None, bitsandbytes needed)
batch_size : int
Batch size for processing inputs (default 2).
"""
model_name: str = "Metric-AI/ColQwen2.5-3b-multilingual-v1.0"
def _torch() -> Any:
return attempt_import_or_raise("torch", "torch")
def _torch_dtype(dtype: str) -> Any:
torch = _torch()
mapping = {
"bfloat16": torch.bfloat16,
"float16": torch.float16,
"float32": torch.float32,
"float64": torch.float64,
}
if dtype not in mapping:
raise ValueError(
"Unsupported dtype '{}'. Expected one of {}".format(
dtype, ", ".join(sorted(mapping))
)
)
return mapping[dtype]
def _load_pooler(use_pooler: bool) -> Optional[Any]:
if not use_pooler:
return None
token_pooling = attempt_import_or_raise(
"colpali_engine.compression.token_pooling", "colpali-engine"
)
pooler_cls = getattr(token_pooling, "HierarchicalTokenPooler", None)
if pooler_cls is None:
raise ImportError(
"colpali_engine HierarchicalTokenPooler not available; update colpali-engine"
)
return pooler_cls()
def _move_to_device(batch: Any, device: Any) -> Any:
if device is None:
return batch
torch = _torch()
if isinstance(device, str):
device_obj = torch.device(device)
else:
device_obj = device
if isinstance(batch, dict):
return {k: _move_to_device(v, device_obj) for k, v in batch.items()}
if hasattr(batch, "to"):
return batch.to(device_obj)
return batch
@register("multimodal-late-interaction")
class MultimodalLateInteractionEmbeddings(EmbeddingFunction):
"""Late-interaction embeddings for ViDoRe models."""
model_name: str = "vidore/colSmol-256M"
model_family: Optional[str] = None
device: str = "auto"
dtype: str = "bfloat16"
use_token_pooling: bool = True
pool_factor: int = 2
batch_size: int = 4
quantization_config: Optional[Any] = None
batch_size: int = 2
_model = None
_processor = None
_token_pooler = None
_vector_dim = None
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
(
self._model,
self._processor,
self._token_pooler,
) = self._load_model(
self.model_name,
self.dtype,
self.device,
self.use_token_pooling,
self.quantization_config,
)
self._family = self._resolve_family(self.model_name, self.model_family)
self._vector_dim: Optional[int] = None
@staticmethod
@lru_cache(maxsize=1)
def _load_model(
model_name: str,
dtype: str,
device: str,
use_token_pooling: bool,
quantization_config: Optional[Any],
):
"""
Initialize and cache the ColPali model, processor, and token pooler.
"""
torch = attempt_import_or_raise("torch", "torch")
@property
def model(self) -> Any:
"""The cached model."""
return self._get_models()[0]
@property
def processor(self) -> Any:
"""The cached processor."""
return self._get_models()[1]
@property
def pooler(self) -> Optional[Any]:
"""The cached pooler."""
return self._get_models()[2]
@property
def target_device(self) -> Optional[Any]:
"""The cached target device."""
return self._get_models()[3]
# ------------------------------------------------------------------
# Family detection
# ------------------------------------------------------------------
@classmethod
def _resolve_family(cls, model_name: str, explicit: Optional[str]) -> str:
if explicit:
family = explicit.lower()
if family not in _FAMILY_CLASSES:
raise ValueError(
"Unknown model_family '{}'. Expected one of {}".format(
explicit, ", ".join(sorted(_FAMILY_CLASSES))
)
)
return family
lowered = model_name.lower()
for family, aliases in _FAMILY_ALIASES.items():
if any(alias in lowered for alias in aliases):
return family
return "colpali"
# ------------------------------------------------------------------
# Model loading
# ------------------------------------------------------------------
@weak_lru(maxsize=1)
def _get_models(self) -> tuple[Any, Optional[Any], Optional[Any], Optional[Any]]:
colpali_engine = attempt_import_or_raise("colpali_engine", "colpali-engine")
transformers = attempt_import_or_raise("transformers", "transformers")
colpali_engine = attempt_import_or_raise("colpali_engine", "colpali_engine")
from colpali_engine.compression.token_pooling import HierarchicalTokenPooler
if quantization_config is not None:
if not isinstance(quantization_config, transformers.BitsAndBytesConfig):
raise ValueError("quantization_config must be a BitsAndBytesConfig")
if (
self.quantization_config is not None
and not isinstance(
self.quantization_config, transformers.BitsAndBytesConfig
)
):
raise ValueError(
"quantization_config must be a transformers.BitsAndBytesConfig instance"
)
if dtype == "bfloat16":
torch_dtype = torch.bfloat16
elif dtype == "float16":
torch_dtype = torch.float16
elif dtype == "float64":
torch_dtype = torch.float64
model_cls_name, processor_cls_name = _FAMILY_CLASSES[self._family]
model_cls = getattr(colpali_engine.models, model_cls_name)
processor_cls = getattr(colpali_engine.models, processor_cls_name)
torch = _torch()
device_map = self.device
target_device: Optional[Any] = None
if device_map == "auto":
if torch.cuda.is_available():
device_map = "cuda:0"
target_device = torch.device("cuda:0")
elif (
getattr(torch.backends, "mps", None)
and torch.backends.mps.is_available()
):
device_map = "mps"
target_device = torch.device("mps")
else:
device_map = "cpu"
target_device = torch.device("cpu")
else:
try:
target_device = torch.device(device_map)
except (TypeError, ValueError): # pragma: no cover - device map dicts
target_device = None
torch_dtype = _torch_dtype(self.dtype)
if isinstance(device_map, str) and device_map == "cpu" and torch_dtype in {
torch.bfloat16,
torch.float16,
}:
torch_dtype = torch.float32
model = colpali_engine.models.ColQwen2_5.from_pretrained(
model_name,
torch_dtype=torch_dtype,
device_map=device,
quantization_config=quantization_config
if quantization_config is not None
else None,
attn_implementation="flash_attention_2"
if is_flash_attn_2_available()
else None,
).eval()
processor = colpali_engine.models.ColQwen2_5_Processor.from_pretrained(
model_name
load_kwargs: Dict[str, Any] = {
"torch_dtype": torch_dtype,
"device_map": device_map,
}
if self.quantization_config is not None:
load_kwargs["quantization_config"] = self.quantization_config
attn_impl = "flash_attention_2" if is_flash_attn_2_available() else None
if attn_impl is not None:
load_kwargs["attn_implementation"] = attn_impl
model = model_cls.from_pretrained(self.model_name, **load_kwargs)
if hasattr(model, "eval"):
model = model.eval()
processor = processor_cls.from_pretrained(self.model_name)
pooler = _load_pooler(self.use_token_pooling)
if target_device is None and hasattr(model, "device"):
target_device = getattr(model, "device")
return model, processor, pooler, target_device
# ------------------------------------------------------------------
# Encoding helpers
# ------------------------------------------------------------------
def _pool_tensor(self, tensor: Any) -> Any:
if self.pooler is None:
return tensor
torch = _torch()
assert isinstance(tensor, torch.Tensor)
expanded = False
if tensor.ndim == 2:
tensor = tensor.unsqueeze(0)
expanded = True
kwargs = {"pool_factor": self.pool_factor, "padding": True}
tokenizer = getattr(
getattr(self.processor, "tokenizer", None), "padding_side", None
)
token_pooler = HierarchicalTokenPooler() if use_token_pooling else None
return model, processor, token_pooler
def ndims(self):
"""
Return the dimension of a vector in the multivector output (e.g., 128).
"""
torch = attempt_import_or_raise("torch", "torch")
if self._vector_dim is None:
dummy_query = "test"
batch_queries = self._processor.process_queries([dummy_query]).to(
self._model.device
)
with torch.no_grad():
query_embeddings = self._model(**batch_queries)
if self.use_token_pooling and self._token_pooler is not None:
query_embeddings = self._token_pooler.pool_embeddings(
query_embeddings,
pool_factor=self.pool_factor,
padding=True,
padding_side=self._processor.tokenizer.padding_side,
)
self._vector_dim = query_embeddings[0].shape[-1]
return self._vector_dim
def _process_embeddings(self, embeddings):
"""
Format model embeddings into List[List[float]].
Use token pooling if enabled.
"""
torch = attempt_import_or_raise("torch", "torch")
if self.use_token_pooling and self._token_pooler is not None:
embeddings = self._token_pooler.pool_embeddings(
embeddings,
pool_factor=self.pool_factor,
padding=True,
padding_side=self._processor.tokenizer.padding_side,
)
if tokenizer is not None:
kwargs["padding_side"] = tokenizer
pooled = self.pooler.pool_embeddings(tensor, **kwargs)
if expanded:
pooled = pooled.squeeze(0)
return pooled
def _normalize_output(self, embeddings: Any) -> List[List[List[float]]]:
torch = _torch()
if hasattr(embeddings, "last_hidden_state"):
return self._normalize_output(embeddings.last_hidden_state)
if isinstance(embeddings, dict) and "last_hidden_state" in embeddings:
return self._normalize_output(embeddings["last_hidden_state"])
if isinstance(embeddings, torch.Tensor):
tensors = embeddings.detach().cpu()
if tensors.dtype == torch.bfloat16:
tensors = tensors.to(torch.float32)
return (
tensors.numpy()
.astype(np.float64 if self.dtype == "float64" else np.float32)
.tolist()
)
return []
pooled = self._pool_tensor(embeddings).detach().cpu()
if pooled.ndim == 2:
pooled = pooled.unsqueeze(0)
target = torch.float64 if self.dtype == "float64" else torch.float32
return pooled.to(target).numpy().tolist()
if isinstance(embeddings, (list, tuple)):
results: List[List[List[float]]] = []
for item in embeddings:
results.extend(self._normalize_output(item))
return results
raise TypeError(f"Unsupported embedding type {type(embeddings)}")
def generate_text_embeddings(self, text: TEXT) -> List[List[List[float]]]:
"""
Generate embeddings for text input.
"""
torch = attempt_import_or_raise("torch", "torch")
text = self.sanitize_input(text)
all_embeddings = []
# ------------------------------------------------------------------
# Text encoding
# ------------------------------------------------------------------
def _encode_text(self, batch: Sequence[str]) -> List[List[List[float]]]:
if not self.processor or not hasattr(self.processor, "process_queries"):
raise RuntimeError("Processor for text queries is not available for this model")
payload = self.processor.process_queries(batch)
payload = _move_to_device(payload, self.target_device)
torch = _torch()
with torch.no_grad():
outputs = self.model(**payload)
return self._normalize_output(outputs)
for i in range(0, len(text), self.batch_size):
batch_text = text[i : i + self.batch_size]
batch_queries = self._processor.process_queries(batch_text).to(
self._model.device
)
with torch.no_grad():
query_embeddings = self._model(**batch_queries)
all_embeddings.extend(self._process_embeddings(query_embeddings))
return all_embeddings
def _prepare_images(self, images: IMAGES) -> List:
"""
Convert image inputs to PIL Images.
"""
# ------------------------------------------------------------------
# Image encoding
# ------------------------------------------------------------------
def _prepare_images(self, images: IMAGES) -> List[Any]:
PIL = attempt_import_or_raise("PIL", "pillow")
requests = attempt_import_or_raise("requests", "requests")
images = self.sanitize_input(images)
pil_images = []
try:
for image in images:
if isinstance(image, str):
if image.startswith(("http://", "https://")):
response = requests.get(image, timeout=10)
response.raise_for_status()
pil_images.append(PIL.Image.open(io.BytesIO(response.content)))
else:
with PIL.Image.open(image) as im:
pil_images.append(im.copy())
elif isinstance(image, bytes):
pil_images.append(PIL.Image.open(io.BytesIO(image)))
else:
# Assume it's a PIL Image; will raise if invalid
pil_images.append(image)
except Exception as e:
raise ValueError(f"Failed to process image: {e}")
prepared: List[Any] = []
for image in self.sanitize_input(images):
if isinstance(image, str) and image.startswith(("http://", "https://")):
response = requests.get(image, timeout=10)
response.raise_for_status()
prepared.append(PIL.Image.open(io.BytesIO(response.content)))
elif isinstance(image, str):
with PIL.Image.open(image) as img:
prepared.append(img.copy())
elif isinstance(image, bytes):
prepared.append(PIL.Image.open(io.BytesIO(image)))
else:
prepared.append(image)
return prepared
return pil_images
def _encode_images(self, images: Sequence[Any]) -> List[List[List[float]]]:
if not self.processor or not hasattr(self.processor, "process_images"):
raise RuntimeError("Processor for images is not available for this model")
payload = self.processor.process_images(images)
payload = _move_to_device(payload, self.target_device)
torch = _torch()
with torch.no_grad():
outputs = self.model(**payload)
return self._normalize_output(outputs)
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def _batched(self, values: Sequence[Any], encoder) -> List[List[List[float]]]:
results: List[List[List[float]]] = []
for start in range(0, len(values), self.batch_size):
chunk = values[start : start + self.batch_size]
results.extend(encoder(chunk))
return results
def generate_text_embeddings(self, text: TEXT) -> List[List[List[float]]]:
text = self.sanitize_input(text)
if len(text) == 0:
return []
return self._batched(text, self._encode_text)
def generate_image_embeddings(self, images: IMAGES) -> List[List[List[float]]]:
"""
Generate embeddings for a batch of images.
"""
torch = attempt_import_or_raise("torch", "torch")
pil_images = self._prepare_images(images)
all_embeddings = []
for i in range(0, len(pil_images), self.batch_size):
batch_images = pil_images[i : i + self.batch_size]
batch_images = self._processor.process_images(batch_images).to(
self._model.device
)
with torch.no_grad():
image_embeddings = self._model(**batch_images)
all_embeddings.extend(self._process_embeddings(image_embeddings))
return all_embeddings
prepared = self._prepare_images(images)
if len(prepared) == 0:
return []
return self._batched(prepared, self._encode_images)
def compute_query_embeddings(
self, query: Union[str, IMAGES], *args, **kwargs
self, query: str, *args: Any, **kwargs: Any
) -> List[List[List[float]]]:
"""
Compute embeddings for a single user query (text only).
"""
if not isinstance(query, str):
raise ValueError(
"Query must be a string, image to image search is not supported"
)
raise ValueError("Late interaction queries must be text")
return self.generate_text_embeddings([query])
def compute_source_embeddings(
self, images: IMAGES, *args, **kwargs
self, images: IMAGES, *args: Any, **kwargs: Any
) -> List[List[List[float]]]:
"""
Compute embeddings for a batch of source images.
Parameters
----------
images : Union[str, bytes, List, pa.Array, pa.ChunkedArray, np.ndarray]
Batch of images (paths, URLs, bytes, or PIL Images).
"""
images = self.sanitize_input(images)
return self.generate_image_embeddings(images)
def ndims(self) -> int:
if self._vector_dim is None:
probe = self.generate_text_embeddings(["probe"])
if not probe or not probe[0]:
raise RuntimeError("Failed to determine embedding dimension")
self._vector_dim = len(probe[0][0])
return self._vector_dim
# Backwards compatibility: keep the historical "colpali" key
register("colpali")(MultimodalLateInteractionEmbeddings)
# Legacy class name kept for backwards compatibility in imports
ColPaliEmbeddings = MultimodalLateInteractionEmbeddings

View File

@@ -17,6 +17,7 @@ from lancedb.embeddings import (
EmbeddingFunctionRegistry,
)
from lancedb.embeddings.base import TextEmbeddingFunction
from lancedb.embeddings.colpali import MultimodalLateInteractionEmbeddings
from lancedb.embeddings.registry import get_registry, register
from lancedb.embeddings.utils import retry
from lancedb.pydantic import LanceModel, Vector
@@ -515,3 +516,16 @@ def test_openai_propagates_api_key(monkeypatch):
query = "greetings"
actual = table.search(query).limit(1).to_pydantic(Words)[0]
assert len(actual.text) > 0
def test_multimodal_late_interaction_family_detection():
resolver = MultimodalLateInteractionEmbeddings._resolve_family
assert resolver("vidore/colSmol-256M", None) == "colsmol"
assert resolver("vidore/colqwen2.5-v0.2", None) == "colqwen2.5"
assert resolver("vidore/colqwen2-v1.0", None) == "colqwen2"
assert resolver("vidore/colpali-v1.3", None) == "colpali"
assert resolver("custom/model", None) == "colpali"
assert resolver("any/model", "colqwen2") == "colqwen2"
with pytest.raises(ValueError):
resolver("any/model", "unknown")

View File

@@ -4,6 +4,7 @@
import importlib
import io
import os
import lancedb
import numpy as np
import pandas as pd
@@ -33,6 +34,24 @@ try:
except Exception:
_imagebind = None
try:
import torch
except ImportError:
torch = None
HAS_ACCEL = bool(
torch
and (
torch.cuda.is_available()
or getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
)
)
RUN_HEAVY_VIDORE = os.getenv("LANCEDB_TEST_FULL_LATE_INTERACTION") in {"1", "true", "yes"}
HEAVY_SKIP = pytest.mark.skipif(
not (RUN_HEAVY_VIDORE and HAS_ACCEL),
reason="Set LANCEDB_TEST_FULL_LATE_INTERACTION=1 and run on GPU to exercise large vidore checkpoints",
)
@pytest.mark.slow
@pytest.mark.parametrize(
@@ -597,21 +616,44 @@ def test_voyageai_multimodal_embedding_text_function():
importlib.util.find_spec("colpali_engine") is None,
reason="colpali_engine not installed",
)
def test_colpali(tmp_path):
@pytest.mark.parametrize(
"model_name",
[
#pytest.param("vidore/colSmol-256M", id="colSmol"),
pytest.param(
"vidore/colqwen2-v1.0",
id="colQwen2",
#marks=HEAVY_SKIP,
),
pytest.param(
"vidore/colqwen2.5-v0.2",
id="colQwen2.5",
# marks=HEAVY_SKIP,
),
pytest.param(
"vidore/colpali-v1.3-merged",
id="colPali",
#marks=HEAVY_SKIP,
),
],
)
def test_multimodal_late_interaction_models(tmp_path, model_name):
import requests
from lancedb.pydantic import LanceModel
from lancedb.pydantic import LanceModel, Vector
db = lancedb.connect(tmp_path)
registry = get_registry()
func = registry.get("colpali").create()
func = registry.get("multimodal-late-interaction").create(
model_name=model_name,
device="auto",
batch_size=1,
)
class MediaItems(LanceModel):
text: str
image_uri: str = func.SourceField()
image_bytes: bytes = func.SourceField()
image_vectors: MultiVector(func.ndims()) = (
func.VectorField()
) # Multivector image embeddings
image_vectors: MultiVector(func.ndims()) = func.VectorField()
table = db.create_table("media", schema=MediaItems)
@@ -619,41 +661,30 @@ def test_colpali(tmp_path):
"a cute cat playing with yarn",
"a puppy in a flower field",
"a red sports car on the highway",
"a vintage bicycle leaning against a wall",
"a plate of delicious pasta",
"fresh fruit salad in a bowl",
]
uris = [
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
]
# Get images as bytes
image_bytes = [requests.get(uri).content for uri in uris]
table.add(
pd.DataFrame({"text": texts, "image_uri": uris, "image_bytes": image_bytes})
)
# Test text-to-image search
image_results = (
result = (
table.search("fluffy companion", vector_column_name="image_vectors")
.limit(1)
.to_pydantic(MediaItems)[0]
)
assert "cat" in image_results.text.lower() or "puppy" in image_results.text.lower()
assert any(keyword in result.text.lower() for keyword in ("cat", "puppy"))
# Verify multivector dimensions
first_row = table.to_arrow().to_pylist()[0]
assert len(first_row["image_vectors"]) > 1, "Should have multiple image vectors"
assert len(first_row["image_vectors"][0]) == func.ndims(), (
"Vector dimension mismatch"
)
assert len(first_row["image_vectors"]) > 1
assert len(first_row["image_vectors"][0]) == func.ndims()
@pytest.mark.slow

58
python/test.py Normal file
View File

@@ -0,0 +1,58 @@
import requests
from lancedb.pydantic import LanceModel, Vector
import importlib
import io
import os
import lancedb
import numpy as np
import pandas as pd
import pyarrow as pa
import pytest
from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector, MultiVector
db = lancedb.connect("~/.db")
registry = get_registry()
func = registry.get("multimodal-late-interaction").create(
model_name="vidore/colQwen2.5-v0.2",
device="auto",
batch_size=1,
)
class MediaItems(LanceModel):
text: str
image_uri: str = func.SourceField()
image_bytes: bytes = func.SourceField()
image_vectors: MultiVector(func.ndims()) = func.VectorField()
table = db.create_table("media", schema=MediaItems, mode="overwrite")
texts = [
"a cute cat playing with yarn",
"a puppy in a flower field",
"a red sports car on the highway",
]
uris = [
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
]
image_bytes = [requests.get(uri).content for uri in uris]
table.add(
pd.DataFrame({"text": texts, "image_uri": uris, "image_bytes": image_bytes})
)
result = (
table.search("fluffy companion", vector_column_name="image_vectors")
.limit(1)
.to_pydantic(MediaItems)[0]
)
assert any(keyword in result.text.lower() for keyword in ("cat", "puppy"))
first_row = table.to_arrow().to_pylist()[0]
assert len(first_row["image_vectors"]) > 1
assert len(first_row["image_vectors"][0]) == func.ndims()