mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-22 21:09:58 +00:00
init
This commit is contained in:
@@ -80,7 +80,7 @@ embeddings = [
|
||||
"pillow",
|
||||
"open-clip-torch",
|
||||
"cohere",
|
||||
"colpali-engine>=0.3.10",
|
||||
"colpali-engine>=0.3.12",
|
||||
"huggingface_hub",
|
||||
"InstructorEmbedding",
|
||||
"google.generativeai",
|
||||
|
||||
@@ -19,5 +19,5 @@ from .imagebind import ImageBindEmbeddings
|
||||
from .jinaai import JinaEmbeddings
|
||||
from .watsonx import WatsonxEmbeddings
|
||||
from .voyageai import VoyageAIEmbeddingFunction
|
||||
from .colpali import ColPaliEmbeddings
|
||||
from .colpali import MultimodalLateInteractionEmbeddings, ColPaliEmbeddings # noqa: F401
|
||||
from .siglip import SigLipEmbeddings
|
||||
|
||||
508
python/python/lancedb/embeddings/colpali.py
Normal file → Executable file
508
python/python/lancedb/embeddings/colpali.py
Normal file → Executable file
@@ -1,255 +1,347 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
"""Late-interaction embeddings powered by colpali-engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import List, Union, Optional, Any
|
||||
import numpy as np
|
||||
import io
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import EmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import TEXT, IMAGES, is_flash_attn_2_available
|
||||
from .utils import IMAGES, TEXT, is_flash_attn_2_available, weak_lru
|
||||
|
||||
|
||||
@register("colpali")
|
||||
class ColPaliEmbeddings(EmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the ColPali engine for
|
||||
multimodal multi-vector embeddings.
|
||||
_FAMILY_ALIASES = {
|
||||
"colsmol": {"colsmol", "colsmolvlm", "smol"},
|
||||
"colqwen2.5": {"colqwen2.5", "colqwen25", "colqwen-2.5"},
|
||||
"colqwen2": {"colqwen2", "colqwen-2"},
|
||||
"colpali": {"colpali", "paligemma"},
|
||||
}
|
||||
|
||||
This embedding function supports ColQwen2.5 models, producing multivector outputs
|
||||
for both text and image inputs. The output embeddings are lists of vectors, each
|
||||
vector being 128-dimensional by default, represented as List[List[float]].
|
||||
_FAMILY_CLASSES = {
|
||||
"colpali": ("ColPali", "ColPaliProcessor"),
|
||||
"colqwen2.5": ("ColQwen2_5", "ColQwen2_5_Processor"),
|
||||
"colqwen2": ("ColQwen2", "ColQwen2Processor"),
|
||||
"colsmol": ("ColIdefics3", "ColIdefics3Processor"),
|
||||
}
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_name : str
|
||||
The name of the model to use (e.g., "Metric-AI/ColQwen2.5-3b-multilingual-v1.0")
|
||||
device : str
|
||||
The device for inference (default "cuda:0").
|
||||
dtype : str
|
||||
Data type for model weights (default "bfloat16").
|
||||
use_token_pooling : bool
|
||||
Whether to use token pooling to reduce embedding size (default True).
|
||||
pool_factor : int
|
||||
Factor to reduce sequence length if token pooling is enabled (default 2).
|
||||
quantization_config : Optional[BitsAndBytesConfig]
|
||||
Quantization configuration for the model. (default None, bitsandbytes needed)
|
||||
batch_size : int
|
||||
Batch size for processing inputs (default 2).
|
||||
"""
|
||||
|
||||
model_name: str = "Metric-AI/ColQwen2.5-3b-multilingual-v1.0"
|
||||
def _torch() -> Any:
|
||||
return attempt_import_or_raise("torch", "torch")
|
||||
|
||||
|
||||
def _torch_dtype(dtype: str) -> Any:
|
||||
torch = _torch()
|
||||
mapping = {
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float16": torch.float16,
|
||||
"float32": torch.float32,
|
||||
"float64": torch.float64,
|
||||
}
|
||||
if dtype not in mapping:
|
||||
raise ValueError(
|
||||
"Unsupported dtype '{}'. Expected one of {}".format(
|
||||
dtype, ", ".join(sorted(mapping))
|
||||
)
|
||||
)
|
||||
return mapping[dtype]
|
||||
|
||||
|
||||
def _load_pooler(use_pooler: bool) -> Optional[Any]:
|
||||
if not use_pooler:
|
||||
return None
|
||||
token_pooling = attempt_import_or_raise(
|
||||
"colpali_engine.compression.token_pooling", "colpali-engine"
|
||||
)
|
||||
pooler_cls = getattr(token_pooling, "HierarchicalTokenPooler", None)
|
||||
if pooler_cls is None:
|
||||
raise ImportError(
|
||||
"colpali_engine HierarchicalTokenPooler not available; update colpali-engine"
|
||||
)
|
||||
return pooler_cls()
|
||||
|
||||
|
||||
def _move_to_device(batch: Any, device: Any) -> Any:
|
||||
if device is None:
|
||||
return batch
|
||||
torch = _torch()
|
||||
if isinstance(device, str):
|
||||
device_obj = torch.device(device)
|
||||
else:
|
||||
device_obj = device
|
||||
if isinstance(batch, dict):
|
||||
return {k: _move_to_device(v, device_obj) for k, v in batch.items()}
|
||||
if hasattr(batch, "to"):
|
||||
return batch.to(device_obj)
|
||||
return batch
|
||||
|
||||
|
||||
@register("multimodal-late-interaction")
|
||||
class MultimodalLateInteractionEmbeddings(EmbeddingFunction):
|
||||
"""Late-interaction embeddings for ViDoRe models."""
|
||||
|
||||
model_name: str = "vidore/colSmol-256M"
|
||||
model_family: Optional[str] = None
|
||||
device: str = "auto"
|
||||
dtype: str = "bfloat16"
|
||||
use_token_pooling: bool = True
|
||||
pool_factor: int = 2
|
||||
batch_size: int = 4
|
||||
quantization_config: Optional[Any] = None
|
||||
batch_size: int = 2
|
||||
|
||||
_model = None
|
||||
_processor = None
|
||||
_token_pooler = None
|
||||
_vector_dim = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
(
|
||||
self._model,
|
||||
self._processor,
|
||||
self._token_pooler,
|
||||
) = self._load_model(
|
||||
self.model_name,
|
||||
self.dtype,
|
||||
self.device,
|
||||
self.use_token_pooling,
|
||||
self.quantization_config,
|
||||
)
|
||||
self._family = self._resolve_family(self.model_name, self.model_family)
|
||||
self._vector_dim: Optional[int] = None
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=1)
|
||||
def _load_model(
|
||||
model_name: str,
|
||||
dtype: str,
|
||||
device: str,
|
||||
use_token_pooling: bool,
|
||||
quantization_config: Optional[Any],
|
||||
):
|
||||
"""
|
||||
Initialize and cache the ColPali model, processor, and token pooler.
|
||||
"""
|
||||
torch = attempt_import_or_raise("torch", "torch")
|
||||
@property
|
||||
def model(self) -> Any:
|
||||
"""The cached model."""
|
||||
return self._get_models()[0]
|
||||
|
||||
@property
|
||||
def processor(self) -> Any:
|
||||
"""The cached processor."""
|
||||
return self._get_models()[1]
|
||||
|
||||
@property
|
||||
def pooler(self) -> Optional[Any]:
|
||||
"""The cached pooler."""
|
||||
return self._get_models()[2]
|
||||
|
||||
@property
|
||||
def target_device(self) -> Optional[Any]:
|
||||
"""The cached target device."""
|
||||
return self._get_models()[3]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Family detection
|
||||
# ------------------------------------------------------------------
|
||||
@classmethod
|
||||
def _resolve_family(cls, model_name: str, explicit: Optional[str]) -> str:
|
||||
if explicit:
|
||||
family = explicit.lower()
|
||||
if family not in _FAMILY_CLASSES:
|
||||
raise ValueError(
|
||||
"Unknown model_family '{}'. Expected one of {}".format(
|
||||
explicit, ", ".join(sorted(_FAMILY_CLASSES))
|
||||
)
|
||||
)
|
||||
return family
|
||||
|
||||
lowered = model_name.lower()
|
||||
for family, aliases in _FAMILY_ALIASES.items():
|
||||
if any(alias in lowered for alias in aliases):
|
||||
return family
|
||||
return "colpali"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Model loading
|
||||
# ------------------------------------------------------------------
|
||||
@weak_lru(maxsize=1)
|
||||
def _get_models(self) -> tuple[Any, Optional[Any], Optional[Any], Optional[Any]]:
|
||||
colpali_engine = attempt_import_or_raise("colpali_engine", "colpali-engine")
|
||||
transformers = attempt_import_or_raise("transformers", "transformers")
|
||||
colpali_engine = attempt_import_or_raise("colpali_engine", "colpali_engine")
|
||||
from colpali_engine.compression.token_pooling import HierarchicalTokenPooler
|
||||
|
||||
if quantization_config is not None:
|
||||
if not isinstance(quantization_config, transformers.BitsAndBytesConfig):
|
||||
raise ValueError("quantization_config must be a BitsAndBytesConfig")
|
||||
if (
|
||||
self.quantization_config is not None
|
||||
and not isinstance(
|
||||
self.quantization_config, transformers.BitsAndBytesConfig
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
"quantization_config must be a transformers.BitsAndBytesConfig instance"
|
||||
)
|
||||
|
||||
if dtype == "bfloat16":
|
||||
torch_dtype = torch.bfloat16
|
||||
elif dtype == "float16":
|
||||
torch_dtype = torch.float16
|
||||
elif dtype == "float64":
|
||||
torch_dtype = torch.float64
|
||||
model_cls_name, processor_cls_name = _FAMILY_CLASSES[self._family]
|
||||
model_cls = getattr(colpali_engine.models, model_cls_name)
|
||||
processor_cls = getattr(colpali_engine.models, processor_cls_name)
|
||||
|
||||
torch = _torch()
|
||||
device_map = self.device
|
||||
target_device: Optional[Any] = None
|
||||
if device_map == "auto":
|
||||
if torch.cuda.is_available():
|
||||
device_map = "cuda:0"
|
||||
target_device = torch.device("cuda:0")
|
||||
elif (
|
||||
getattr(torch.backends, "mps", None)
|
||||
and torch.backends.mps.is_available()
|
||||
):
|
||||
device_map = "mps"
|
||||
target_device = torch.device("mps")
|
||||
else:
|
||||
device_map = "cpu"
|
||||
target_device = torch.device("cpu")
|
||||
else:
|
||||
try:
|
||||
target_device = torch.device(device_map)
|
||||
except (TypeError, ValueError): # pragma: no cover - device map dicts
|
||||
target_device = None
|
||||
|
||||
torch_dtype = _torch_dtype(self.dtype)
|
||||
if isinstance(device_map, str) and device_map == "cpu" and torch_dtype in {
|
||||
torch.bfloat16,
|
||||
torch.float16,
|
||||
}:
|
||||
torch_dtype = torch.float32
|
||||
|
||||
model = colpali_engine.models.ColQwen2_5.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=device,
|
||||
quantization_config=quantization_config
|
||||
if quantization_config is not None
|
||||
else None,
|
||||
attn_implementation="flash_attention_2"
|
||||
if is_flash_attn_2_available()
|
||||
else None,
|
||||
).eval()
|
||||
processor = colpali_engine.models.ColQwen2_5_Processor.from_pretrained(
|
||||
model_name
|
||||
load_kwargs: Dict[str, Any] = {
|
||||
"torch_dtype": torch_dtype,
|
||||
"device_map": device_map,
|
||||
}
|
||||
if self.quantization_config is not None:
|
||||
load_kwargs["quantization_config"] = self.quantization_config
|
||||
attn_impl = "flash_attention_2" if is_flash_attn_2_available() else None
|
||||
if attn_impl is not None:
|
||||
load_kwargs["attn_implementation"] = attn_impl
|
||||
|
||||
model = model_cls.from_pretrained(self.model_name, **load_kwargs)
|
||||
if hasattr(model, "eval"):
|
||||
model = model.eval()
|
||||
|
||||
processor = processor_cls.from_pretrained(self.model_name)
|
||||
pooler = _load_pooler(self.use_token_pooling)
|
||||
if target_device is None and hasattr(model, "device"):
|
||||
target_device = getattr(model, "device")
|
||||
|
||||
return model, processor, pooler, target_device
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Encoding helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _pool_tensor(self, tensor: Any) -> Any:
|
||||
if self.pooler is None:
|
||||
return tensor
|
||||
torch = _torch()
|
||||
assert isinstance(tensor, torch.Tensor)
|
||||
expanded = False
|
||||
if tensor.ndim == 2:
|
||||
tensor = tensor.unsqueeze(0)
|
||||
expanded = True
|
||||
kwargs = {"pool_factor": self.pool_factor, "padding": True}
|
||||
tokenizer = getattr(
|
||||
getattr(self.processor, "tokenizer", None), "padding_side", None
|
||||
)
|
||||
token_pooler = HierarchicalTokenPooler() if use_token_pooling else None
|
||||
return model, processor, token_pooler
|
||||
|
||||
def ndims(self):
|
||||
"""
|
||||
Return the dimension of a vector in the multivector output (e.g., 128).
|
||||
"""
|
||||
torch = attempt_import_or_raise("torch", "torch")
|
||||
if self._vector_dim is None:
|
||||
dummy_query = "test"
|
||||
batch_queries = self._processor.process_queries([dummy_query]).to(
|
||||
self._model.device
|
||||
)
|
||||
with torch.no_grad():
|
||||
query_embeddings = self._model(**batch_queries)
|
||||
|
||||
if self.use_token_pooling and self._token_pooler is not None:
|
||||
query_embeddings = self._token_pooler.pool_embeddings(
|
||||
query_embeddings,
|
||||
pool_factor=self.pool_factor,
|
||||
padding=True,
|
||||
padding_side=self._processor.tokenizer.padding_side,
|
||||
)
|
||||
|
||||
self._vector_dim = query_embeddings[0].shape[-1]
|
||||
return self._vector_dim
|
||||
|
||||
def _process_embeddings(self, embeddings):
|
||||
"""
|
||||
Format model embeddings into List[List[float]].
|
||||
Use token pooling if enabled.
|
||||
"""
|
||||
torch = attempt_import_or_raise("torch", "torch")
|
||||
if self.use_token_pooling and self._token_pooler is not None:
|
||||
embeddings = self._token_pooler.pool_embeddings(
|
||||
embeddings,
|
||||
pool_factor=self.pool_factor,
|
||||
padding=True,
|
||||
padding_side=self._processor.tokenizer.padding_side,
|
||||
)
|
||||
if tokenizer is not None:
|
||||
kwargs["padding_side"] = tokenizer
|
||||
pooled = self.pooler.pool_embeddings(tensor, **kwargs)
|
||||
if expanded:
|
||||
pooled = pooled.squeeze(0)
|
||||
return pooled
|
||||
|
||||
def _normalize_output(self, embeddings: Any) -> List[List[List[float]]]:
|
||||
torch = _torch()
|
||||
if hasattr(embeddings, "last_hidden_state"):
|
||||
return self._normalize_output(embeddings.last_hidden_state)
|
||||
if isinstance(embeddings, dict) and "last_hidden_state" in embeddings:
|
||||
return self._normalize_output(embeddings["last_hidden_state"])
|
||||
if isinstance(embeddings, torch.Tensor):
|
||||
tensors = embeddings.detach().cpu()
|
||||
if tensors.dtype == torch.bfloat16:
|
||||
tensors = tensors.to(torch.float32)
|
||||
return (
|
||||
tensors.numpy()
|
||||
.astype(np.float64 if self.dtype == "float64" else np.float32)
|
||||
.tolist()
|
||||
)
|
||||
return []
|
||||
pooled = self._pool_tensor(embeddings).detach().cpu()
|
||||
if pooled.ndim == 2:
|
||||
pooled = pooled.unsqueeze(0)
|
||||
target = torch.float64 if self.dtype == "float64" else torch.float32
|
||||
return pooled.to(target).numpy().tolist()
|
||||
if isinstance(embeddings, (list, tuple)):
|
||||
results: List[List[List[float]]] = []
|
||||
for item in embeddings:
|
||||
results.extend(self._normalize_output(item))
|
||||
return results
|
||||
raise TypeError(f"Unsupported embedding type {type(embeddings)}")
|
||||
|
||||
def generate_text_embeddings(self, text: TEXT) -> List[List[List[float]]]:
|
||||
"""
|
||||
Generate embeddings for text input.
|
||||
"""
|
||||
torch = attempt_import_or_raise("torch", "torch")
|
||||
text = self.sanitize_input(text)
|
||||
all_embeddings = []
|
||||
# ------------------------------------------------------------------
|
||||
# Text encoding
|
||||
# ------------------------------------------------------------------
|
||||
def _encode_text(self, batch: Sequence[str]) -> List[List[List[float]]]:
|
||||
if not self.processor or not hasattr(self.processor, "process_queries"):
|
||||
raise RuntimeError("Processor for text queries is not available for this model")
|
||||
payload = self.processor.process_queries(batch)
|
||||
payload = _move_to_device(payload, self.target_device)
|
||||
torch = _torch()
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**payload)
|
||||
return self._normalize_output(outputs)
|
||||
|
||||
for i in range(0, len(text), self.batch_size):
|
||||
batch_text = text[i : i + self.batch_size]
|
||||
batch_queries = self._processor.process_queries(batch_text).to(
|
||||
self._model.device
|
||||
)
|
||||
with torch.no_grad():
|
||||
query_embeddings = self._model(**batch_queries)
|
||||
all_embeddings.extend(self._process_embeddings(query_embeddings))
|
||||
return all_embeddings
|
||||
|
||||
def _prepare_images(self, images: IMAGES) -> List:
|
||||
"""
|
||||
Convert image inputs to PIL Images.
|
||||
"""
|
||||
# ------------------------------------------------------------------
|
||||
# Image encoding
|
||||
# ------------------------------------------------------------------
|
||||
def _prepare_images(self, images: IMAGES) -> List[Any]:
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
requests = attempt_import_or_raise("requests", "requests")
|
||||
images = self.sanitize_input(images)
|
||||
pil_images = []
|
||||
try:
|
||||
for image in images:
|
||||
if isinstance(image, str):
|
||||
if image.startswith(("http://", "https://")):
|
||||
response = requests.get(image, timeout=10)
|
||||
response.raise_for_status()
|
||||
pil_images.append(PIL.Image.open(io.BytesIO(response.content)))
|
||||
else:
|
||||
with PIL.Image.open(image) as im:
|
||||
pil_images.append(im.copy())
|
||||
elif isinstance(image, bytes):
|
||||
pil_images.append(PIL.Image.open(io.BytesIO(image)))
|
||||
else:
|
||||
# Assume it's a PIL Image; will raise if invalid
|
||||
pil_images.append(image)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to process image: {e}")
|
||||
prepared: List[Any] = []
|
||||
for image in self.sanitize_input(images):
|
||||
if isinstance(image, str) and image.startswith(("http://", "https://")):
|
||||
response = requests.get(image, timeout=10)
|
||||
response.raise_for_status()
|
||||
prepared.append(PIL.Image.open(io.BytesIO(response.content)))
|
||||
elif isinstance(image, str):
|
||||
with PIL.Image.open(image) as img:
|
||||
prepared.append(img.copy())
|
||||
elif isinstance(image, bytes):
|
||||
prepared.append(PIL.Image.open(io.BytesIO(image)))
|
||||
else:
|
||||
prepared.append(image)
|
||||
return prepared
|
||||
|
||||
return pil_images
|
||||
def _encode_images(self, images: Sequence[Any]) -> List[List[List[float]]]:
|
||||
if not self.processor or not hasattr(self.processor, "process_images"):
|
||||
raise RuntimeError("Processor for images is not available for this model")
|
||||
payload = self.processor.process_images(images)
|
||||
payload = _move_to_device(payload, self.target_device)
|
||||
torch = _torch()
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**payload)
|
||||
return self._normalize_output(outputs)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
def _batched(self, values: Sequence[Any], encoder) -> List[List[List[float]]]:
|
||||
results: List[List[List[float]]] = []
|
||||
for start in range(0, len(values), self.batch_size):
|
||||
chunk = values[start : start + self.batch_size]
|
||||
results.extend(encoder(chunk))
|
||||
return results
|
||||
|
||||
def generate_text_embeddings(self, text: TEXT) -> List[List[List[float]]]:
|
||||
text = self.sanitize_input(text)
|
||||
if len(text) == 0:
|
||||
return []
|
||||
return self._batched(text, self._encode_text)
|
||||
|
||||
def generate_image_embeddings(self, images: IMAGES) -> List[List[List[float]]]:
|
||||
"""
|
||||
Generate embeddings for a batch of images.
|
||||
"""
|
||||
torch = attempt_import_or_raise("torch", "torch")
|
||||
pil_images = self._prepare_images(images)
|
||||
all_embeddings = []
|
||||
|
||||
for i in range(0, len(pil_images), self.batch_size):
|
||||
batch_images = pil_images[i : i + self.batch_size]
|
||||
batch_images = self._processor.process_images(batch_images).to(
|
||||
self._model.device
|
||||
)
|
||||
with torch.no_grad():
|
||||
image_embeddings = self._model(**batch_images)
|
||||
all_embeddings.extend(self._process_embeddings(image_embeddings))
|
||||
return all_embeddings
|
||||
prepared = self._prepare_images(images)
|
||||
if len(prepared) == 0:
|
||||
return []
|
||||
return self._batched(prepared, self._encode_images)
|
||||
|
||||
def compute_query_embeddings(
|
||||
self, query: Union[str, IMAGES], *args, **kwargs
|
||||
self, query: str, *args: Any, **kwargs: Any
|
||||
) -> List[List[List[float]]]:
|
||||
"""
|
||||
Compute embeddings for a single user query (text only).
|
||||
"""
|
||||
if not isinstance(query, str):
|
||||
raise ValueError(
|
||||
"Query must be a string, image to image search is not supported"
|
||||
)
|
||||
raise ValueError("Late interaction queries must be text")
|
||||
return self.generate_text_embeddings([query])
|
||||
|
||||
def compute_source_embeddings(
|
||||
self, images: IMAGES, *args, **kwargs
|
||||
self, images: IMAGES, *args: Any, **kwargs: Any
|
||||
) -> List[List[List[float]]]:
|
||||
"""
|
||||
Compute embeddings for a batch of source images.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
images : Union[str, bytes, List, pa.Array, pa.ChunkedArray, np.ndarray]
|
||||
Batch of images (paths, URLs, bytes, or PIL Images).
|
||||
"""
|
||||
images = self.sanitize_input(images)
|
||||
return self.generate_image_embeddings(images)
|
||||
|
||||
def ndims(self) -> int:
|
||||
if self._vector_dim is None:
|
||||
probe = self.generate_text_embeddings(["probe"])
|
||||
if not probe or not probe[0]:
|
||||
raise RuntimeError("Failed to determine embedding dimension")
|
||||
self._vector_dim = len(probe[0][0])
|
||||
return self._vector_dim
|
||||
|
||||
|
||||
# Backwards compatibility: keep the historical "colpali" key
|
||||
register("colpali")(MultimodalLateInteractionEmbeddings)
|
||||
|
||||
|
||||
# Legacy class name kept for backwards compatibility in imports
|
||||
ColPaliEmbeddings = MultimodalLateInteractionEmbeddings
|
||||
|
||||
@@ -17,6 +17,7 @@ from lancedb.embeddings import (
|
||||
EmbeddingFunctionRegistry,
|
||||
)
|
||||
from lancedb.embeddings.base import TextEmbeddingFunction
|
||||
from lancedb.embeddings.colpali import MultimodalLateInteractionEmbeddings
|
||||
from lancedb.embeddings.registry import get_registry, register
|
||||
from lancedb.embeddings.utils import retry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
@@ -515,3 +516,16 @@ def test_openai_propagates_api_key(monkeypatch):
|
||||
query = "greetings"
|
||||
actual = table.search(query).limit(1).to_pydantic(Words)[0]
|
||||
assert len(actual.text) > 0
|
||||
|
||||
|
||||
def test_multimodal_late_interaction_family_detection():
|
||||
resolver = MultimodalLateInteractionEmbeddings._resolve_family
|
||||
|
||||
assert resolver("vidore/colSmol-256M", None) == "colsmol"
|
||||
assert resolver("vidore/colqwen2.5-v0.2", None) == "colqwen2.5"
|
||||
assert resolver("vidore/colqwen2-v1.0", None) == "colqwen2"
|
||||
assert resolver("vidore/colpali-v1.3", None) == "colpali"
|
||||
assert resolver("custom/model", None) == "colpali"
|
||||
assert resolver("any/model", "colqwen2") == "colqwen2"
|
||||
with pytest.raises(ValueError):
|
||||
resolver("any/model", "unknown")
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
import importlib
|
||||
import io
|
||||
import os
|
||||
|
||||
import lancedb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -33,6 +34,24 @@ try:
|
||||
except Exception:
|
||||
_imagebind = None
|
||||
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
torch = None
|
||||
|
||||
HAS_ACCEL = bool(
|
||||
torch
|
||||
and (
|
||||
torch.cuda.is_available()
|
||||
or getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
|
||||
)
|
||||
)
|
||||
RUN_HEAVY_VIDORE = os.getenv("LANCEDB_TEST_FULL_LATE_INTERACTION") in {"1", "true", "yes"}
|
||||
HEAVY_SKIP = pytest.mark.skipif(
|
||||
not (RUN_HEAVY_VIDORE and HAS_ACCEL),
|
||||
reason="Set LANCEDB_TEST_FULL_LATE_INTERACTION=1 and run on GPU to exercise large vidore checkpoints",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize(
|
||||
@@ -597,21 +616,44 @@ def test_voyageai_multimodal_embedding_text_function():
|
||||
importlib.util.find_spec("colpali_engine") is None,
|
||||
reason="colpali_engine not installed",
|
||||
)
|
||||
def test_colpali(tmp_path):
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[
|
||||
#pytest.param("vidore/colSmol-256M", id="colSmol"),
|
||||
pytest.param(
|
||||
"vidore/colqwen2-v1.0",
|
||||
id="colQwen2",
|
||||
#marks=HEAVY_SKIP,
|
||||
),
|
||||
pytest.param(
|
||||
"vidore/colqwen2.5-v0.2",
|
||||
id="colQwen2.5",
|
||||
# marks=HEAVY_SKIP,
|
||||
),
|
||||
pytest.param(
|
||||
"vidore/colpali-v1.3-merged",
|
||||
id="colPali",
|
||||
#marks=HEAVY_SKIP,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_multimodal_late_interaction_models(tmp_path, model_name):
|
||||
import requests
|
||||
from lancedb.pydantic import LanceModel
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = get_registry()
|
||||
func = registry.get("colpali").create()
|
||||
func = registry.get("multimodal-late-interaction").create(
|
||||
model_name=model_name,
|
||||
device="auto",
|
||||
batch_size=1,
|
||||
)
|
||||
|
||||
class MediaItems(LanceModel):
|
||||
text: str
|
||||
image_uri: str = func.SourceField()
|
||||
image_bytes: bytes = func.SourceField()
|
||||
image_vectors: MultiVector(func.ndims()) = (
|
||||
func.VectorField()
|
||||
) # Multivector image embeddings
|
||||
image_vectors: MultiVector(func.ndims()) = func.VectorField()
|
||||
|
||||
table = db.create_table("media", schema=MediaItems)
|
||||
|
||||
@@ -619,41 +661,30 @@ def test_colpali(tmp_path):
|
||||
"a cute cat playing with yarn",
|
||||
"a puppy in a flower field",
|
||||
"a red sports car on the highway",
|
||||
"a vintage bicycle leaning against a wall",
|
||||
"a plate of delicious pasta",
|
||||
"fresh fruit salad in a bowl",
|
||||
]
|
||||
|
||||
uris = [
|
||||
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
||||
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
|
||||
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
|
||||
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
|
||||
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
|
||||
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
|
||||
]
|
||||
|
||||
# Get images as bytes
|
||||
image_bytes = [requests.get(uri).content for uri in uris]
|
||||
|
||||
table.add(
|
||||
pd.DataFrame({"text": texts, "image_uri": uris, "image_bytes": image_bytes})
|
||||
)
|
||||
|
||||
# Test text-to-image search
|
||||
image_results = (
|
||||
result = (
|
||||
table.search("fluffy companion", vector_column_name="image_vectors")
|
||||
.limit(1)
|
||||
.to_pydantic(MediaItems)[0]
|
||||
)
|
||||
assert "cat" in image_results.text.lower() or "puppy" in image_results.text.lower()
|
||||
assert any(keyword in result.text.lower() for keyword in ("cat", "puppy"))
|
||||
|
||||
# Verify multivector dimensions
|
||||
first_row = table.to_arrow().to_pylist()[0]
|
||||
assert len(first_row["image_vectors"]) > 1, "Should have multiple image vectors"
|
||||
assert len(first_row["image_vectors"][0]) == func.ndims(), (
|
||||
"Vector dimension mismatch"
|
||||
)
|
||||
assert len(first_row["image_vectors"]) > 1
|
||||
assert len(first_row["image_vectors"][0]) == func.ndims()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
|
||||
58
python/test.py
Normal file
58
python/test.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import requests
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
import importlib
|
||||
import io
|
||||
import os
|
||||
|
||||
import lancedb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector, MultiVector
|
||||
|
||||
db = lancedb.connect("~/.db")
|
||||
registry = get_registry()
|
||||
func = registry.get("multimodal-late-interaction").create(
|
||||
model_name="vidore/colQwen2.5-v0.2",
|
||||
device="auto",
|
||||
batch_size=1,
|
||||
)
|
||||
|
||||
class MediaItems(LanceModel):
|
||||
text: str
|
||||
image_uri: str = func.SourceField()
|
||||
image_bytes: bytes = func.SourceField()
|
||||
image_vectors: MultiVector(func.ndims()) = func.VectorField()
|
||||
|
||||
table = db.create_table("media", schema=MediaItems, mode="overwrite")
|
||||
|
||||
texts = [
|
||||
"a cute cat playing with yarn",
|
||||
"a puppy in a flower field",
|
||||
"a red sports car on the highway",
|
||||
]
|
||||
|
||||
uris = [
|
||||
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
||||
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
|
||||
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
|
||||
]
|
||||
|
||||
image_bytes = [requests.get(uri).content for uri in uris]
|
||||
|
||||
table.add(
|
||||
pd.DataFrame({"text": texts, "image_uri": uris, "image_bytes": image_bytes})
|
||||
)
|
||||
|
||||
result = (
|
||||
table.search("fluffy companion", vector_column_name="image_vectors")
|
||||
.limit(1)
|
||||
.to_pydantic(MediaItems)[0]
|
||||
)
|
||||
assert any(keyword in result.text.lower() for keyword in ("cat", "puppy"))
|
||||
|
||||
first_row = table.to_arrow().to_pylist()[0]
|
||||
assert len(first_row["image_vectors"]) > 1
|
||||
assert len(first_row["image_vectors"][0]) == func.ndims()
|
||||
Reference in New Issue
Block a user