mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-04 19:02:58 +00:00
Compare commits
1 Commits
codex/debu
...
ayush/pyla
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b766cbe0a9 |
@@ -80,7 +80,7 @@ embeddings = [
|
|||||||
"pillow",
|
"pillow",
|
||||||
"open-clip-torch",
|
"open-clip-torch",
|
||||||
"cohere",
|
"cohere",
|
||||||
"colpali-engine>=0.3.10",
|
"colpali-engine>=0.3.12",
|
||||||
"huggingface_hub",
|
"huggingface_hub",
|
||||||
"InstructorEmbedding",
|
"InstructorEmbedding",
|
||||||
"google.generativeai",
|
"google.generativeai",
|
||||||
|
|||||||
@@ -19,5 +19,5 @@ from .imagebind import ImageBindEmbeddings
|
|||||||
from .jinaai import JinaEmbeddings
|
from .jinaai import JinaEmbeddings
|
||||||
from .watsonx import WatsonxEmbeddings
|
from .watsonx import WatsonxEmbeddings
|
||||||
from .voyageai import VoyageAIEmbeddingFunction
|
from .voyageai import VoyageAIEmbeddingFunction
|
||||||
from .colpali import ColPaliEmbeddings
|
from .colpali import MultimodalLateInteractionEmbeddings, ColPaliEmbeddings # noqa: F401
|
||||||
from .siglip import SigLipEmbeddings
|
from .siglip import SigLipEmbeddings
|
||||||
|
|||||||
508
python/python/lancedb/embeddings/colpali.py
Normal file → Executable file
508
python/python/lancedb/embeddings/colpali.py
Normal file → Executable file
@@ -1,255 +1,347 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
"""Late-interaction embeddings powered by colpali-engine."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from functools import lru_cache
|
|
||||||
from typing import List, Union, Optional, Any
|
|
||||||
import numpy as np
|
|
||||||
import io
|
import io
|
||||||
|
from typing import Any, Dict, List, Optional, Sequence
|
||||||
|
|
||||||
from ..util import attempt_import_or_raise
|
from ..util import attempt_import_or_raise
|
||||||
from .base import EmbeddingFunction
|
from .base import EmbeddingFunction
|
||||||
from .registry import register
|
from .registry import register
|
||||||
from .utils import TEXT, IMAGES, is_flash_attn_2_available
|
from .utils import IMAGES, TEXT, is_flash_attn_2_available, weak_lru
|
||||||
|
|
||||||
|
|
||||||
@register("colpali")
|
_FAMILY_ALIASES = {
|
||||||
class ColPaliEmbeddings(EmbeddingFunction):
|
"colsmol": {"colsmol", "colsmolvlm", "smol"},
|
||||||
"""
|
"colqwen2.5": {"colqwen2.5", "colqwen25", "colqwen-2.5"},
|
||||||
An embedding function that uses the ColPali engine for
|
"colqwen2": {"colqwen2", "colqwen-2"},
|
||||||
multimodal multi-vector embeddings.
|
"colpali": {"colpali", "paligemma"},
|
||||||
|
}
|
||||||
|
|
||||||
This embedding function supports ColQwen2.5 models, producing multivector outputs
|
_FAMILY_CLASSES = {
|
||||||
for both text and image inputs. The output embeddings are lists of vectors, each
|
"colpali": ("ColPali", "ColPaliProcessor"),
|
||||||
vector being 128-dimensional by default, represented as List[List[float]].
|
"colqwen2.5": ("ColQwen2_5", "ColQwen2_5_Processor"),
|
||||||
|
"colqwen2": ("ColQwen2", "ColQwen2Processor"),
|
||||||
|
"colsmol": ("ColIdefics3", "ColIdefics3Processor"),
|
||||||
|
}
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
model_name : str
|
|
||||||
The name of the model to use (e.g., "Metric-AI/ColQwen2.5-3b-multilingual-v1.0")
|
|
||||||
device : str
|
|
||||||
The device for inference (default "cuda:0").
|
|
||||||
dtype : str
|
|
||||||
Data type for model weights (default "bfloat16").
|
|
||||||
use_token_pooling : bool
|
|
||||||
Whether to use token pooling to reduce embedding size (default True).
|
|
||||||
pool_factor : int
|
|
||||||
Factor to reduce sequence length if token pooling is enabled (default 2).
|
|
||||||
quantization_config : Optional[BitsAndBytesConfig]
|
|
||||||
Quantization configuration for the model. (default None, bitsandbytes needed)
|
|
||||||
batch_size : int
|
|
||||||
Batch size for processing inputs (default 2).
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_name: str = "Metric-AI/ColQwen2.5-3b-multilingual-v1.0"
|
def _torch() -> Any:
|
||||||
|
return attempt_import_or_raise("torch", "torch")
|
||||||
|
|
||||||
|
|
||||||
|
def _torch_dtype(dtype: str) -> Any:
|
||||||
|
torch = _torch()
|
||||||
|
mapping = {
|
||||||
|
"bfloat16": torch.bfloat16,
|
||||||
|
"float16": torch.float16,
|
||||||
|
"float32": torch.float32,
|
||||||
|
"float64": torch.float64,
|
||||||
|
}
|
||||||
|
if dtype not in mapping:
|
||||||
|
raise ValueError(
|
||||||
|
"Unsupported dtype '{}'. Expected one of {}".format(
|
||||||
|
dtype, ", ".join(sorted(mapping))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return mapping[dtype]
|
||||||
|
|
||||||
|
|
||||||
|
def _load_pooler(use_pooler: bool) -> Optional[Any]:
|
||||||
|
if not use_pooler:
|
||||||
|
return None
|
||||||
|
token_pooling = attempt_import_or_raise(
|
||||||
|
"colpali_engine.compression.token_pooling", "colpali-engine"
|
||||||
|
)
|
||||||
|
pooler_cls = getattr(token_pooling, "HierarchicalTokenPooler", None)
|
||||||
|
if pooler_cls is None:
|
||||||
|
raise ImportError(
|
||||||
|
"colpali_engine HierarchicalTokenPooler not available; update colpali-engine"
|
||||||
|
)
|
||||||
|
return pooler_cls()
|
||||||
|
|
||||||
|
|
||||||
|
def _move_to_device(batch: Any, device: Any) -> Any:
|
||||||
|
if device is None:
|
||||||
|
return batch
|
||||||
|
torch = _torch()
|
||||||
|
if isinstance(device, str):
|
||||||
|
device_obj = torch.device(device)
|
||||||
|
else:
|
||||||
|
device_obj = device
|
||||||
|
if isinstance(batch, dict):
|
||||||
|
return {k: _move_to_device(v, device_obj) for k, v in batch.items()}
|
||||||
|
if hasattr(batch, "to"):
|
||||||
|
return batch.to(device_obj)
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
@register("multimodal-late-interaction")
|
||||||
|
class MultimodalLateInteractionEmbeddings(EmbeddingFunction):
|
||||||
|
"""Late-interaction embeddings for ViDoRe models."""
|
||||||
|
|
||||||
|
model_name: str = "vidore/colSmol-256M"
|
||||||
|
model_family: Optional[str] = None
|
||||||
device: str = "auto"
|
device: str = "auto"
|
||||||
dtype: str = "bfloat16"
|
dtype: str = "bfloat16"
|
||||||
use_token_pooling: bool = True
|
use_token_pooling: bool = True
|
||||||
pool_factor: int = 2
|
pool_factor: int = 2
|
||||||
|
batch_size: int = 4
|
||||||
quantization_config: Optional[Any] = None
|
quantization_config: Optional[Any] = None
|
||||||
batch_size: int = 2
|
|
||||||
|
|
||||||
_model = None
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
_processor = None
|
|
||||||
_token_pooler = None
|
|
||||||
_vector_dim = None
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
(
|
self._family = self._resolve_family(self.model_name, self.model_family)
|
||||||
self._model,
|
self._vector_dim: Optional[int] = None
|
||||||
self._processor,
|
|
||||||
self._token_pooler,
|
|
||||||
) = self._load_model(
|
|
||||||
self.model_name,
|
|
||||||
self.dtype,
|
|
||||||
self.device,
|
|
||||||
self.use_token_pooling,
|
|
||||||
self.quantization_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@property
|
||||||
@lru_cache(maxsize=1)
|
def model(self) -> Any:
|
||||||
def _load_model(
|
"""The cached model."""
|
||||||
model_name: str,
|
return self._get_models()[0]
|
||||||
dtype: str,
|
|
||||||
device: str,
|
@property
|
||||||
use_token_pooling: bool,
|
def processor(self) -> Any:
|
||||||
quantization_config: Optional[Any],
|
"""The cached processor."""
|
||||||
):
|
return self._get_models()[1]
|
||||||
"""
|
|
||||||
Initialize and cache the ColPali model, processor, and token pooler.
|
@property
|
||||||
"""
|
def pooler(self) -> Optional[Any]:
|
||||||
torch = attempt_import_or_raise("torch", "torch")
|
"""The cached pooler."""
|
||||||
|
return self._get_models()[2]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def target_device(self) -> Optional[Any]:
|
||||||
|
"""The cached target device."""
|
||||||
|
return self._get_models()[3]
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Family detection
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
@classmethod
|
||||||
|
def _resolve_family(cls, model_name: str, explicit: Optional[str]) -> str:
|
||||||
|
if explicit:
|
||||||
|
family = explicit.lower()
|
||||||
|
if family not in _FAMILY_CLASSES:
|
||||||
|
raise ValueError(
|
||||||
|
"Unknown model_family '{}'. Expected one of {}".format(
|
||||||
|
explicit, ", ".join(sorted(_FAMILY_CLASSES))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return family
|
||||||
|
|
||||||
|
lowered = model_name.lower()
|
||||||
|
for family, aliases in _FAMILY_ALIASES.items():
|
||||||
|
if any(alias in lowered for alias in aliases):
|
||||||
|
return family
|
||||||
|
return "colpali"
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Model loading
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
@weak_lru(maxsize=1)
|
||||||
|
def _get_models(self) -> tuple[Any, Optional[Any], Optional[Any], Optional[Any]]:
|
||||||
|
colpali_engine = attempt_import_or_raise("colpali_engine", "colpali-engine")
|
||||||
transformers = attempt_import_or_raise("transformers", "transformers")
|
transformers = attempt_import_or_raise("transformers", "transformers")
|
||||||
colpali_engine = attempt_import_or_raise("colpali_engine", "colpali_engine")
|
|
||||||
from colpali_engine.compression.token_pooling import HierarchicalTokenPooler
|
|
||||||
|
|
||||||
if quantization_config is not None:
|
if (
|
||||||
if not isinstance(quantization_config, transformers.BitsAndBytesConfig):
|
self.quantization_config is not None
|
||||||
raise ValueError("quantization_config must be a BitsAndBytesConfig")
|
and not isinstance(
|
||||||
|
self.quantization_config, transformers.BitsAndBytesConfig
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"quantization_config must be a transformers.BitsAndBytesConfig instance"
|
||||||
|
)
|
||||||
|
|
||||||
if dtype == "bfloat16":
|
model_cls_name, processor_cls_name = _FAMILY_CLASSES[self._family]
|
||||||
torch_dtype = torch.bfloat16
|
model_cls = getattr(colpali_engine.models, model_cls_name)
|
||||||
elif dtype == "float16":
|
processor_cls = getattr(colpali_engine.models, processor_cls_name)
|
||||||
torch_dtype = torch.float16
|
|
||||||
elif dtype == "float64":
|
torch = _torch()
|
||||||
torch_dtype = torch.float64
|
device_map = self.device
|
||||||
|
target_device: Optional[Any] = None
|
||||||
|
if device_map == "auto":
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device_map = "cuda:0"
|
||||||
|
target_device = torch.device("cuda:0")
|
||||||
|
elif (
|
||||||
|
getattr(torch.backends, "mps", None)
|
||||||
|
and torch.backends.mps.is_available()
|
||||||
|
):
|
||||||
|
device_map = "mps"
|
||||||
|
target_device = torch.device("mps")
|
||||||
|
else:
|
||||||
|
device_map = "cpu"
|
||||||
|
target_device = torch.device("cpu")
|
||||||
else:
|
else:
|
||||||
|
try:
|
||||||
|
target_device = torch.device(device_map)
|
||||||
|
except (TypeError, ValueError): # pragma: no cover - device map dicts
|
||||||
|
target_device = None
|
||||||
|
|
||||||
|
torch_dtype = _torch_dtype(self.dtype)
|
||||||
|
if isinstance(device_map, str) and device_map == "cpu" and torch_dtype in {
|
||||||
|
torch.bfloat16,
|
||||||
|
torch.float16,
|
||||||
|
}:
|
||||||
torch_dtype = torch.float32
|
torch_dtype = torch.float32
|
||||||
|
|
||||||
model = colpali_engine.models.ColQwen2_5.from_pretrained(
|
load_kwargs: Dict[str, Any] = {
|
||||||
model_name,
|
"torch_dtype": torch_dtype,
|
||||||
torch_dtype=torch_dtype,
|
"device_map": device_map,
|
||||||
device_map=device,
|
}
|
||||||
quantization_config=quantization_config
|
if self.quantization_config is not None:
|
||||||
if quantization_config is not None
|
load_kwargs["quantization_config"] = self.quantization_config
|
||||||
else None,
|
attn_impl = "flash_attention_2" if is_flash_attn_2_available() else None
|
||||||
attn_implementation="flash_attention_2"
|
if attn_impl is not None:
|
||||||
if is_flash_attn_2_available()
|
load_kwargs["attn_implementation"] = attn_impl
|
||||||
else None,
|
|
||||||
).eval()
|
model = model_cls.from_pretrained(self.model_name, **load_kwargs)
|
||||||
processor = colpali_engine.models.ColQwen2_5_Processor.from_pretrained(
|
if hasattr(model, "eval"):
|
||||||
model_name
|
model = model.eval()
|
||||||
|
|
||||||
|
processor = processor_cls.from_pretrained(self.model_name)
|
||||||
|
pooler = _load_pooler(self.use_token_pooling)
|
||||||
|
if target_device is None and hasattr(model, "device"):
|
||||||
|
target_device = getattr(model, "device")
|
||||||
|
|
||||||
|
return model, processor, pooler, target_device
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Encoding helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def _pool_tensor(self, tensor: Any) -> Any:
|
||||||
|
if self.pooler is None:
|
||||||
|
return tensor
|
||||||
|
torch = _torch()
|
||||||
|
assert isinstance(tensor, torch.Tensor)
|
||||||
|
expanded = False
|
||||||
|
if tensor.ndim == 2:
|
||||||
|
tensor = tensor.unsqueeze(0)
|
||||||
|
expanded = True
|
||||||
|
kwargs = {"pool_factor": self.pool_factor, "padding": True}
|
||||||
|
tokenizer = getattr(
|
||||||
|
getattr(self.processor, "tokenizer", None), "padding_side", None
|
||||||
)
|
)
|
||||||
token_pooler = HierarchicalTokenPooler() if use_token_pooling else None
|
if tokenizer is not None:
|
||||||
return model, processor, token_pooler
|
kwargs["padding_side"] = tokenizer
|
||||||
|
pooled = self.pooler.pool_embeddings(tensor, **kwargs)
|
||||||
def ndims(self):
|
if expanded:
|
||||||
"""
|
pooled = pooled.squeeze(0)
|
||||||
Return the dimension of a vector in the multivector output (e.g., 128).
|
return pooled
|
||||||
"""
|
|
||||||
torch = attempt_import_or_raise("torch", "torch")
|
|
||||||
if self._vector_dim is None:
|
|
||||||
dummy_query = "test"
|
|
||||||
batch_queries = self._processor.process_queries([dummy_query]).to(
|
|
||||||
self._model.device
|
|
||||||
)
|
|
||||||
with torch.no_grad():
|
|
||||||
query_embeddings = self._model(**batch_queries)
|
|
||||||
|
|
||||||
if self.use_token_pooling and self._token_pooler is not None:
|
|
||||||
query_embeddings = self._token_pooler.pool_embeddings(
|
|
||||||
query_embeddings,
|
|
||||||
pool_factor=self.pool_factor,
|
|
||||||
padding=True,
|
|
||||||
padding_side=self._processor.tokenizer.padding_side,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._vector_dim = query_embeddings[0].shape[-1]
|
|
||||||
return self._vector_dim
|
|
||||||
|
|
||||||
def _process_embeddings(self, embeddings):
|
|
||||||
"""
|
|
||||||
Format model embeddings into List[List[float]].
|
|
||||||
Use token pooling if enabled.
|
|
||||||
"""
|
|
||||||
torch = attempt_import_or_raise("torch", "torch")
|
|
||||||
if self.use_token_pooling and self._token_pooler is not None:
|
|
||||||
embeddings = self._token_pooler.pool_embeddings(
|
|
||||||
embeddings,
|
|
||||||
pool_factor=self.pool_factor,
|
|
||||||
padding=True,
|
|
||||||
padding_side=self._processor.tokenizer.padding_side,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
def _normalize_output(self, embeddings: Any) -> List[List[List[float]]]:
|
||||||
|
torch = _torch()
|
||||||
|
if hasattr(embeddings, "last_hidden_state"):
|
||||||
|
return self._normalize_output(embeddings.last_hidden_state)
|
||||||
|
if isinstance(embeddings, dict) and "last_hidden_state" in embeddings:
|
||||||
|
return self._normalize_output(embeddings["last_hidden_state"])
|
||||||
if isinstance(embeddings, torch.Tensor):
|
if isinstance(embeddings, torch.Tensor):
|
||||||
tensors = embeddings.detach().cpu()
|
pooled = self._pool_tensor(embeddings).detach().cpu()
|
||||||
if tensors.dtype == torch.bfloat16:
|
if pooled.ndim == 2:
|
||||||
tensors = tensors.to(torch.float32)
|
pooled = pooled.unsqueeze(0)
|
||||||
return (
|
target = torch.float64 if self.dtype == "float64" else torch.float32
|
||||||
tensors.numpy()
|
return pooled.to(target).numpy().tolist()
|
||||||
.astype(np.float64 if self.dtype == "float64" else np.float32)
|
if isinstance(embeddings, (list, tuple)):
|
||||||
.tolist()
|
results: List[List[List[float]]] = []
|
||||||
)
|
for item in embeddings:
|
||||||
return []
|
results.extend(self._normalize_output(item))
|
||||||
|
return results
|
||||||
|
raise TypeError(f"Unsupported embedding type {type(embeddings)}")
|
||||||
|
|
||||||
def generate_text_embeddings(self, text: TEXT) -> List[List[List[float]]]:
|
# ------------------------------------------------------------------
|
||||||
"""
|
# Text encoding
|
||||||
Generate embeddings for text input.
|
# ------------------------------------------------------------------
|
||||||
"""
|
def _encode_text(self, batch: Sequence[str]) -> List[List[List[float]]]:
|
||||||
torch = attempt_import_or_raise("torch", "torch")
|
if not self.processor or not hasattr(self.processor, "process_queries"):
|
||||||
text = self.sanitize_input(text)
|
raise RuntimeError("Processor for text queries is not available for this model")
|
||||||
all_embeddings = []
|
payload = self.processor.process_queries(batch)
|
||||||
|
payload = _move_to_device(payload, self.target_device)
|
||||||
|
torch = _torch()
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.model(**payload)
|
||||||
|
return self._normalize_output(outputs)
|
||||||
|
|
||||||
for i in range(0, len(text), self.batch_size):
|
# ------------------------------------------------------------------
|
||||||
batch_text = text[i : i + self.batch_size]
|
# Image encoding
|
||||||
batch_queries = self._processor.process_queries(batch_text).to(
|
# ------------------------------------------------------------------
|
||||||
self._model.device
|
def _prepare_images(self, images: IMAGES) -> List[Any]:
|
||||||
)
|
|
||||||
with torch.no_grad():
|
|
||||||
query_embeddings = self._model(**batch_queries)
|
|
||||||
all_embeddings.extend(self._process_embeddings(query_embeddings))
|
|
||||||
return all_embeddings
|
|
||||||
|
|
||||||
def _prepare_images(self, images: IMAGES) -> List:
|
|
||||||
"""
|
|
||||||
Convert image inputs to PIL Images.
|
|
||||||
"""
|
|
||||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||||
requests = attempt_import_or_raise("requests", "requests")
|
requests = attempt_import_or_raise("requests", "requests")
|
||||||
images = self.sanitize_input(images)
|
prepared: List[Any] = []
|
||||||
pil_images = []
|
for image in self.sanitize_input(images):
|
||||||
try:
|
if isinstance(image, str) and image.startswith(("http://", "https://")):
|
||||||
for image in images:
|
response = requests.get(image, timeout=10)
|
||||||
if isinstance(image, str):
|
response.raise_for_status()
|
||||||
if image.startswith(("http://", "https://")):
|
prepared.append(PIL.Image.open(io.BytesIO(response.content)))
|
||||||
response = requests.get(image, timeout=10)
|
elif isinstance(image, str):
|
||||||
response.raise_for_status()
|
with PIL.Image.open(image) as img:
|
||||||
pil_images.append(PIL.Image.open(io.BytesIO(response.content)))
|
prepared.append(img.copy())
|
||||||
else:
|
elif isinstance(image, bytes):
|
||||||
with PIL.Image.open(image) as im:
|
prepared.append(PIL.Image.open(io.BytesIO(image)))
|
||||||
pil_images.append(im.copy())
|
else:
|
||||||
elif isinstance(image, bytes):
|
prepared.append(image)
|
||||||
pil_images.append(PIL.Image.open(io.BytesIO(image)))
|
return prepared
|
||||||
else:
|
|
||||||
# Assume it's a PIL Image; will raise if invalid
|
|
||||||
pil_images.append(image)
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(f"Failed to process image: {e}")
|
|
||||||
|
|
||||||
return pil_images
|
def _encode_images(self, images: Sequence[Any]) -> List[List[List[float]]]:
|
||||||
|
if not self.processor or not hasattr(self.processor, "process_images"):
|
||||||
|
raise RuntimeError("Processor for images is not available for this model")
|
||||||
|
payload = self.processor.process_images(images)
|
||||||
|
payload = _move_to_device(payload, self.target_device)
|
||||||
|
torch = _torch()
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.model(**payload)
|
||||||
|
return self._normalize_output(outputs)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public API
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def _batched(self, values: Sequence[Any], encoder) -> List[List[List[float]]]:
|
||||||
|
results: List[List[List[float]]] = []
|
||||||
|
for start in range(0, len(values), self.batch_size):
|
||||||
|
chunk = values[start : start + self.batch_size]
|
||||||
|
results.extend(encoder(chunk))
|
||||||
|
return results
|
||||||
|
|
||||||
|
def generate_text_embeddings(self, text: TEXT) -> List[List[List[float]]]:
|
||||||
|
text = self.sanitize_input(text)
|
||||||
|
if len(text) == 0:
|
||||||
|
return []
|
||||||
|
return self._batched(text, self._encode_text)
|
||||||
|
|
||||||
def generate_image_embeddings(self, images: IMAGES) -> List[List[List[float]]]:
|
def generate_image_embeddings(self, images: IMAGES) -> List[List[List[float]]]:
|
||||||
"""
|
prepared = self._prepare_images(images)
|
||||||
Generate embeddings for a batch of images.
|
if len(prepared) == 0:
|
||||||
"""
|
return []
|
||||||
torch = attempt_import_or_raise("torch", "torch")
|
return self._batched(prepared, self._encode_images)
|
||||||
pil_images = self._prepare_images(images)
|
|
||||||
all_embeddings = []
|
|
||||||
|
|
||||||
for i in range(0, len(pil_images), self.batch_size):
|
|
||||||
batch_images = pil_images[i : i + self.batch_size]
|
|
||||||
batch_images = self._processor.process_images(batch_images).to(
|
|
||||||
self._model.device
|
|
||||||
)
|
|
||||||
with torch.no_grad():
|
|
||||||
image_embeddings = self._model(**batch_images)
|
|
||||||
all_embeddings.extend(self._process_embeddings(image_embeddings))
|
|
||||||
return all_embeddings
|
|
||||||
|
|
||||||
def compute_query_embeddings(
|
def compute_query_embeddings(
|
||||||
self, query: Union[str, IMAGES], *args, **kwargs
|
self, query: str, *args: Any, **kwargs: Any
|
||||||
) -> List[List[List[float]]]:
|
) -> List[List[List[float]]]:
|
||||||
"""
|
|
||||||
Compute embeddings for a single user query (text only).
|
|
||||||
"""
|
|
||||||
if not isinstance(query, str):
|
if not isinstance(query, str):
|
||||||
raise ValueError(
|
raise ValueError("Late interaction queries must be text")
|
||||||
"Query must be a string, image to image search is not supported"
|
|
||||||
)
|
|
||||||
return self.generate_text_embeddings([query])
|
return self.generate_text_embeddings([query])
|
||||||
|
|
||||||
def compute_source_embeddings(
|
def compute_source_embeddings(
|
||||||
self, images: IMAGES, *args, **kwargs
|
self, images: IMAGES, *args: Any, **kwargs: Any
|
||||||
) -> List[List[List[float]]]:
|
) -> List[List[List[float]]]:
|
||||||
"""
|
|
||||||
Compute embeddings for a batch of source images.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
images : Union[str, bytes, List, pa.Array, pa.ChunkedArray, np.ndarray]
|
|
||||||
Batch of images (paths, URLs, bytes, or PIL Images).
|
|
||||||
"""
|
|
||||||
images = self.sanitize_input(images)
|
|
||||||
return self.generate_image_embeddings(images)
|
return self.generate_image_embeddings(images)
|
||||||
|
|
||||||
|
def ndims(self) -> int:
|
||||||
|
if self._vector_dim is None:
|
||||||
|
probe = self.generate_text_embeddings(["probe"])
|
||||||
|
if not probe or not probe[0]:
|
||||||
|
raise RuntimeError("Failed to determine embedding dimension")
|
||||||
|
self._vector_dim = len(probe[0][0])
|
||||||
|
return self._vector_dim
|
||||||
|
|
||||||
|
|
||||||
|
# Backwards compatibility: keep the historical "colpali" key
|
||||||
|
register("colpali")(MultimodalLateInteractionEmbeddings)
|
||||||
|
|
||||||
|
|
||||||
|
# Legacy class name kept for backwards compatibility in imports
|
||||||
|
ColPaliEmbeddings = MultimodalLateInteractionEmbeddings
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from lancedb.embeddings import (
|
|||||||
EmbeddingFunctionRegistry,
|
EmbeddingFunctionRegistry,
|
||||||
)
|
)
|
||||||
from lancedb.embeddings.base import TextEmbeddingFunction
|
from lancedb.embeddings.base import TextEmbeddingFunction
|
||||||
|
from lancedb.embeddings.colpali import MultimodalLateInteractionEmbeddings
|
||||||
from lancedb.embeddings.registry import get_registry, register
|
from lancedb.embeddings.registry import get_registry, register
|
||||||
from lancedb.embeddings.utils import retry
|
from lancedb.embeddings.utils import retry
|
||||||
from lancedb.pydantic import LanceModel, Vector
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
@@ -515,3 +516,16 @@ def test_openai_propagates_api_key(monkeypatch):
|
|||||||
query = "greetings"
|
query = "greetings"
|
||||||
actual = table.search(query).limit(1).to_pydantic(Words)[0]
|
actual = table.search(query).limit(1).to_pydantic(Words)[0]
|
||||||
assert len(actual.text) > 0
|
assert len(actual.text) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_multimodal_late_interaction_family_detection():
|
||||||
|
resolver = MultimodalLateInteractionEmbeddings._resolve_family
|
||||||
|
|
||||||
|
assert resolver("vidore/colSmol-256M", None) == "colsmol"
|
||||||
|
assert resolver("vidore/colqwen2.5-v0.2", None) == "colqwen2.5"
|
||||||
|
assert resolver("vidore/colqwen2-v1.0", None) == "colqwen2"
|
||||||
|
assert resolver("vidore/colpali-v1.3", None) == "colpali"
|
||||||
|
assert resolver("custom/model", None) == "colpali"
|
||||||
|
assert resolver("any/model", "colqwen2") == "colqwen2"
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
resolver("any/model", "unknown")
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
import importlib
|
import importlib
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import lancedb
|
import lancedb
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@@ -33,6 +34,24 @@ try:
|
|||||||
except Exception:
|
except Exception:
|
||||||
_imagebind = None
|
_imagebind = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
except ImportError:
|
||||||
|
torch = None
|
||||||
|
|
||||||
|
HAS_ACCEL = bool(
|
||||||
|
torch
|
||||||
|
and (
|
||||||
|
torch.cuda.is_available()
|
||||||
|
or getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
RUN_HEAVY_VIDORE = os.getenv("LANCEDB_TEST_FULL_LATE_INTERACTION") in {"1", "true", "yes"}
|
||||||
|
HEAVY_SKIP = pytest.mark.skipif(
|
||||||
|
not (RUN_HEAVY_VIDORE and HAS_ACCEL),
|
||||||
|
reason="Set LANCEDB_TEST_FULL_LATE_INTERACTION=1 and run on GPU to exercise large vidore checkpoints",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -597,21 +616,44 @@ def test_voyageai_multimodal_embedding_text_function():
|
|||||||
importlib.util.find_spec("colpali_engine") is None,
|
importlib.util.find_spec("colpali_engine") is None,
|
||||||
reason="colpali_engine not installed",
|
reason="colpali_engine not installed",
|
||||||
)
|
)
|
||||||
def test_colpali(tmp_path):
|
@pytest.mark.parametrize(
|
||||||
|
"model_name",
|
||||||
|
[
|
||||||
|
#pytest.param("vidore/colSmol-256M", id="colSmol"),
|
||||||
|
pytest.param(
|
||||||
|
"vidore/colqwen2-v1.0",
|
||||||
|
id="colQwen2",
|
||||||
|
#marks=HEAVY_SKIP,
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"vidore/colqwen2.5-v0.2",
|
||||||
|
id="colQwen2.5",
|
||||||
|
# marks=HEAVY_SKIP,
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"vidore/colpali-v1.3-merged",
|
||||||
|
id="colPali",
|
||||||
|
#marks=HEAVY_SKIP,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_multimodal_late_interaction_models(tmp_path, model_name):
|
||||||
import requests
|
import requests
|
||||||
from lancedb.pydantic import LanceModel
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
|
|
||||||
db = lancedb.connect(tmp_path)
|
db = lancedb.connect(tmp_path)
|
||||||
registry = get_registry()
|
registry = get_registry()
|
||||||
func = registry.get("colpali").create()
|
func = registry.get("multimodal-late-interaction").create(
|
||||||
|
model_name=model_name,
|
||||||
|
device="auto",
|
||||||
|
batch_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
class MediaItems(LanceModel):
|
class MediaItems(LanceModel):
|
||||||
text: str
|
text: str
|
||||||
image_uri: str = func.SourceField()
|
image_uri: str = func.SourceField()
|
||||||
image_bytes: bytes = func.SourceField()
|
image_bytes: bytes = func.SourceField()
|
||||||
image_vectors: MultiVector(func.ndims()) = (
|
image_vectors: MultiVector(func.ndims()) = func.VectorField()
|
||||||
func.VectorField()
|
|
||||||
) # Multivector image embeddings
|
|
||||||
|
|
||||||
table = db.create_table("media", schema=MediaItems)
|
table = db.create_table("media", schema=MediaItems)
|
||||||
|
|
||||||
@@ -619,41 +661,30 @@ def test_colpali(tmp_path):
|
|||||||
"a cute cat playing with yarn",
|
"a cute cat playing with yarn",
|
||||||
"a puppy in a flower field",
|
"a puppy in a flower field",
|
||||||
"a red sports car on the highway",
|
"a red sports car on the highway",
|
||||||
"a vintage bicycle leaning against a wall",
|
|
||||||
"a plate of delicious pasta",
|
|
||||||
"fresh fruit salad in a bowl",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
uris = [
|
uris = [
|
||||||
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
||||||
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
|
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
|
||||||
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
|
|
||||||
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
|
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
|
||||||
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
|
|
||||||
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# Get images as bytes
|
|
||||||
image_bytes = [requests.get(uri).content for uri in uris]
|
image_bytes = [requests.get(uri).content for uri in uris]
|
||||||
|
|
||||||
table.add(
|
table.add(
|
||||||
pd.DataFrame({"text": texts, "image_uri": uris, "image_bytes": image_bytes})
|
pd.DataFrame({"text": texts, "image_uri": uris, "image_bytes": image_bytes})
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test text-to-image search
|
result = (
|
||||||
image_results = (
|
|
||||||
table.search("fluffy companion", vector_column_name="image_vectors")
|
table.search("fluffy companion", vector_column_name="image_vectors")
|
||||||
.limit(1)
|
.limit(1)
|
||||||
.to_pydantic(MediaItems)[0]
|
.to_pydantic(MediaItems)[0]
|
||||||
)
|
)
|
||||||
assert "cat" in image_results.text.lower() or "puppy" in image_results.text.lower()
|
assert any(keyword in result.text.lower() for keyword in ("cat", "puppy"))
|
||||||
|
|
||||||
# Verify multivector dimensions
|
|
||||||
first_row = table.to_arrow().to_pylist()[0]
|
first_row = table.to_arrow().to_pylist()[0]
|
||||||
assert len(first_row["image_vectors"]) > 1, "Should have multiple image vectors"
|
assert len(first_row["image_vectors"]) > 1
|
||||||
assert len(first_row["image_vectors"][0]) == func.ndims(), (
|
assert len(first_row["image_vectors"][0]) == func.ndims()
|
||||||
"Vector dimension mismatch"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
|
|||||||
58
python/test.py
Normal file
58
python/test.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
import requests
|
||||||
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
|
import importlib
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
|
||||||
|
import lancedb
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import pyarrow as pa
|
||||||
|
import pytest
|
||||||
|
from lancedb.embeddings import get_registry
|
||||||
|
from lancedb.pydantic import LanceModel, Vector, MultiVector
|
||||||
|
|
||||||
|
db = lancedb.connect("~/.db")
|
||||||
|
registry = get_registry()
|
||||||
|
func = registry.get("multimodal-late-interaction").create(
|
||||||
|
model_name="vidore/colQwen2.5-v0.2",
|
||||||
|
device="auto",
|
||||||
|
batch_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
class MediaItems(LanceModel):
|
||||||
|
text: str
|
||||||
|
image_uri: str = func.SourceField()
|
||||||
|
image_bytes: bytes = func.SourceField()
|
||||||
|
image_vectors: MultiVector(func.ndims()) = func.VectorField()
|
||||||
|
|
||||||
|
table = db.create_table("media", schema=MediaItems, mode="overwrite")
|
||||||
|
|
||||||
|
texts = [
|
||||||
|
"a cute cat playing with yarn",
|
||||||
|
"a puppy in a flower field",
|
||||||
|
"a red sports car on the highway",
|
||||||
|
]
|
||||||
|
|
||||||
|
uris = [
|
||||||
|
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
||||||
|
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
|
||||||
|
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
|
||||||
|
]
|
||||||
|
|
||||||
|
image_bytes = [requests.get(uri).content for uri in uris]
|
||||||
|
|
||||||
|
table.add(
|
||||||
|
pd.DataFrame({"text": texts, "image_uri": uris, "image_bytes": image_bytes})
|
||||||
|
)
|
||||||
|
|
||||||
|
result = (
|
||||||
|
table.search("fluffy companion", vector_column_name="image_vectors")
|
||||||
|
.limit(1)
|
||||||
|
.to_pydantic(MediaItems)[0]
|
||||||
|
)
|
||||||
|
assert any(keyword in result.text.lower() for keyword in ("cat", "puppy"))
|
||||||
|
|
||||||
|
first_row = table.to_arrow().to_pylist()[0]
|
||||||
|
assert len(first_row["image_vectors"]) > 1
|
||||||
|
assert len(first_row["image_vectors"][0]) == func.ndims()
|
||||||
Reference in New Issue
Block a user