feat: expand support for multivector colpali models and enchancements (#2719)

This commit is contained in:
Ayush Chaurasia
2025-10-17 14:36:32 +05:30
committed by GitHub
parent bf55feb9b6
commit 3f2e3986e9
2 changed files with 215 additions and 25 deletions

View File

@@ -3,9 +3,11 @@
from functools import lru_cache
from typing import List, Union, Optional, Any
from logging import warning
from typing import List, Union, Optional, Any, Callable
import numpy as np
import io
import warnings
from ..util import attempt_import_or_raise
from .base import EmbeddingFunction
@@ -19,35 +21,52 @@ class ColPaliEmbeddings(EmbeddingFunction):
An embedding function that uses the ColPali engine for
multimodal multi-vector embeddings.
This embedding function supports ColQwen2.5 models, producing multivector outputs
for both text and image inputs. The output embeddings are lists of vectors, each
vector being 128-dimensional by default, represented as List[List[float]].
This embedding function supports ColPali models, producing multivector outputs
for both text and image inputs.
Parameters
----------
model_name : str
The name of the model to use (e.g., "Metric-AI/ColQwen2.5-3b-multilingual-v1.0")
Supports models based on these engines:
- ColPali: "vidore/colpali-v1.3" and others
- ColQwen2.5: "Metric-AI/ColQwen2.5-3b-multilingual-v1.0" and others
- ColQwen2: "vidore/colqwen2-v1.0" and others
- ColSmol: "vidore/colSmol-256M" and others
device : str
The device for inference (default "cuda:0").
The device for inference (default "auto").
dtype : str
Data type for model weights (default "bfloat16").
use_token_pooling : bool
Whether to use token pooling to reduce embedding size (default True).
DEPRECATED. Whether to use token pooling. Use `pooling_strategy` instead.
pooling_strategy : str, optional
The token pooling strategy to use, by default "hierarchical".
- "hierarchical": Progressively pools tokens to reduce sequence length.
- "lambda": A simpler pooling that uses a custom `pooling_func`.
pooling_func: typing.Callable, optional
A function to use for pooling when `pooling_strategy` is "lambda".
pool_factor : int
Factor to reduce sequence length if token pooling is enabled (default 2).
quantization_config : Optional[BitsAndBytesConfig]
Quantization configuration for the model. (default None, bitsandbytes needed)
batch_size : int
Batch size for processing inputs (default 2).
offload_folder: str, optional
Folder to offload model weights if using CPU offloading (default None). This is
useful for large models that do not fit in memory.
"""
model_name: str = "Metric-AI/ColQwen2.5-3b-multilingual-v1.0"
device: str = "auto"
dtype: str = "bfloat16"
use_token_pooling: bool = True
pooling_strategy: Optional[str] = "hierarchical"
pooling_func: Optional[Any] = None
pool_factor: int = 2
quantization_config: Optional[Any] = None
batch_size: int = 2
offload_folder: Optional[str] = None
_model = None
_processor = None
@@ -56,15 +75,43 @@ class ColPaliEmbeddings(EmbeddingFunction):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
torch = attempt_import_or_raise("torch", "torch")
if not self.use_token_pooling:
warnings.warn(
"use_token_pooling is deprecated, use pooling_strategy=None instead",
DeprecationWarning,
)
self.pooling_strategy = None
if self.pooling_strategy == "lambda" and self.pooling_func is None:
raise ValueError(
"pooling_func must be provided when pooling_strategy is 'lambda'"
)
device = self.device
if device == "auto":
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
dtype = self.dtype
if device == "mps" and dtype == "bfloat16":
dtype = "float32" # Avoid NaNs on MPS
(
self._model,
self._processor,
self._token_pooler,
) = self._load_model(
self.model_name,
self.dtype,
self.device,
self.use_token_pooling,
dtype,
device,
self.pooling_strategy,
self.pooling_func,
self.quantization_config,
)
@@ -74,16 +121,26 @@ class ColPaliEmbeddings(EmbeddingFunction):
model_name: str,
dtype: str,
device: str,
use_token_pooling: bool,
pooling_strategy: Optional[str],
pooling_func: Optional[Callable],
quantization_config: Optional[Any],
):
"""
Initialize and cache the ColPali model, processor, and token pooler.
"""
if device.startswith("mps"):
# warn some torch ops in late interaction architecture result in nans on mps
warning(
"MPS device detected. Some operations may result in NaNs. "
"If you encounter issues, consider using 'cpu' or 'cuda' devices."
)
torch = attempt_import_or_raise("torch", "torch")
transformers = attempt_import_or_raise("transformers", "transformers")
colpali_engine = attempt_import_or_raise("colpali_engine", "colpali_engine")
from colpali_engine.compression.token_pooling import HierarchicalTokenPooler
from colpali_engine.compression.token_pooling import (
HierarchicalTokenPooler,
LambdaTokenPooler,
)
if quantization_config is not None:
if not isinstance(quantization_config, transformers.BitsAndBytesConfig):
@@ -98,21 +155,45 @@ class ColPaliEmbeddings(EmbeddingFunction):
else:
torch_dtype = torch.float32
model = colpali_engine.models.ColQwen2_5.from_pretrained(
model_class, processor_class = None, None
model_name_lower = model_name.lower()
if "colqwen2.5" in model_name_lower:
model_class = colpali_engine.models.ColQwen2_5
processor_class = colpali_engine.models.ColQwen2_5_Processor
elif "colsmol" in model_name_lower or "colidefics3" in model_name_lower:
model_class = colpali_engine.models.ColIdefics3
processor_class = colpali_engine.models.ColIdefics3Processor
elif "colqwen" in model_name_lower:
model_class = colpali_engine.models.ColQwen2
processor_class = colpali_engine.models.ColQwen2Processor
elif "colpali" in model_name_lower:
model_class = colpali_engine.models.ColPali
processor_class = colpali_engine.models.ColPaliProcessor
if model_class is None:
raise ValueError(f"Unsupported model: {model_name}")
model = model_class.from_pretrained(
model_name,
torch_dtype=torch_dtype,
device_map=device,
quantization_config=quantization_config
if quantization_config is not None
else None,
attn_implementation="flash_attention_2"
if is_flash_attn_2_available()
else None,
low_cpu_mem_usage=True,
).eval()
processor = colpali_engine.models.ColQwen2_5_Processor.from_pretrained(
model_name
)
token_pooler = HierarchicalTokenPooler() if use_token_pooling else None
model = model.to(device)
model = model.to(torch_dtype) # Force cast after moving to device
processor = processor_class.from_pretrained(model_name)
token_pooler = None
if pooling_strategy == "hierarchical":
token_pooler = HierarchicalTokenPooler()
elif pooling_strategy == "lambda":
token_pooler = LambdaTokenPooler(pool_func=pooling_func)
return model, processor, token_pooler
def ndims(self):
@@ -128,7 +209,7 @@ class ColPaliEmbeddings(EmbeddingFunction):
with torch.no_grad():
query_embeddings = self._model(**batch_queries)
if self.use_token_pooling and self._token_pooler is not None:
if self.pooling_strategy and self._token_pooler is not None:
query_embeddings = self._token_pooler.pool_embeddings(
query_embeddings,
pool_factor=self.pool_factor,
@@ -145,13 +226,20 @@ class ColPaliEmbeddings(EmbeddingFunction):
Use token pooling if enabled.
"""
torch = attempt_import_or_raise("torch", "torch")
if self.use_token_pooling and self._token_pooler is not None:
embeddings = self._token_pooler.pool_embeddings(
embeddings,
pool_factor=self.pool_factor,
padding=True,
padding_side=self._processor.tokenizer.padding_side,
)
if self.pooling_strategy and self._token_pooler is not None:
if self.pooling_strategy == "hierarchical":
embeddings = self._token_pooler.pool_embeddings(
embeddings,
pool_factor=self.pool_factor,
padding=True,
padding_side=self._processor.tokenizer.padding_side,
)
elif self.pooling_strategy == "lambda":
embeddings = self._token_pooler.pool_embeddings(
embeddings,
padding=True,
padding_side=self._processor.tokenizer.padding_side,
)
if isinstance(embeddings, torch.Tensor):
tensors = embeddings.detach().cpu()
@@ -179,6 +267,7 @@ class ColPaliEmbeddings(EmbeddingFunction):
)
with torch.no_grad():
query_embeddings = self._model(**batch_queries)
query_embeddings = torch.nan_to_num(query_embeddings)
all_embeddings.extend(self._process_embeddings(query_embeddings))
return all_embeddings
@@ -225,6 +314,7 @@ class ColPaliEmbeddings(EmbeddingFunction):
)
with torch.no_grad():
image_embeddings = self._model(**batch_images)
image_embeddings = torch.nan_to_num(image_embeddings)
all_embeddings.extend(self._process_embeddings(image_embeddings))
return all_embeddings

View File

@@ -656,6 +656,106 @@ def test_colpali(tmp_path):
)
@pytest.mark.slow
@pytest.mark.skipif(
importlib.util.find_spec("colpali_engine") is None,
reason="colpali_engine not installed",
)
@pytest.mark.parametrize(
"model_name",
[
"vidore/colSmol-256M",
"vidore/colqwen2.5-v0.2",
"vidore/colpali-v1.3",
"vidore/colqwen2-v1.0",
],
)
def test_colpali_models(tmp_path, model_name):
import requests
from lancedb.pydantic import LanceModel
db = lancedb.connect(tmp_path)
registry = get_registry()
func = registry.get("colpali").create(model_name=model_name)
class MediaItems(LanceModel):
text: str
image_uri: str = func.SourceField()
image_bytes: bytes = func.SourceField()
image_vectors: MultiVector(func.ndims()) = func.VectorField()
table = db.create_table(f"media_{model_name.replace('/', '_')}", schema=MediaItems)
texts = [
"a cute cat playing with yarn",
]
uris = [
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
]
image_bytes = [requests.get(uri).content for uri in uris]
table.add(
pd.DataFrame({"text": texts, "image_uri": uris, "image_bytes": image_bytes})
)
image_results = (
table.search("fluffy companion", vector_column_name="image_vectors")
.limit(1)
.to_pydantic(MediaItems)[0]
)
assert "cat" in image_results.text.lower() or "puppy" in image_results.text.lower()
first_row = table.to_arrow().to_pylist()[0]
assert len(first_row["image_vectors"]) > 1, "Should have multiple image vectors"
assert len(first_row["image_vectors"][0]) == func.ndims(), (
"Vector dimension mismatch"
)
@pytest.mark.slow
@pytest.mark.skipif(
importlib.util.find_spec("colpali_engine") is None,
reason="colpali_engine not installed",
)
def test_colpali_pooling(tmp_path):
registry = get_registry()
model_name = "vidore/colSmol-256M"
test_sentence = "a test sentence for pooling"
# 1. Get embeddings with no pooling
func_no_pool = registry.get("colpali").create(
model_name=model_name, pooling_strategy=None
)
unpooled_embeddings = func_no_pool.generate_text_embeddings([test_sentence])[0]
original_length = len(unpooled_embeddings)
assert original_length > 1
# 2. Test hierarchical pooling
func_hierarchical = registry.get("colpali").create(
model_name=model_name, pooling_strategy="hierarchical", pool_factor=2
)
hierarchical_embeddings = func_hierarchical.generate_text_embeddings(
[test_sentence]
)[0]
expected_hierarchical_length = (original_length + 1) // 2
assert len(hierarchical_embeddings) == expected_hierarchical_length
# 3. Test lambda pooling
def simple_pool_func(tensor):
return tensor[::2]
func_lambda = registry.get("colpali").create(
model_name=model_name,
pooling_strategy="lambda",
pooling_func=simple_pool_func,
)
lambda_embeddings = func_lambda.generate_text_embeddings([test_sentence])[0]
expected_lambda_length = (original_length + 1) // 2
assert len(lambda_embeddings) == expected_lambda_length
@pytest.mark.slow
def test_siglip(tmp_path, test_images, query_image_bytes):
from PIL import Image