Compare commits

...

1 Commits

Author SHA1 Message Date
ayush chaurasia
b766cbe0a9 init 2025-10-13 15:17:50 +05:30
6 changed files with 426 additions and 231 deletions

View File

@@ -80,7 +80,7 @@ embeddings = [
"pillow", "pillow",
"open-clip-torch", "open-clip-torch",
"cohere", "cohere",
"colpali-engine>=0.3.10", "colpali-engine>=0.3.12",
"huggingface_hub", "huggingface_hub",
"InstructorEmbedding", "InstructorEmbedding",
"google.generativeai", "google.generativeai",

View File

@@ -19,5 +19,5 @@ from .imagebind import ImageBindEmbeddings
from .jinaai import JinaEmbeddings from .jinaai import JinaEmbeddings
from .watsonx import WatsonxEmbeddings from .watsonx import WatsonxEmbeddings
from .voyageai import VoyageAIEmbeddingFunction from .voyageai import VoyageAIEmbeddingFunction
from .colpali import ColPaliEmbeddings from .colpali import MultimodalLateInteractionEmbeddings, ColPaliEmbeddings # noqa: F401
from .siglip import SigLipEmbeddings from .siglip import SigLipEmbeddings

508
python/python/lancedb/embeddings/colpali.py Normal file → Executable file
View File

@@ -1,255 +1,347 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors # SPDX-FileCopyrightText: Copyright The LanceDB Authors
"""Late-interaction embeddings powered by colpali-engine."""
from __future__ import annotations
from functools import lru_cache
from typing import List, Union, Optional, Any
import numpy as np
import io import io
from typing import Any, Dict, List, Optional, Sequence
from ..util import attempt_import_or_raise from ..util import attempt_import_or_raise
from .base import EmbeddingFunction from .base import EmbeddingFunction
from .registry import register from .registry import register
from .utils import TEXT, IMAGES, is_flash_attn_2_available from .utils import IMAGES, TEXT, is_flash_attn_2_available, weak_lru
@register("colpali") _FAMILY_ALIASES = {
class ColPaliEmbeddings(EmbeddingFunction): "colsmol": {"colsmol", "colsmolvlm", "smol"},
""" "colqwen2.5": {"colqwen2.5", "colqwen25", "colqwen-2.5"},
An embedding function that uses the ColPali engine for "colqwen2": {"colqwen2", "colqwen-2"},
multimodal multi-vector embeddings. "colpali": {"colpali", "paligemma"},
}
This embedding function supports ColQwen2.5 models, producing multivector outputs _FAMILY_CLASSES = {
for both text and image inputs. The output embeddings are lists of vectors, each "colpali": ("ColPali", "ColPaliProcessor"),
vector being 128-dimensional by default, represented as List[List[float]]. "colqwen2.5": ("ColQwen2_5", "ColQwen2_5_Processor"),
"colqwen2": ("ColQwen2", "ColQwen2Processor"),
"colsmol": ("ColIdefics3", "ColIdefics3Processor"),
}
Parameters
----------
model_name : str
The name of the model to use (e.g., "Metric-AI/ColQwen2.5-3b-multilingual-v1.0")
device : str
The device for inference (default "cuda:0").
dtype : str
Data type for model weights (default "bfloat16").
use_token_pooling : bool
Whether to use token pooling to reduce embedding size (default True).
pool_factor : int
Factor to reduce sequence length if token pooling is enabled (default 2).
quantization_config : Optional[BitsAndBytesConfig]
Quantization configuration for the model. (default None, bitsandbytes needed)
batch_size : int
Batch size for processing inputs (default 2).
"""
model_name: str = "Metric-AI/ColQwen2.5-3b-multilingual-v1.0" def _torch() -> Any:
return attempt_import_or_raise("torch", "torch")
def _torch_dtype(dtype: str) -> Any:
torch = _torch()
mapping = {
"bfloat16": torch.bfloat16,
"float16": torch.float16,
"float32": torch.float32,
"float64": torch.float64,
}
if dtype not in mapping:
raise ValueError(
"Unsupported dtype '{}'. Expected one of {}".format(
dtype, ", ".join(sorted(mapping))
)
)
return mapping[dtype]
def _load_pooler(use_pooler: bool) -> Optional[Any]:
if not use_pooler:
return None
token_pooling = attempt_import_or_raise(
"colpali_engine.compression.token_pooling", "colpali-engine"
)
pooler_cls = getattr(token_pooling, "HierarchicalTokenPooler", None)
if pooler_cls is None:
raise ImportError(
"colpali_engine HierarchicalTokenPooler not available; update colpali-engine"
)
return pooler_cls()
def _move_to_device(batch: Any, device: Any) -> Any:
if device is None:
return batch
torch = _torch()
if isinstance(device, str):
device_obj = torch.device(device)
else:
device_obj = device
if isinstance(batch, dict):
return {k: _move_to_device(v, device_obj) for k, v in batch.items()}
if hasattr(batch, "to"):
return batch.to(device_obj)
return batch
@register("multimodal-late-interaction")
class MultimodalLateInteractionEmbeddings(EmbeddingFunction):
"""Late-interaction embeddings for ViDoRe models."""
model_name: str = "vidore/colSmol-256M"
model_family: Optional[str] = None
device: str = "auto" device: str = "auto"
dtype: str = "bfloat16" dtype: str = "bfloat16"
use_token_pooling: bool = True use_token_pooling: bool = True
pool_factor: int = 2 pool_factor: int = 2
batch_size: int = 4
quantization_config: Optional[Any] = None quantization_config: Optional[Any] = None
batch_size: int = 2
_model = None def __init__(self, *args: Any, **kwargs: Any) -> None:
_processor = None
_token_pooler = None
_vector_dim = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
( self._family = self._resolve_family(self.model_name, self.model_family)
self._model, self._vector_dim: Optional[int] = None
self._processor,
self._token_pooler,
) = self._load_model(
self.model_name,
self.dtype,
self.device,
self.use_token_pooling,
self.quantization_config,
)
@staticmethod @property
@lru_cache(maxsize=1) def model(self) -> Any:
def _load_model( """The cached model."""
model_name: str, return self._get_models()[0]
dtype: str,
device: str, @property
use_token_pooling: bool, def processor(self) -> Any:
quantization_config: Optional[Any], """The cached processor."""
): return self._get_models()[1]
"""
Initialize and cache the ColPali model, processor, and token pooler. @property
""" def pooler(self) -> Optional[Any]:
torch = attempt_import_or_raise("torch", "torch") """The cached pooler."""
return self._get_models()[2]
@property
def target_device(self) -> Optional[Any]:
"""The cached target device."""
return self._get_models()[3]
# ------------------------------------------------------------------
# Family detection
# ------------------------------------------------------------------
@classmethod
def _resolve_family(cls, model_name: str, explicit: Optional[str]) -> str:
if explicit:
family = explicit.lower()
if family not in _FAMILY_CLASSES:
raise ValueError(
"Unknown model_family '{}'. Expected one of {}".format(
explicit, ", ".join(sorted(_FAMILY_CLASSES))
)
)
return family
lowered = model_name.lower()
for family, aliases in _FAMILY_ALIASES.items():
if any(alias in lowered for alias in aliases):
return family
return "colpali"
# ------------------------------------------------------------------
# Model loading
# ------------------------------------------------------------------
@weak_lru(maxsize=1)
def _get_models(self) -> tuple[Any, Optional[Any], Optional[Any], Optional[Any]]:
colpali_engine = attempt_import_or_raise("colpali_engine", "colpali-engine")
transformers = attempt_import_or_raise("transformers", "transformers") transformers = attempt_import_or_raise("transformers", "transformers")
colpali_engine = attempt_import_or_raise("colpali_engine", "colpali_engine")
from colpali_engine.compression.token_pooling import HierarchicalTokenPooler
if quantization_config is not None: if (
if not isinstance(quantization_config, transformers.BitsAndBytesConfig): self.quantization_config is not None
raise ValueError("quantization_config must be a BitsAndBytesConfig") and not isinstance(
self.quantization_config, transformers.BitsAndBytesConfig
)
):
raise ValueError(
"quantization_config must be a transformers.BitsAndBytesConfig instance"
)
if dtype == "bfloat16": model_cls_name, processor_cls_name = _FAMILY_CLASSES[self._family]
torch_dtype = torch.bfloat16 model_cls = getattr(colpali_engine.models, model_cls_name)
elif dtype == "float16": processor_cls = getattr(colpali_engine.models, processor_cls_name)
torch_dtype = torch.float16
elif dtype == "float64": torch = _torch()
torch_dtype = torch.float64 device_map = self.device
target_device: Optional[Any] = None
if device_map == "auto":
if torch.cuda.is_available():
device_map = "cuda:0"
target_device = torch.device("cuda:0")
elif (
getattr(torch.backends, "mps", None)
and torch.backends.mps.is_available()
):
device_map = "mps"
target_device = torch.device("mps")
else:
device_map = "cpu"
target_device = torch.device("cpu")
else: else:
try:
target_device = torch.device(device_map)
except (TypeError, ValueError): # pragma: no cover - device map dicts
target_device = None
torch_dtype = _torch_dtype(self.dtype)
if isinstance(device_map, str) and device_map == "cpu" and torch_dtype in {
torch.bfloat16,
torch.float16,
}:
torch_dtype = torch.float32 torch_dtype = torch.float32
model = colpali_engine.models.ColQwen2_5.from_pretrained( load_kwargs: Dict[str, Any] = {
model_name, "torch_dtype": torch_dtype,
torch_dtype=torch_dtype, "device_map": device_map,
device_map=device, }
quantization_config=quantization_config if self.quantization_config is not None:
if quantization_config is not None load_kwargs["quantization_config"] = self.quantization_config
else None, attn_impl = "flash_attention_2" if is_flash_attn_2_available() else None
attn_implementation="flash_attention_2" if attn_impl is not None:
if is_flash_attn_2_available() load_kwargs["attn_implementation"] = attn_impl
else None,
).eval() model = model_cls.from_pretrained(self.model_name, **load_kwargs)
processor = colpali_engine.models.ColQwen2_5_Processor.from_pretrained( if hasattr(model, "eval"):
model_name model = model.eval()
processor = processor_cls.from_pretrained(self.model_name)
pooler = _load_pooler(self.use_token_pooling)
if target_device is None and hasattr(model, "device"):
target_device = getattr(model, "device")
return model, processor, pooler, target_device
# ------------------------------------------------------------------
# Encoding helpers
# ------------------------------------------------------------------
def _pool_tensor(self, tensor: Any) -> Any:
if self.pooler is None:
return tensor
torch = _torch()
assert isinstance(tensor, torch.Tensor)
expanded = False
if tensor.ndim == 2:
tensor = tensor.unsqueeze(0)
expanded = True
kwargs = {"pool_factor": self.pool_factor, "padding": True}
tokenizer = getattr(
getattr(self.processor, "tokenizer", None), "padding_side", None
) )
token_pooler = HierarchicalTokenPooler() if use_token_pooling else None if tokenizer is not None:
return model, processor, token_pooler kwargs["padding_side"] = tokenizer
pooled = self.pooler.pool_embeddings(tensor, **kwargs)
def ndims(self): if expanded:
""" pooled = pooled.squeeze(0)
Return the dimension of a vector in the multivector output (e.g., 128). return pooled
"""
torch = attempt_import_or_raise("torch", "torch")
if self._vector_dim is None:
dummy_query = "test"
batch_queries = self._processor.process_queries([dummy_query]).to(
self._model.device
)
with torch.no_grad():
query_embeddings = self._model(**batch_queries)
if self.use_token_pooling and self._token_pooler is not None:
query_embeddings = self._token_pooler.pool_embeddings(
query_embeddings,
pool_factor=self.pool_factor,
padding=True,
padding_side=self._processor.tokenizer.padding_side,
)
self._vector_dim = query_embeddings[0].shape[-1]
return self._vector_dim
def _process_embeddings(self, embeddings):
"""
Format model embeddings into List[List[float]].
Use token pooling if enabled.
"""
torch = attempt_import_or_raise("torch", "torch")
if self.use_token_pooling and self._token_pooler is not None:
embeddings = self._token_pooler.pool_embeddings(
embeddings,
pool_factor=self.pool_factor,
padding=True,
padding_side=self._processor.tokenizer.padding_side,
)
def _normalize_output(self, embeddings: Any) -> List[List[List[float]]]:
torch = _torch()
if hasattr(embeddings, "last_hidden_state"):
return self._normalize_output(embeddings.last_hidden_state)
if isinstance(embeddings, dict) and "last_hidden_state" in embeddings:
return self._normalize_output(embeddings["last_hidden_state"])
if isinstance(embeddings, torch.Tensor): if isinstance(embeddings, torch.Tensor):
tensors = embeddings.detach().cpu() pooled = self._pool_tensor(embeddings).detach().cpu()
if tensors.dtype == torch.bfloat16: if pooled.ndim == 2:
tensors = tensors.to(torch.float32) pooled = pooled.unsqueeze(0)
return ( target = torch.float64 if self.dtype == "float64" else torch.float32
tensors.numpy() return pooled.to(target).numpy().tolist()
.astype(np.float64 if self.dtype == "float64" else np.float32) if isinstance(embeddings, (list, tuple)):
.tolist() results: List[List[List[float]]] = []
) for item in embeddings:
return [] results.extend(self._normalize_output(item))
return results
raise TypeError(f"Unsupported embedding type {type(embeddings)}")
def generate_text_embeddings(self, text: TEXT) -> List[List[List[float]]]: # ------------------------------------------------------------------
""" # Text encoding
Generate embeddings for text input. # ------------------------------------------------------------------
""" def _encode_text(self, batch: Sequence[str]) -> List[List[List[float]]]:
torch = attempt_import_or_raise("torch", "torch") if not self.processor or not hasattr(self.processor, "process_queries"):
text = self.sanitize_input(text) raise RuntimeError("Processor for text queries is not available for this model")
all_embeddings = [] payload = self.processor.process_queries(batch)
payload = _move_to_device(payload, self.target_device)
torch = _torch()
with torch.no_grad():
outputs = self.model(**payload)
return self._normalize_output(outputs)
for i in range(0, len(text), self.batch_size): # ------------------------------------------------------------------
batch_text = text[i : i + self.batch_size] # Image encoding
batch_queries = self._processor.process_queries(batch_text).to( # ------------------------------------------------------------------
self._model.device def _prepare_images(self, images: IMAGES) -> List[Any]:
)
with torch.no_grad():
query_embeddings = self._model(**batch_queries)
all_embeddings.extend(self._process_embeddings(query_embeddings))
return all_embeddings
def _prepare_images(self, images: IMAGES) -> List:
"""
Convert image inputs to PIL Images.
"""
PIL = attempt_import_or_raise("PIL", "pillow") PIL = attempt_import_or_raise("PIL", "pillow")
requests = attempt_import_or_raise("requests", "requests") requests = attempt_import_or_raise("requests", "requests")
images = self.sanitize_input(images) prepared: List[Any] = []
pil_images = [] for image in self.sanitize_input(images):
try: if isinstance(image, str) and image.startswith(("http://", "https://")):
for image in images: response = requests.get(image, timeout=10)
if isinstance(image, str): response.raise_for_status()
if image.startswith(("http://", "https://")): prepared.append(PIL.Image.open(io.BytesIO(response.content)))
response = requests.get(image, timeout=10) elif isinstance(image, str):
response.raise_for_status() with PIL.Image.open(image) as img:
pil_images.append(PIL.Image.open(io.BytesIO(response.content))) prepared.append(img.copy())
else: elif isinstance(image, bytes):
with PIL.Image.open(image) as im: prepared.append(PIL.Image.open(io.BytesIO(image)))
pil_images.append(im.copy()) else:
elif isinstance(image, bytes): prepared.append(image)
pil_images.append(PIL.Image.open(io.BytesIO(image))) return prepared
else:
# Assume it's a PIL Image; will raise if invalid
pil_images.append(image)
except Exception as e:
raise ValueError(f"Failed to process image: {e}")
return pil_images def _encode_images(self, images: Sequence[Any]) -> List[List[List[float]]]:
if not self.processor or not hasattr(self.processor, "process_images"):
raise RuntimeError("Processor for images is not available for this model")
payload = self.processor.process_images(images)
payload = _move_to_device(payload, self.target_device)
torch = _torch()
with torch.no_grad():
outputs = self.model(**payload)
return self._normalize_output(outputs)
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def _batched(self, values: Sequence[Any], encoder) -> List[List[List[float]]]:
results: List[List[List[float]]] = []
for start in range(0, len(values), self.batch_size):
chunk = values[start : start + self.batch_size]
results.extend(encoder(chunk))
return results
def generate_text_embeddings(self, text: TEXT) -> List[List[List[float]]]:
text = self.sanitize_input(text)
if len(text) == 0:
return []
return self._batched(text, self._encode_text)
def generate_image_embeddings(self, images: IMAGES) -> List[List[List[float]]]: def generate_image_embeddings(self, images: IMAGES) -> List[List[List[float]]]:
""" prepared = self._prepare_images(images)
Generate embeddings for a batch of images. if len(prepared) == 0:
""" return []
torch = attempt_import_or_raise("torch", "torch") return self._batched(prepared, self._encode_images)
pil_images = self._prepare_images(images)
all_embeddings = []
for i in range(0, len(pil_images), self.batch_size):
batch_images = pil_images[i : i + self.batch_size]
batch_images = self._processor.process_images(batch_images).to(
self._model.device
)
with torch.no_grad():
image_embeddings = self._model(**batch_images)
all_embeddings.extend(self._process_embeddings(image_embeddings))
return all_embeddings
def compute_query_embeddings( def compute_query_embeddings(
self, query: Union[str, IMAGES], *args, **kwargs self, query: str, *args: Any, **kwargs: Any
) -> List[List[List[float]]]: ) -> List[List[List[float]]]:
"""
Compute embeddings for a single user query (text only).
"""
if not isinstance(query, str): if not isinstance(query, str):
raise ValueError( raise ValueError("Late interaction queries must be text")
"Query must be a string, image to image search is not supported"
)
return self.generate_text_embeddings([query]) return self.generate_text_embeddings([query])
def compute_source_embeddings( def compute_source_embeddings(
self, images: IMAGES, *args, **kwargs self, images: IMAGES, *args: Any, **kwargs: Any
) -> List[List[List[float]]]: ) -> List[List[List[float]]]:
"""
Compute embeddings for a batch of source images.
Parameters
----------
images : Union[str, bytes, List, pa.Array, pa.ChunkedArray, np.ndarray]
Batch of images (paths, URLs, bytes, or PIL Images).
"""
images = self.sanitize_input(images)
return self.generate_image_embeddings(images) return self.generate_image_embeddings(images)
def ndims(self) -> int:
if self._vector_dim is None:
probe = self.generate_text_embeddings(["probe"])
if not probe or not probe[0]:
raise RuntimeError("Failed to determine embedding dimension")
self._vector_dim = len(probe[0][0])
return self._vector_dim
# Backwards compatibility: keep the historical "colpali" key
register("colpali")(MultimodalLateInteractionEmbeddings)
# Legacy class name kept for backwards compatibility in imports
ColPaliEmbeddings = MultimodalLateInteractionEmbeddings

View File

@@ -17,6 +17,7 @@ from lancedb.embeddings import (
EmbeddingFunctionRegistry, EmbeddingFunctionRegistry,
) )
from lancedb.embeddings.base import TextEmbeddingFunction from lancedb.embeddings.base import TextEmbeddingFunction
from lancedb.embeddings.colpali import MultimodalLateInteractionEmbeddings
from lancedb.embeddings.registry import get_registry, register from lancedb.embeddings.registry import get_registry, register
from lancedb.embeddings.utils import retry from lancedb.embeddings.utils import retry
from lancedb.pydantic import LanceModel, Vector from lancedb.pydantic import LanceModel, Vector
@@ -515,3 +516,16 @@ def test_openai_propagates_api_key(monkeypatch):
query = "greetings" query = "greetings"
actual = table.search(query).limit(1).to_pydantic(Words)[0] actual = table.search(query).limit(1).to_pydantic(Words)[0]
assert len(actual.text) > 0 assert len(actual.text) > 0
def test_multimodal_late_interaction_family_detection():
resolver = MultimodalLateInteractionEmbeddings._resolve_family
assert resolver("vidore/colSmol-256M", None) == "colsmol"
assert resolver("vidore/colqwen2.5-v0.2", None) == "colqwen2.5"
assert resolver("vidore/colqwen2-v1.0", None) == "colqwen2"
assert resolver("vidore/colpali-v1.3", None) == "colpali"
assert resolver("custom/model", None) == "colpali"
assert resolver("any/model", "colqwen2") == "colqwen2"
with pytest.raises(ValueError):
resolver("any/model", "unknown")

View File

@@ -4,6 +4,7 @@
import importlib import importlib
import io import io
import os import os
import lancedb import lancedb
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@@ -33,6 +34,24 @@ try:
except Exception: except Exception:
_imagebind = None _imagebind = None
try:
import torch
except ImportError:
torch = None
HAS_ACCEL = bool(
torch
and (
torch.cuda.is_available()
or getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
)
)
RUN_HEAVY_VIDORE = os.getenv("LANCEDB_TEST_FULL_LATE_INTERACTION") in {"1", "true", "yes"}
HEAVY_SKIP = pytest.mark.skipif(
not (RUN_HEAVY_VIDORE and HAS_ACCEL),
reason="Set LANCEDB_TEST_FULL_LATE_INTERACTION=1 and run on GPU to exercise large vidore checkpoints",
)
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -597,21 +616,44 @@ def test_voyageai_multimodal_embedding_text_function():
importlib.util.find_spec("colpali_engine") is None, importlib.util.find_spec("colpali_engine") is None,
reason="colpali_engine not installed", reason="colpali_engine not installed",
) )
def test_colpali(tmp_path): @pytest.mark.parametrize(
"model_name",
[
#pytest.param("vidore/colSmol-256M", id="colSmol"),
pytest.param(
"vidore/colqwen2-v1.0",
id="colQwen2",
#marks=HEAVY_SKIP,
),
pytest.param(
"vidore/colqwen2.5-v0.2",
id="colQwen2.5",
# marks=HEAVY_SKIP,
),
pytest.param(
"vidore/colpali-v1.3-merged",
id="colPali",
#marks=HEAVY_SKIP,
),
],
)
def test_multimodal_late_interaction_models(tmp_path, model_name):
import requests import requests
from lancedb.pydantic import LanceModel from lancedb.pydantic import LanceModel, Vector
db = lancedb.connect(tmp_path) db = lancedb.connect(tmp_path)
registry = get_registry() registry = get_registry()
func = registry.get("colpali").create() func = registry.get("multimodal-late-interaction").create(
model_name=model_name,
device="auto",
batch_size=1,
)
class MediaItems(LanceModel): class MediaItems(LanceModel):
text: str text: str
image_uri: str = func.SourceField() image_uri: str = func.SourceField()
image_bytes: bytes = func.SourceField() image_bytes: bytes = func.SourceField()
image_vectors: MultiVector(func.ndims()) = ( image_vectors: MultiVector(func.ndims()) = func.VectorField()
func.VectorField()
) # Multivector image embeddings
table = db.create_table("media", schema=MediaItems) table = db.create_table("media", schema=MediaItems)
@@ -619,41 +661,30 @@ def test_colpali(tmp_path):
"a cute cat playing with yarn", "a cute cat playing with yarn",
"a puppy in a flower field", "a puppy in a flower field",
"a red sports car on the highway", "a red sports car on the highway",
"a vintage bicycle leaning against a wall",
"a plate of delicious pasta",
"fresh fruit salad in a bowl",
] ]
uris = [ uris = [
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg", "http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg", "http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg", "http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
] ]
# Get images as bytes
image_bytes = [requests.get(uri).content for uri in uris] image_bytes = [requests.get(uri).content for uri in uris]
table.add( table.add(
pd.DataFrame({"text": texts, "image_uri": uris, "image_bytes": image_bytes}) pd.DataFrame({"text": texts, "image_uri": uris, "image_bytes": image_bytes})
) )
# Test text-to-image search result = (
image_results = (
table.search("fluffy companion", vector_column_name="image_vectors") table.search("fluffy companion", vector_column_name="image_vectors")
.limit(1) .limit(1)
.to_pydantic(MediaItems)[0] .to_pydantic(MediaItems)[0]
) )
assert "cat" in image_results.text.lower() or "puppy" in image_results.text.lower() assert any(keyword in result.text.lower() for keyword in ("cat", "puppy"))
# Verify multivector dimensions
first_row = table.to_arrow().to_pylist()[0] first_row = table.to_arrow().to_pylist()[0]
assert len(first_row["image_vectors"]) > 1, "Should have multiple image vectors" assert len(first_row["image_vectors"]) > 1
assert len(first_row["image_vectors"][0]) == func.ndims(), ( assert len(first_row["image_vectors"][0]) == func.ndims()
"Vector dimension mismatch"
)
@pytest.mark.slow @pytest.mark.slow

58
python/test.py Normal file
View File

@@ -0,0 +1,58 @@
import requests
from lancedb.pydantic import LanceModel, Vector
import importlib
import io
import os
import lancedb
import numpy as np
import pandas as pd
import pyarrow as pa
import pytest
from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector, MultiVector
db = lancedb.connect("~/.db")
registry = get_registry()
func = registry.get("multimodal-late-interaction").create(
model_name="vidore/colQwen2.5-v0.2",
device="auto",
batch_size=1,
)
class MediaItems(LanceModel):
text: str
image_uri: str = func.SourceField()
image_bytes: bytes = func.SourceField()
image_vectors: MultiVector(func.ndims()) = func.VectorField()
table = db.create_table("media", schema=MediaItems, mode="overwrite")
texts = [
"a cute cat playing with yarn",
"a puppy in a flower field",
"a red sports car on the highway",
]
uris = [
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
]
image_bytes = [requests.get(uri).content for uri in uris]
table.add(
pd.DataFrame({"text": texts, "image_uri": uris, "image_bytes": image_bytes})
)
result = (
table.search("fluffy companion", vector_column_name="image_vectors")
.limit(1)
.to_pydantic(MediaItems)[0]
)
assert any(keyword in result.text.lower() for keyword in ("cat", "puppy"))
first_row = table.to_arrow().to_pylist()[0]
assert len(first_row["image_vectors"]) > 1
assert len(first_row["image_vectors"][0]) == func.ndims()