diff --git a/python/python/lancedb/embeddings/colpali.py b/python/python/lancedb/embeddings/colpali.py index 150eaa52..52b0d113 100644 --- a/python/python/lancedb/embeddings/colpali.py +++ b/python/python/lancedb/embeddings/colpali.py @@ -3,9 +3,11 @@ from functools import lru_cache -from typing import List, Union, Optional, Any +from logging import warning +from typing import List, Union, Optional, Any, Callable import numpy as np import io +import warnings from ..util import attempt_import_or_raise from .base import EmbeddingFunction @@ -19,35 +21,52 @@ class ColPaliEmbeddings(EmbeddingFunction): An embedding function that uses the ColPali engine for multimodal multi-vector embeddings. - This embedding function supports ColQwen2.5 models, producing multivector outputs - for both text and image inputs. The output embeddings are lists of vectors, each - vector being 128-dimensional by default, represented as List[List[float]]. + This embedding function supports ColPali models, producing multivector outputs + for both text and image inputs. Parameters ---------- model_name : str The name of the model to use (e.g., "Metric-AI/ColQwen2.5-3b-multilingual-v1.0") + Supports models based on these engines: + - ColPali: "vidore/colpali-v1.3" and others + - ColQwen2.5: "Metric-AI/ColQwen2.5-3b-multilingual-v1.0" and others + - ColQwen2: "vidore/colqwen2-v1.0" and others + - ColSmol: "vidore/colSmol-256M" and others + device : str - The device for inference (default "cuda:0"). + The device for inference (default "auto"). dtype : str Data type for model weights (default "bfloat16"). use_token_pooling : bool - Whether to use token pooling to reduce embedding size (default True). + DEPRECATED. Whether to use token pooling. Use `pooling_strategy` instead. + pooling_strategy : str, optional + The token pooling strategy to use, by default "hierarchical". + - "hierarchical": Progressively pools tokens to reduce sequence length. + - "lambda": A simpler pooling that uses a custom `pooling_func`. + pooling_func: typing.Callable, optional + A function to use for pooling when `pooling_strategy` is "lambda". pool_factor : int Factor to reduce sequence length if token pooling is enabled (default 2). quantization_config : Optional[BitsAndBytesConfig] Quantization configuration for the model. (default None, bitsandbytes needed) batch_size : int Batch size for processing inputs (default 2). + offload_folder: str, optional + Folder to offload model weights if using CPU offloading (default None). This is + useful for large models that do not fit in memory. """ model_name: str = "Metric-AI/ColQwen2.5-3b-multilingual-v1.0" device: str = "auto" dtype: str = "bfloat16" use_token_pooling: bool = True + pooling_strategy: Optional[str] = "hierarchical" + pooling_func: Optional[Any] = None pool_factor: int = 2 quantization_config: Optional[Any] = None batch_size: int = 2 + offload_folder: Optional[str] = None _model = None _processor = None @@ -56,15 +75,43 @@ class ColPaliEmbeddings(EmbeddingFunction): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + torch = attempt_import_or_raise("torch", "torch") + + if not self.use_token_pooling: + warnings.warn( + "use_token_pooling is deprecated, use pooling_strategy=None instead", + DeprecationWarning, + ) + self.pooling_strategy = None + + if self.pooling_strategy == "lambda" and self.pooling_func is None: + raise ValueError( + "pooling_func must be provided when pooling_strategy is 'lambda'" + ) + + device = self.device + if device == "auto": + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" + + dtype = self.dtype + if device == "mps" and dtype == "bfloat16": + dtype = "float32" # Avoid NaNs on MPS + ( self._model, self._processor, self._token_pooler, ) = self._load_model( self.model_name, - self.dtype, - self.device, - self.use_token_pooling, + dtype, + device, + self.pooling_strategy, + self.pooling_func, self.quantization_config, ) @@ -74,16 +121,26 @@ class ColPaliEmbeddings(EmbeddingFunction): model_name: str, dtype: str, device: str, - use_token_pooling: bool, + pooling_strategy: Optional[str], + pooling_func: Optional[Callable], quantization_config: Optional[Any], ): """ Initialize and cache the ColPali model, processor, and token pooler. """ + if device.startswith("mps"): + # warn some torch ops in late interaction architecture result in nans on mps + warning( + "MPS device detected. Some operations may result in NaNs. " + "If you encounter issues, consider using 'cpu' or 'cuda' devices." + ) torch = attempt_import_or_raise("torch", "torch") transformers = attempt_import_or_raise("transformers", "transformers") colpali_engine = attempt_import_or_raise("colpali_engine", "colpali_engine") - from colpali_engine.compression.token_pooling import HierarchicalTokenPooler + from colpali_engine.compression.token_pooling import ( + HierarchicalTokenPooler, + LambdaTokenPooler, + ) if quantization_config is not None: if not isinstance(quantization_config, transformers.BitsAndBytesConfig): @@ -98,21 +155,45 @@ class ColPaliEmbeddings(EmbeddingFunction): else: torch_dtype = torch.float32 - model = colpali_engine.models.ColQwen2_5.from_pretrained( + model_class, processor_class = None, None + model_name_lower = model_name.lower() + if "colqwen2.5" in model_name_lower: + model_class = colpali_engine.models.ColQwen2_5 + processor_class = colpali_engine.models.ColQwen2_5_Processor + elif "colsmol" in model_name_lower or "colidefics3" in model_name_lower: + model_class = colpali_engine.models.ColIdefics3 + processor_class = colpali_engine.models.ColIdefics3Processor + elif "colqwen" in model_name_lower: + model_class = colpali_engine.models.ColQwen2 + processor_class = colpali_engine.models.ColQwen2Processor + elif "colpali" in model_name_lower: + model_class = colpali_engine.models.ColPali + processor_class = colpali_engine.models.ColPaliProcessor + + if model_class is None: + raise ValueError(f"Unsupported model: {model_name}") + + model = model_class.from_pretrained( model_name, torch_dtype=torch_dtype, - device_map=device, quantization_config=quantization_config if quantization_config is not None else None, attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None, + low_cpu_mem_usage=True, ).eval() - processor = colpali_engine.models.ColQwen2_5_Processor.from_pretrained( - model_name - ) - token_pooler = HierarchicalTokenPooler() if use_token_pooling else None + model = model.to(device) + model = model.to(torch_dtype) # Force cast after moving to device + processor = processor_class.from_pretrained(model_name) + + token_pooler = None + if pooling_strategy == "hierarchical": + token_pooler = HierarchicalTokenPooler() + elif pooling_strategy == "lambda": + token_pooler = LambdaTokenPooler(pool_func=pooling_func) + return model, processor, token_pooler def ndims(self): @@ -128,7 +209,7 @@ class ColPaliEmbeddings(EmbeddingFunction): with torch.no_grad(): query_embeddings = self._model(**batch_queries) - if self.use_token_pooling and self._token_pooler is not None: + if self.pooling_strategy and self._token_pooler is not None: query_embeddings = self._token_pooler.pool_embeddings( query_embeddings, pool_factor=self.pool_factor, @@ -145,13 +226,20 @@ class ColPaliEmbeddings(EmbeddingFunction): Use token pooling if enabled. """ torch = attempt_import_or_raise("torch", "torch") - if self.use_token_pooling and self._token_pooler is not None: - embeddings = self._token_pooler.pool_embeddings( - embeddings, - pool_factor=self.pool_factor, - padding=True, - padding_side=self._processor.tokenizer.padding_side, - ) + if self.pooling_strategy and self._token_pooler is not None: + if self.pooling_strategy == "hierarchical": + embeddings = self._token_pooler.pool_embeddings( + embeddings, + pool_factor=self.pool_factor, + padding=True, + padding_side=self._processor.tokenizer.padding_side, + ) + elif self.pooling_strategy == "lambda": + embeddings = self._token_pooler.pool_embeddings( + embeddings, + padding=True, + padding_side=self._processor.tokenizer.padding_side, + ) if isinstance(embeddings, torch.Tensor): tensors = embeddings.detach().cpu() @@ -179,6 +267,7 @@ class ColPaliEmbeddings(EmbeddingFunction): ) with torch.no_grad(): query_embeddings = self._model(**batch_queries) + query_embeddings = torch.nan_to_num(query_embeddings) all_embeddings.extend(self._process_embeddings(query_embeddings)) return all_embeddings @@ -225,6 +314,7 @@ class ColPaliEmbeddings(EmbeddingFunction): ) with torch.no_grad(): image_embeddings = self._model(**batch_images) + image_embeddings = torch.nan_to_num(image_embeddings) all_embeddings.extend(self._process_embeddings(image_embeddings)) return all_embeddings diff --git a/python/python/tests/test_embeddings_slow.py b/python/python/tests/test_embeddings_slow.py index 50bf76e2..b461c003 100644 --- a/python/python/tests/test_embeddings_slow.py +++ b/python/python/tests/test_embeddings_slow.py @@ -656,6 +656,106 @@ def test_colpali(tmp_path): ) +@pytest.mark.slow +@pytest.mark.skipif( + importlib.util.find_spec("colpali_engine") is None, + reason="colpali_engine not installed", +) +@pytest.mark.parametrize( + "model_name", + [ + "vidore/colSmol-256M", + "vidore/colqwen2.5-v0.2", + "vidore/colpali-v1.3", + "vidore/colqwen2-v1.0", + ], +) +def test_colpali_models(tmp_path, model_name): + import requests + from lancedb.pydantic import LanceModel + + db = lancedb.connect(tmp_path) + registry = get_registry() + func = registry.get("colpali").create(model_name=model_name) + + class MediaItems(LanceModel): + text: str + image_uri: str = func.SourceField() + image_bytes: bytes = func.SourceField() + image_vectors: MultiVector(func.ndims()) = func.VectorField() + + table = db.create_table(f"media_{model_name.replace('/', '_')}", schema=MediaItems) + + texts = [ + "a cute cat playing with yarn", + ] + + uris = [ + "http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg", + ] + + image_bytes = [requests.get(uri).content for uri in uris] + + table.add( + pd.DataFrame({"text": texts, "image_uri": uris, "image_bytes": image_bytes}) + ) + + image_results = ( + table.search("fluffy companion", vector_column_name="image_vectors") + .limit(1) + .to_pydantic(MediaItems)[0] + ) + assert "cat" in image_results.text.lower() or "puppy" in image_results.text.lower() + + first_row = table.to_arrow().to_pylist()[0] + assert len(first_row["image_vectors"]) > 1, "Should have multiple image vectors" + assert len(first_row["image_vectors"][0]) == func.ndims(), ( + "Vector dimension mismatch" + ) + + +@pytest.mark.slow +@pytest.mark.skipif( + importlib.util.find_spec("colpali_engine") is None, + reason="colpali_engine not installed", +) +def test_colpali_pooling(tmp_path): + registry = get_registry() + model_name = "vidore/colSmol-256M" + test_sentence = "a test sentence for pooling" + + # 1. Get embeddings with no pooling + func_no_pool = registry.get("colpali").create( + model_name=model_name, pooling_strategy=None + ) + unpooled_embeddings = func_no_pool.generate_text_embeddings([test_sentence])[0] + original_length = len(unpooled_embeddings) + assert original_length > 1 + + # 2. Test hierarchical pooling + func_hierarchical = registry.get("colpali").create( + model_name=model_name, pooling_strategy="hierarchical", pool_factor=2 + ) + hierarchical_embeddings = func_hierarchical.generate_text_embeddings( + [test_sentence] + )[0] + expected_hierarchical_length = (original_length + 1) // 2 + assert len(hierarchical_embeddings) == expected_hierarchical_length + + # 3. Test lambda pooling + def simple_pool_func(tensor): + return tensor[::2] + + func_lambda = registry.get("colpali").create( + model_name=model_name, + pooling_strategy="lambda", + pooling_func=simple_pool_func, + ) + lambda_embeddings = func_lambda.generate_text_embeddings([test_sentence])[0] + expected_lambda_length = (original_length + 1) // 2 + assert len(lambda_embeddings) == expected_lambda_length + + @pytest.mark.slow def test_siglip(tmp_path, test_images, query_image_bytes): from PIL import Image