From b766cbe0a9d723638fd1edc58d188334b516cc3b Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Mon, 13 Oct 2025 15:17:50 +0530 Subject: [PATCH] init --- python/pyproject.toml | 2 +- python/python/lancedb/embeddings/__init__.py | 2 +- python/python/lancedb/embeddings/colpali.py | 508 +++++++++++-------- python/python/tests/test_embeddings.py | 14 + python/python/tests/test_embeddings_slow.py | 73 ++- python/test.py | 58 +++ 6 files changed, 426 insertions(+), 231 deletions(-) mode change 100644 => 100755 python/python/lancedb/embeddings/colpali.py create mode 100644 python/test.py diff --git a/python/pyproject.toml b/python/pyproject.toml index 118b23dc..791e445f 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -80,7 +80,7 @@ embeddings = [ "pillow", "open-clip-torch", "cohere", - "colpali-engine>=0.3.10", + "colpali-engine>=0.3.12", "huggingface_hub", "InstructorEmbedding", "google.generativeai", diff --git a/python/python/lancedb/embeddings/__init__.py b/python/python/lancedb/embeddings/__init__.py index f70aa57b..368134eb 100644 --- a/python/python/lancedb/embeddings/__init__.py +++ b/python/python/lancedb/embeddings/__init__.py @@ -19,5 +19,5 @@ from .imagebind import ImageBindEmbeddings from .jinaai import JinaEmbeddings from .watsonx import WatsonxEmbeddings from .voyageai import VoyageAIEmbeddingFunction -from .colpali import ColPaliEmbeddings +from .colpali import MultimodalLateInteractionEmbeddings, ColPaliEmbeddings # noqa: F401 from .siglip import SigLipEmbeddings diff --git a/python/python/lancedb/embeddings/colpali.py b/python/python/lancedb/embeddings/colpali.py old mode 100644 new mode 100755 index 150eaa52..b2c92ca4 --- a/python/python/lancedb/embeddings/colpali.py +++ b/python/python/lancedb/embeddings/colpali.py @@ -1,255 +1,347 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright The LanceDB Authors +"""Late-interaction embeddings powered by colpali-engine.""" + +from __future__ import annotations -from functools import lru_cache -from typing import List, Union, Optional, Any -import numpy as np import io +from typing import Any, Dict, List, Optional, Sequence from ..util import attempt_import_or_raise from .base import EmbeddingFunction from .registry import register -from .utils import TEXT, IMAGES, is_flash_attn_2_available +from .utils import IMAGES, TEXT, is_flash_attn_2_available, weak_lru -@register("colpali") -class ColPaliEmbeddings(EmbeddingFunction): - """ - An embedding function that uses the ColPali engine for - multimodal multi-vector embeddings. +_FAMILY_ALIASES = { + "colsmol": {"colsmol", "colsmolvlm", "smol"}, + "colqwen2.5": {"colqwen2.5", "colqwen25", "colqwen-2.5"}, + "colqwen2": {"colqwen2", "colqwen-2"}, + "colpali": {"colpali", "paligemma"}, +} - This embedding function supports ColQwen2.5 models, producing multivector outputs - for both text and image inputs. The output embeddings are lists of vectors, each - vector being 128-dimensional by default, represented as List[List[float]]. +_FAMILY_CLASSES = { + "colpali": ("ColPali", "ColPaliProcessor"), + "colqwen2.5": ("ColQwen2_5", "ColQwen2_5_Processor"), + "colqwen2": ("ColQwen2", "ColQwen2Processor"), + "colsmol": ("ColIdefics3", "ColIdefics3Processor"), +} - Parameters - ---------- - model_name : str - The name of the model to use (e.g., "Metric-AI/ColQwen2.5-3b-multilingual-v1.0") - device : str - The device for inference (default "cuda:0"). - dtype : str - Data type for model weights (default "bfloat16"). - use_token_pooling : bool - Whether to use token pooling to reduce embedding size (default True). - pool_factor : int - Factor to reduce sequence length if token pooling is enabled (default 2). - quantization_config : Optional[BitsAndBytesConfig] - Quantization configuration for the model. (default None, bitsandbytes needed) - batch_size : int - Batch size for processing inputs (default 2). - """ - model_name: str = "Metric-AI/ColQwen2.5-3b-multilingual-v1.0" +def _torch() -> Any: + return attempt_import_or_raise("torch", "torch") + + +def _torch_dtype(dtype: str) -> Any: + torch = _torch() + mapping = { + "bfloat16": torch.bfloat16, + "float16": torch.float16, + "float32": torch.float32, + "float64": torch.float64, + } + if dtype not in mapping: + raise ValueError( + "Unsupported dtype '{}'. Expected one of {}".format( + dtype, ", ".join(sorted(mapping)) + ) + ) + return mapping[dtype] + + +def _load_pooler(use_pooler: bool) -> Optional[Any]: + if not use_pooler: + return None + token_pooling = attempt_import_or_raise( + "colpali_engine.compression.token_pooling", "colpali-engine" + ) + pooler_cls = getattr(token_pooling, "HierarchicalTokenPooler", None) + if pooler_cls is None: + raise ImportError( + "colpali_engine HierarchicalTokenPooler not available; update colpali-engine" + ) + return pooler_cls() + + +def _move_to_device(batch: Any, device: Any) -> Any: + if device is None: + return batch + torch = _torch() + if isinstance(device, str): + device_obj = torch.device(device) + else: + device_obj = device + if isinstance(batch, dict): + return {k: _move_to_device(v, device_obj) for k, v in batch.items()} + if hasattr(batch, "to"): + return batch.to(device_obj) + return batch + + +@register("multimodal-late-interaction") +class MultimodalLateInteractionEmbeddings(EmbeddingFunction): + """Late-interaction embeddings for ViDoRe models.""" + + model_name: str = "vidore/colSmol-256M" + model_family: Optional[str] = None device: str = "auto" dtype: str = "bfloat16" use_token_pooling: bool = True pool_factor: int = 2 + batch_size: int = 4 quantization_config: Optional[Any] = None - batch_size: int = 2 - _model = None - _processor = None - _token_pooler = None - _vector_dim = None - - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - ( - self._model, - self._processor, - self._token_pooler, - ) = self._load_model( - self.model_name, - self.dtype, - self.device, - self.use_token_pooling, - self.quantization_config, - ) + self._family = self._resolve_family(self.model_name, self.model_family) + self._vector_dim: Optional[int] = None - @staticmethod - @lru_cache(maxsize=1) - def _load_model( - model_name: str, - dtype: str, - device: str, - use_token_pooling: bool, - quantization_config: Optional[Any], - ): - """ - Initialize and cache the ColPali model, processor, and token pooler. - """ - torch = attempt_import_or_raise("torch", "torch") + @property + def model(self) -> Any: + """The cached model.""" + return self._get_models()[0] + + @property + def processor(self) -> Any: + """The cached processor.""" + return self._get_models()[1] + + @property + def pooler(self) -> Optional[Any]: + """The cached pooler.""" + return self._get_models()[2] + + @property + def target_device(self) -> Optional[Any]: + """The cached target device.""" + return self._get_models()[3] + + # ------------------------------------------------------------------ + # Family detection + # ------------------------------------------------------------------ + @classmethod + def _resolve_family(cls, model_name: str, explicit: Optional[str]) -> str: + if explicit: + family = explicit.lower() + if family not in _FAMILY_CLASSES: + raise ValueError( + "Unknown model_family '{}'. Expected one of {}".format( + explicit, ", ".join(sorted(_FAMILY_CLASSES)) + ) + ) + return family + + lowered = model_name.lower() + for family, aliases in _FAMILY_ALIASES.items(): + if any(alias in lowered for alias in aliases): + return family + return "colpali" + + # ------------------------------------------------------------------ + # Model loading + # ------------------------------------------------------------------ + @weak_lru(maxsize=1) + def _get_models(self) -> tuple[Any, Optional[Any], Optional[Any], Optional[Any]]: + colpali_engine = attempt_import_or_raise("colpali_engine", "colpali-engine") transformers = attempt_import_or_raise("transformers", "transformers") - colpali_engine = attempt_import_or_raise("colpali_engine", "colpali_engine") - from colpali_engine.compression.token_pooling import HierarchicalTokenPooler - if quantization_config is not None: - if not isinstance(quantization_config, transformers.BitsAndBytesConfig): - raise ValueError("quantization_config must be a BitsAndBytesConfig") + if ( + self.quantization_config is not None + and not isinstance( + self.quantization_config, transformers.BitsAndBytesConfig + ) + ): + raise ValueError( + "quantization_config must be a transformers.BitsAndBytesConfig instance" + ) - if dtype == "bfloat16": - torch_dtype = torch.bfloat16 - elif dtype == "float16": - torch_dtype = torch.float16 - elif dtype == "float64": - torch_dtype = torch.float64 + model_cls_name, processor_cls_name = _FAMILY_CLASSES[self._family] + model_cls = getattr(colpali_engine.models, model_cls_name) + processor_cls = getattr(colpali_engine.models, processor_cls_name) + + torch = _torch() + device_map = self.device + target_device: Optional[Any] = None + if device_map == "auto": + if torch.cuda.is_available(): + device_map = "cuda:0" + target_device = torch.device("cuda:0") + elif ( + getattr(torch.backends, "mps", None) + and torch.backends.mps.is_available() + ): + device_map = "mps" + target_device = torch.device("mps") + else: + device_map = "cpu" + target_device = torch.device("cpu") else: + try: + target_device = torch.device(device_map) + except (TypeError, ValueError): # pragma: no cover - device map dicts + target_device = None + + torch_dtype = _torch_dtype(self.dtype) + if isinstance(device_map, str) and device_map == "cpu" and torch_dtype in { + torch.bfloat16, + torch.float16, + }: torch_dtype = torch.float32 - model = colpali_engine.models.ColQwen2_5.from_pretrained( - model_name, - torch_dtype=torch_dtype, - device_map=device, - quantization_config=quantization_config - if quantization_config is not None - else None, - attn_implementation="flash_attention_2" - if is_flash_attn_2_available() - else None, - ).eval() - processor = colpali_engine.models.ColQwen2_5_Processor.from_pretrained( - model_name + load_kwargs: Dict[str, Any] = { + "torch_dtype": torch_dtype, + "device_map": device_map, + } + if self.quantization_config is not None: + load_kwargs["quantization_config"] = self.quantization_config + attn_impl = "flash_attention_2" if is_flash_attn_2_available() else None + if attn_impl is not None: + load_kwargs["attn_implementation"] = attn_impl + + model = model_cls.from_pretrained(self.model_name, **load_kwargs) + if hasattr(model, "eval"): + model = model.eval() + + processor = processor_cls.from_pretrained(self.model_name) + pooler = _load_pooler(self.use_token_pooling) + if target_device is None and hasattr(model, "device"): + target_device = getattr(model, "device") + + return model, processor, pooler, target_device + + # ------------------------------------------------------------------ + # Encoding helpers + # ------------------------------------------------------------------ + def _pool_tensor(self, tensor: Any) -> Any: + if self.pooler is None: + return tensor + torch = _torch() + assert isinstance(tensor, torch.Tensor) + expanded = False + if tensor.ndim == 2: + tensor = tensor.unsqueeze(0) + expanded = True + kwargs = {"pool_factor": self.pool_factor, "padding": True} + tokenizer = getattr( + getattr(self.processor, "tokenizer", None), "padding_side", None ) - token_pooler = HierarchicalTokenPooler() if use_token_pooling else None - return model, processor, token_pooler - - def ndims(self): - """ - Return the dimension of a vector in the multivector output (e.g., 128). - """ - torch = attempt_import_or_raise("torch", "torch") - if self._vector_dim is None: - dummy_query = "test" - batch_queries = self._processor.process_queries([dummy_query]).to( - self._model.device - ) - with torch.no_grad(): - query_embeddings = self._model(**batch_queries) - - if self.use_token_pooling and self._token_pooler is not None: - query_embeddings = self._token_pooler.pool_embeddings( - query_embeddings, - pool_factor=self.pool_factor, - padding=True, - padding_side=self._processor.tokenizer.padding_side, - ) - - self._vector_dim = query_embeddings[0].shape[-1] - return self._vector_dim - - def _process_embeddings(self, embeddings): - """ - Format model embeddings into List[List[float]]. - Use token pooling if enabled. - """ - torch = attempt_import_or_raise("torch", "torch") - if self.use_token_pooling and self._token_pooler is not None: - embeddings = self._token_pooler.pool_embeddings( - embeddings, - pool_factor=self.pool_factor, - padding=True, - padding_side=self._processor.tokenizer.padding_side, - ) + if tokenizer is not None: + kwargs["padding_side"] = tokenizer + pooled = self.pooler.pool_embeddings(tensor, **kwargs) + if expanded: + pooled = pooled.squeeze(0) + return pooled + def _normalize_output(self, embeddings: Any) -> List[List[List[float]]]: + torch = _torch() + if hasattr(embeddings, "last_hidden_state"): + return self._normalize_output(embeddings.last_hidden_state) + if isinstance(embeddings, dict) and "last_hidden_state" in embeddings: + return self._normalize_output(embeddings["last_hidden_state"]) if isinstance(embeddings, torch.Tensor): - tensors = embeddings.detach().cpu() - if tensors.dtype == torch.bfloat16: - tensors = tensors.to(torch.float32) - return ( - tensors.numpy() - .astype(np.float64 if self.dtype == "float64" else np.float32) - .tolist() - ) - return [] + pooled = self._pool_tensor(embeddings).detach().cpu() + if pooled.ndim == 2: + pooled = pooled.unsqueeze(0) + target = torch.float64 if self.dtype == "float64" else torch.float32 + return pooled.to(target).numpy().tolist() + if isinstance(embeddings, (list, tuple)): + results: List[List[List[float]]] = [] + for item in embeddings: + results.extend(self._normalize_output(item)) + return results + raise TypeError(f"Unsupported embedding type {type(embeddings)}") - def generate_text_embeddings(self, text: TEXT) -> List[List[List[float]]]: - """ - Generate embeddings for text input. - """ - torch = attempt_import_or_raise("torch", "torch") - text = self.sanitize_input(text) - all_embeddings = [] + # ------------------------------------------------------------------ + # Text encoding + # ------------------------------------------------------------------ + def _encode_text(self, batch: Sequence[str]) -> List[List[List[float]]]: + if not self.processor or not hasattr(self.processor, "process_queries"): + raise RuntimeError("Processor for text queries is not available for this model") + payload = self.processor.process_queries(batch) + payload = _move_to_device(payload, self.target_device) + torch = _torch() + with torch.no_grad(): + outputs = self.model(**payload) + return self._normalize_output(outputs) - for i in range(0, len(text), self.batch_size): - batch_text = text[i : i + self.batch_size] - batch_queries = self._processor.process_queries(batch_text).to( - self._model.device - ) - with torch.no_grad(): - query_embeddings = self._model(**batch_queries) - all_embeddings.extend(self._process_embeddings(query_embeddings)) - return all_embeddings - - def _prepare_images(self, images: IMAGES) -> List: - """ - Convert image inputs to PIL Images. - """ + # ------------------------------------------------------------------ + # Image encoding + # ------------------------------------------------------------------ + def _prepare_images(self, images: IMAGES) -> List[Any]: PIL = attempt_import_or_raise("PIL", "pillow") requests = attempt_import_or_raise("requests", "requests") - images = self.sanitize_input(images) - pil_images = [] - try: - for image in images: - if isinstance(image, str): - if image.startswith(("http://", "https://")): - response = requests.get(image, timeout=10) - response.raise_for_status() - pil_images.append(PIL.Image.open(io.BytesIO(response.content))) - else: - with PIL.Image.open(image) as im: - pil_images.append(im.copy()) - elif isinstance(image, bytes): - pil_images.append(PIL.Image.open(io.BytesIO(image))) - else: - # Assume it's a PIL Image; will raise if invalid - pil_images.append(image) - except Exception as e: - raise ValueError(f"Failed to process image: {e}") + prepared: List[Any] = [] + for image in self.sanitize_input(images): + if isinstance(image, str) and image.startswith(("http://", "https://")): + response = requests.get(image, timeout=10) + response.raise_for_status() + prepared.append(PIL.Image.open(io.BytesIO(response.content))) + elif isinstance(image, str): + with PIL.Image.open(image) as img: + prepared.append(img.copy()) + elif isinstance(image, bytes): + prepared.append(PIL.Image.open(io.BytesIO(image))) + else: + prepared.append(image) + return prepared - return pil_images + def _encode_images(self, images: Sequence[Any]) -> List[List[List[float]]]: + if not self.processor or not hasattr(self.processor, "process_images"): + raise RuntimeError("Processor for images is not available for this model") + payload = self.processor.process_images(images) + payload = _move_to_device(payload, self.target_device) + torch = _torch() + with torch.no_grad(): + outputs = self.model(**payload) + return self._normalize_output(outputs) + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + def _batched(self, values: Sequence[Any], encoder) -> List[List[List[float]]]: + results: List[List[List[float]]] = [] + for start in range(0, len(values), self.batch_size): + chunk = values[start : start + self.batch_size] + results.extend(encoder(chunk)) + return results + + def generate_text_embeddings(self, text: TEXT) -> List[List[List[float]]]: + text = self.sanitize_input(text) + if len(text) == 0: + return [] + return self._batched(text, self._encode_text) def generate_image_embeddings(self, images: IMAGES) -> List[List[List[float]]]: - """ - Generate embeddings for a batch of images. - """ - torch = attempt_import_or_raise("torch", "torch") - pil_images = self._prepare_images(images) - all_embeddings = [] - - for i in range(0, len(pil_images), self.batch_size): - batch_images = pil_images[i : i + self.batch_size] - batch_images = self._processor.process_images(batch_images).to( - self._model.device - ) - with torch.no_grad(): - image_embeddings = self._model(**batch_images) - all_embeddings.extend(self._process_embeddings(image_embeddings)) - return all_embeddings + prepared = self._prepare_images(images) + if len(prepared) == 0: + return [] + return self._batched(prepared, self._encode_images) def compute_query_embeddings( - self, query: Union[str, IMAGES], *args, **kwargs + self, query: str, *args: Any, **kwargs: Any ) -> List[List[List[float]]]: - """ - Compute embeddings for a single user query (text only). - """ if not isinstance(query, str): - raise ValueError( - "Query must be a string, image to image search is not supported" - ) + raise ValueError("Late interaction queries must be text") return self.generate_text_embeddings([query]) def compute_source_embeddings( - self, images: IMAGES, *args, **kwargs + self, images: IMAGES, *args: Any, **kwargs: Any ) -> List[List[List[float]]]: - """ - Compute embeddings for a batch of source images. - - Parameters - ---------- - images : Union[str, bytes, List, pa.Array, pa.ChunkedArray, np.ndarray] - Batch of images (paths, URLs, bytes, or PIL Images). - """ - images = self.sanitize_input(images) return self.generate_image_embeddings(images) + + def ndims(self) -> int: + if self._vector_dim is None: + probe = self.generate_text_embeddings(["probe"]) + if not probe or not probe[0]: + raise RuntimeError("Failed to determine embedding dimension") + self._vector_dim = len(probe[0][0]) + return self._vector_dim + + +# Backwards compatibility: keep the historical "colpali" key +register("colpali")(MultimodalLateInteractionEmbeddings) + + +# Legacy class name kept for backwards compatibility in imports +ColPaliEmbeddings = MultimodalLateInteractionEmbeddings diff --git a/python/python/tests/test_embeddings.py b/python/python/tests/test_embeddings.py index 2f01cf3b..373b48df 100644 --- a/python/python/tests/test_embeddings.py +++ b/python/python/tests/test_embeddings.py @@ -17,6 +17,7 @@ from lancedb.embeddings import ( EmbeddingFunctionRegistry, ) from lancedb.embeddings.base import TextEmbeddingFunction +from lancedb.embeddings.colpali import MultimodalLateInteractionEmbeddings from lancedb.embeddings.registry import get_registry, register from lancedb.embeddings.utils import retry from lancedb.pydantic import LanceModel, Vector @@ -515,3 +516,16 @@ def test_openai_propagates_api_key(monkeypatch): query = "greetings" actual = table.search(query).limit(1).to_pydantic(Words)[0] assert len(actual.text) > 0 + + +def test_multimodal_late_interaction_family_detection(): + resolver = MultimodalLateInteractionEmbeddings._resolve_family + + assert resolver("vidore/colSmol-256M", None) == "colsmol" + assert resolver("vidore/colqwen2.5-v0.2", None) == "colqwen2.5" + assert resolver("vidore/colqwen2-v1.0", None) == "colqwen2" + assert resolver("vidore/colpali-v1.3", None) == "colpali" + assert resolver("custom/model", None) == "colpali" + assert resolver("any/model", "colqwen2") == "colqwen2" + with pytest.raises(ValueError): + resolver("any/model", "unknown") diff --git a/python/python/tests/test_embeddings_slow.py b/python/python/tests/test_embeddings_slow.py index 50bf76e2..2035e2ee 100644 --- a/python/python/tests/test_embeddings_slow.py +++ b/python/python/tests/test_embeddings_slow.py @@ -4,6 +4,7 @@ import importlib import io import os + import lancedb import numpy as np import pandas as pd @@ -33,6 +34,24 @@ try: except Exception: _imagebind = None +try: + import torch +except ImportError: + torch = None + +HAS_ACCEL = bool( + torch + and ( + torch.cuda.is_available() + or getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() + ) +) +RUN_HEAVY_VIDORE = os.getenv("LANCEDB_TEST_FULL_LATE_INTERACTION") in {"1", "true", "yes"} +HEAVY_SKIP = pytest.mark.skipif( + not (RUN_HEAVY_VIDORE and HAS_ACCEL), + reason="Set LANCEDB_TEST_FULL_LATE_INTERACTION=1 and run on GPU to exercise large vidore checkpoints", +) + @pytest.mark.slow @pytest.mark.parametrize( @@ -597,21 +616,44 @@ def test_voyageai_multimodal_embedding_text_function(): importlib.util.find_spec("colpali_engine") is None, reason="colpali_engine not installed", ) -def test_colpali(tmp_path): +@pytest.mark.parametrize( + "model_name", + [ + #pytest.param("vidore/colSmol-256M", id="colSmol"), + pytest.param( + "vidore/colqwen2-v1.0", + id="colQwen2", + #marks=HEAVY_SKIP, + ), + pytest.param( + "vidore/colqwen2.5-v0.2", + id="colQwen2.5", + # marks=HEAVY_SKIP, + ), + pytest.param( + "vidore/colpali-v1.3-merged", + id="colPali", + #marks=HEAVY_SKIP, + ), + ], +) +def test_multimodal_late_interaction_models(tmp_path, model_name): import requests - from lancedb.pydantic import LanceModel + from lancedb.pydantic import LanceModel, Vector db = lancedb.connect(tmp_path) registry = get_registry() - func = registry.get("colpali").create() + func = registry.get("multimodal-late-interaction").create( + model_name=model_name, + device="auto", + batch_size=1, + ) class MediaItems(LanceModel): text: str image_uri: str = func.SourceField() image_bytes: bytes = func.SourceField() - image_vectors: MultiVector(func.ndims()) = ( - func.VectorField() - ) # Multivector image embeddings + image_vectors: MultiVector(func.ndims()) = func.VectorField() table = db.create_table("media", schema=MediaItems) @@ -619,41 +661,30 @@ def test_colpali(tmp_path): "a cute cat playing with yarn", "a puppy in a flower field", "a red sports car on the highway", - "a vintage bicycle leaning against a wall", - "a plate of delicious pasta", - "fresh fruit salad in a bowl", ] uris = [ "http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg", "http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg", - "http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg", "http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg", - "http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg", - "http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg", ] - # Get images as bytes image_bytes = [requests.get(uri).content for uri in uris] table.add( pd.DataFrame({"text": texts, "image_uri": uris, "image_bytes": image_bytes}) ) - # Test text-to-image search - image_results = ( + result = ( table.search("fluffy companion", vector_column_name="image_vectors") .limit(1) .to_pydantic(MediaItems)[0] ) - assert "cat" in image_results.text.lower() or "puppy" in image_results.text.lower() + assert any(keyword in result.text.lower() for keyword in ("cat", "puppy")) - # Verify multivector dimensions first_row = table.to_arrow().to_pylist()[0] - assert len(first_row["image_vectors"]) > 1, "Should have multiple image vectors" - assert len(first_row["image_vectors"][0]) == func.ndims(), ( - "Vector dimension mismatch" - ) + assert len(first_row["image_vectors"]) > 1 + assert len(first_row["image_vectors"][0]) == func.ndims() @pytest.mark.slow diff --git a/python/test.py b/python/test.py new file mode 100644 index 00000000..fb9d7bf5 --- /dev/null +++ b/python/test.py @@ -0,0 +1,58 @@ +import requests +from lancedb.pydantic import LanceModel, Vector +import importlib +import io +import os + +import lancedb +import numpy as np +import pandas as pd +import pyarrow as pa +import pytest +from lancedb.embeddings import get_registry +from lancedb.pydantic import LanceModel, Vector, MultiVector + +db = lancedb.connect("~/.db") +registry = get_registry() +func = registry.get("multimodal-late-interaction").create( + model_name="vidore/colQwen2.5-v0.2", + device="auto", + batch_size=1, +) + +class MediaItems(LanceModel): + text: str + image_uri: str = func.SourceField() + image_bytes: bytes = func.SourceField() + image_vectors: MultiVector(func.ndims()) = func.VectorField() + +table = db.create_table("media", schema=MediaItems, mode="overwrite") + +texts = [ + "a cute cat playing with yarn", + "a puppy in a flower field", + "a red sports car on the highway", +] + +uris = [ + "http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg", + "http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg", + "http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg", +] + +image_bytes = [requests.get(uri).content for uri in uris] + +table.add( + pd.DataFrame({"text": texts, "image_uri": uris, "image_bytes": image_bytes}) +) + +result = ( + table.search("fluffy companion", vector_column_name="image_vectors") + .limit(1) + .to_pydantic(MediaItems)[0] +) +assert any(keyword in result.text.lower() for keyword in ("cat", "puppy")) + +first_row = table.to_arrow().to_pylist()[0] +assert len(first_row["image_vectors"]) > 1 +assert len(first_row["image_vectors"][0]) == func.ndims() \ No newline at end of file