diff --git a/python/pyproject.toml b/python/pyproject.toml index 28d378fc..1df029b1 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -77,6 +77,7 @@ embeddings = [ "pillow", "open-clip-torch", "cohere", + "colpali-engine>=0.3.10", "huggingface_hub", "InstructorEmbedding", "google.generativeai", diff --git a/python/python/lancedb/embeddings/__init__.py b/python/python/lancedb/embeddings/__init__.py index c4854fd1..8164a6b3 100644 --- a/python/python/lancedb/embeddings/__init__.py +++ b/python/python/lancedb/embeddings/__init__.py @@ -19,3 +19,4 @@ from .imagebind import ImageBindEmbeddings from .jinaai import JinaEmbeddings from .watsonx import WatsonxEmbeddings from .voyageai import VoyageAIEmbeddingFunction +from .colpali import ColPaliEmbeddings diff --git a/python/python/lancedb/embeddings/colpali.py b/python/python/lancedb/embeddings/colpali.py new file mode 100644 index 00000000..150eaa52 --- /dev/null +++ b/python/python/lancedb/embeddings/colpali.py @@ -0,0 +1,255 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The LanceDB Authors + + +from functools import lru_cache +from typing import List, Union, Optional, Any +import numpy as np +import io + +from ..util import attempt_import_or_raise +from .base import EmbeddingFunction +from .registry import register +from .utils import TEXT, IMAGES, is_flash_attn_2_available + + +@register("colpali") +class ColPaliEmbeddings(EmbeddingFunction): + """ + An embedding function that uses the ColPali engine for + multimodal multi-vector embeddings. + + This embedding function supports ColQwen2.5 models, producing multivector outputs + for both text and image inputs. The output embeddings are lists of vectors, each + vector being 128-dimensional by default, represented as List[List[float]]. + + Parameters + ---------- + model_name : str + The name of the model to use (e.g., "Metric-AI/ColQwen2.5-3b-multilingual-v1.0") + device : str + The device for inference (default "cuda:0"). + dtype : str + Data type for model weights (default "bfloat16"). + use_token_pooling : bool + Whether to use token pooling to reduce embedding size (default True). + pool_factor : int + Factor to reduce sequence length if token pooling is enabled (default 2). + quantization_config : Optional[BitsAndBytesConfig] + Quantization configuration for the model. (default None, bitsandbytes needed) + batch_size : int + Batch size for processing inputs (default 2). + """ + + model_name: str = "Metric-AI/ColQwen2.5-3b-multilingual-v1.0" + device: str = "auto" + dtype: str = "bfloat16" + use_token_pooling: bool = True + pool_factor: int = 2 + quantization_config: Optional[Any] = None + batch_size: int = 2 + + _model = None + _processor = None + _token_pooler = None + _vector_dim = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + ( + self._model, + self._processor, + self._token_pooler, + ) = self._load_model( + self.model_name, + self.dtype, + self.device, + self.use_token_pooling, + self.quantization_config, + ) + + @staticmethod + @lru_cache(maxsize=1) + def _load_model( + model_name: str, + dtype: str, + device: str, + use_token_pooling: bool, + quantization_config: Optional[Any], + ): + """ + Initialize and cache the ColPali model, processor, and token pooler. + """ + torch = attempt_import_or_raise("torch", "torch") + transformers = attempt_import_or_raise("transformers", "transformers") + colpali_engine = attempt_import_or_raise("colpali_engine", "colpali_engine") + from colpali_engine.compression.token_pooling import HierarchicalTokenPooler + + if quantization_config is not None: + if not isinstance(quantization_config, transformers.BitsAndBytesConfig): + raise ValueError("quantization_config must be a BitsAndBytesConfig") + + if dtype == "bfloat16": + torch_dtype = torch.bfloat16 + elif dtype == "float16": + torch_dtype = torch.float16 + elif dtype == "float64": + torch_dtype = torch.float64 + else: + torch_dtype = torch.float32 + + model = colpali_engine.models.ColQwen2_5.from_pretrained( + model_name, + torch_dtype=torch_dtype, + device_map=device, + quantization_config=quantization_config + if quantization_config is not None + else None, + attn_implementation="flash_attention_2" + if is_flash_attn_2_available() + else None, + ).eval() + processor = colpali_engine.models.ColQwen2_5_Processor.from_pretrained( + model_name + ) + token_pooler = HierarchicalTokenPooler() if use_token_pooling else None + return model, processor, token_pooler + + def ndims(self): + """ + Return the dimension of a vector in the multivector output (e.g., 128). + """ + torch = attempt_import_or_raise("torch", "torch") + if self._vector_dim is None: + dummy_query = "test" + batch_queries = self._processor.process_queries([dummy_query]).to( + self._model.device + ) + with torch.no_grad(): + query_embeddings = self._model(**batch_queries) + + if self.use_token_pooling and self._token_pooler is not None: + query_embeddings = self._token_pooler.pool_embeddings( + query_embeddings, + pool_factor=self.pool_factor, + padding=True, + padding_side=self._processor.tokenizer.padding_side, + ) + + self._vector_dim = query_embeddings[0].shape[-1] + return self._vector_dim + + def _process_embeddings(self, embeddings): + """ + Format model embeddings into List[List[float]]. + Use token pooling if enabled. + """ + torch = attempt_import_or_raise("torch", "torch") + if self.use_token_pooling and self._token_pooler is not None: + embeddings = self._token_pooler.pool_embeddings( + embeddings, + pool_factor=self.pool_factor, + padding=True, + padding_side=self._processor.tokenizer.padding_side, + ) + + if isinstance(embeddings, torch.Tensor): + tensors = embeddings.detach().cpu() + if tensors.dtype == torch.bfloat16: + tensors = tensors.to(torch.float32) + return ( + tensors.numpy() + .astype(np.float64 if self.dtype == "float64" else np.float32) + .tolist() + ) + return [] + + def generate_text_embeddings(self, text: TEXT) -> List[List[List[float]]]: + """ + Generate embeddings for text input. + """ + torch = attempt_import_or_raise("torch", "torch") + text = self.sanitize_input(text) + all_embeddings = [] + + for i in range(0, len(text), self.batch_size): + batch_text = text[i : i + self.batch_size] + batch_queries = self._processor.process_queries(batch_text).to( + self._model.device + ) + with torch.no_grad(): + query_embeddings = self._model(**batch_queries) + all_embeddings.extend(self._process_embeddings(query_embeddings)) + return all_embeddings + + def _prepare_images(self, images: IMAGES) -> List: + """ + Convert image inputs to PIL Images. + """ + PIL = attempt_import_or_raise("PIL", "pillow") + requests = attempt_import_or_raise("requests", "requests") + images = self.sanitize_input(images) + pil_images = [] + try: + for image in images: + if isinstance(image, str): + if image.startswith(("http://", "https://")): + response = requests.get(image, timeout=10) + response.raise_for_status() + pil_images.append(PIL.Image.open(io.BytesIO(response.content))) + else: + with PIL.Image.open(image) as im: + pil_images.append(im.copy()) + elif isinstance(image, bytes): + pil_images.append(PIL.Image.open(io.BytesIO(image))) + else: + # Assume it's a PIL Image; will raise if invalid + pil_images.append(image) + except Exception as e: + raise ValueError(f"Failed to process image: {e}") + + return pil_images + + def generate_image_embeddings(self, images: IMAGES) -> List[List[List[float]]]: + """ + Generate embeddings for a batch of images. + """ + torch = attempt_import_or_raise("torch", "torch") + pil_images = self._prepare_images(images) + all_embeddings = [] + + for i in range(0, len(pil_images), self.batch_size): + batch_images = pil_images[i : i + self.batch_size] + batch_images = self._processor.process_images(batch_images).to( + self._model.device + ) + with torch.no_grad(): + image_embeddings = self._model(**batch_images) + all_embeddings.extend(self._process_embeddings(image_embeddings)) + return all_embeddings + + def compute_query_embeddings( + self, query: Union[str, IMAGES], *args, **kwargs + ) -> List[List[List[float]]]: + """ + Compute embeddings for a single user query (text only). + """ + if not isinstance(query, str): + raise ValueError( + "Query must be a string, image to image search is not supported" + ) + return self.generate_text_embeddings([query]) + + def compute_source_embeddings( + self, images: IMAGES, *args, **kwargs + ) -> List[List[List[float]]]: + """ + Compute embeddings for a batch of source images. + + Parameters + ---------- + images : Union[str, bytes, List, pa.Array, pa.ChunkedArray, np.ndarray] + Batch of images (paths, URLs, bytes, or PIL Images). + """ + images = self.sanitize_input(images) + return self.generate_image_embeddings(images) diff --git a/python/python/lancedb/embeddings/utils.py b/python/python/lancedb/embeddings/utils.py index 6a4c577c..a9e4bfb8 100644 --- a/python/python/lancedb/embeddings/utils.py +++ b/python/python/lancedb/embeddings/utils.py @@ -18,6 +18,7 @@ import numpy as np import pyarrow as pa from ..dependencies import pandas as pd +from ..util import attempt_import_or_raise # ruff: noqa: PERF203 @@ -275,3 +276,12 @@ def url_retrieve(url: str): def api_key_not_found_help(provider): logging.error("Could not find API key for %s", provider) raise ValueError(f"Please set the {provider.upper()}_API_KEY environment variable.") + + +def is_flash_attn_2_available(): + try: + attempt_import_or_raise("flash_attn", "flash_attn") + + return True + except ImportError: + return False diff --git a/python/python/lancedb/pydantic.py b/python/python/lancedb/pydantic.py index c665c9de..f68d0163 100644 --- a/python/python/lancedb/pydantic.py +++ b/python/python/lancedb/pydantic.py @@ -152,6 +152,104 @@ def Vector( return FixedSizeList +def MultiVector( + dim: int, value_type: pa.DataType = pa.float32(), nullable: bool = True +) -> Type: + """Pydantic MultiVector Type for multi-vector embeddings. + + This type represents a list of vectors, each with the same dimension. + Useful for models that produce multiple embeddings per input, like ColPali. + + Parameters + ---------- + dim : int + The dimension of each vector in the multi-vector. + value_type : pyarrow.DataType, optional + The value type of the vectors, by default pa.float32() + nullable : bool, optional + Whether the multi-vector is nullable, by default it is True. + + Examples + -------- + + >>> import pydantic + >>> from lancedb.pydantic import MultiVector + ... + >>> class MyModel(pydantic.BaseModel): + ... id: int + ... text: str + ... embeddings: MultiVector(128) # List of 128-dimensional vectors + >>> schema = pydantic_to_schema(MyModel) + >>> assert schema == pa.schema([ + ... pa.field("id", pa.int64(), False), + ... pa.field("text", pa.utf8(), False), + ... pa.field("embeddings", pa.list_(pa.list_(pa.float32(), 128))) + ... ]) + """ + + class MultiVectorList(list, FixedSizeListMixin): + def __repr__(self): + return f"MultiVector(dim={dim})" + + @staticmethod + def nullable() -> bool: + return nullable + + @staticmethod + def dim() -> int: + return dim + + @staticmethod + def value_arrow_type() -> pa.DataType: + return value_type + + @staticmethod + def is_multi_vector() -> bool: + return True + + @classmethod + def __get_pydantic_core_schema__( + cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler + ) -> CoreSchema: + return core_schema.no_info_after_validator_function( + cls, + core_schema.list_schema( + items_schema=core_schema.list_schema( + min_length=dim, + max_length=dim, + items_schema=core_schema.float_schema(), + ), + ), + ) + + @classmethod + def __get_validators__(cls) -> Generator[Callable, None, None]: + yield cls.validate + + # For pydantic v1 + @classmethod + def validate(cls, v): + if not isinstance(v, (list, range)): + raise TypeError("A list of vectors is needed") + for vec in v: + if not isinstance(vec, (list, range, np.ndarray)) or len(vec) != dim: + raise TypeError(f"Each vector must be a list of {dim} numbers") + return cls(v) + + if PYDANTIC_VERSION.major < 2: + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]): + field_schema["items"] = { + "type": "array", + "items": {"type": "number"}, + "minItems": dim, + "maxItems": dim, + } + + return MultiVectorList + + def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType: """Convert a field with native Python type to Arrow data type. @@ -206,6 +304,9 @@ def _pydantic_type_to_arrow_type(tp: Any, field: FieldInfo) -> pa.DataType: fields = _pydantic_model_to_fields(tp) return pa.struct(fields) if issubclass(tp, FixedSizeListMixin): + if getattr(tp, "is_multi_vector", lambda: False)(): + return pa.list_(pa.list_(tp.value_arrow_type(), tp.dim())) + # For regular Vector return pa.list_(tp.value_arrow_type(), tp.dim()) return _py_type_to_arrow_type(tp, field) diff --git a/python/python/tests/test_embeddings_slow.py b/python/python/tests/test_embeddings_slow.py index 0eec2b3d..d0d71577 100644 --- a/python/python/tests/test_embeddings_slow.py +++ b/python/python/tests/test_embeddings_slow.py @@ -11,7 +11,7 @@ import pandas as pd import pyarrow as pa import pytest from lancedb.embeddings import get_registry -from lancedb.pydantic import LanceModel, Vector +from lancedb.pydantic import LanceModel, Vector, MultiVector import requests # These are integration tests for embedding functions. @@ -575,3 +575,67 @@ def test_voyageai_multimodal_embedding_text_function(): tbl.add(df) assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims() + + +@pytest.mark.slow +@pytest.mark.skipif( + importlib.util.find_spec("colpali_engine") is None, + reason="colpali_engine not installed", +) +def test_colpali(tmp_path): + import requests + from lancedb.pydantic import LanceModel + + db = lancedb.connect(tmp_path) + registry = get_registry() + func = registry.get("colpali").create() + + class MediaItems(LanceModel): + text: str + image_uri: str = func.SourceField() + image_bytes: bytes = func.SourceField() + image_vectors: MultiVector(func.ndims()) = ( + func.VectorField() + ) # Multivector image embeddings + + table = db.create_table("media", schema=MediaItems) + + texts = [ + "a cute cat playing with yarn", + "a puppy in a flower field", + "a red sports car on the highway", + "a vintage bicycle leaning against a wall", + "a plate of delicious pasta", + "fresh fruit salad in a bowl", + ] + + uris = [ + "http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg", + "http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg", + "http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg", + "http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg", + "http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg", + "http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg", + ] + + # Get images as bytes + image_bytes = [requests.get(uri).content for uri in uris] + + table.add( + pd.DataFrame({"text": texts, "image_uri": uris, "image_bytes": image_bytes}) + ) + + # Test text-to-image search + image_results = ( + table.search("fluffy companion", vector_column_name="image_vectors") + .limit(1) + .to_pydantic(MediaItems)[0] + ) + assert "cat" in image_results.text.lower() or "puppy" in image_results.text.lower() + + # Verify multivector dimensions + first_row = table.to_arrow().to_pylist()[0] + assert len(first_row["image_vectors"]) > 1, "Should have multiple image vectors" + assert len(first_row["image_vectors"][0]) == func.ndims(), ( + "Vector dimension mismatch" + ) diff --git a/python/python/tests/test_pydantic.py b/python/python/tests/test_pydantic.py index 1648a518..514871cc 100644 --- a/python/python/tests/test_pydantic.py +++ b/python/python/tests/test_pydantic.py @@ -9,7 +9,13 @@ from typing import List, Optional, Tuple import pyarrow as pa import pydantic import pytest -from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema +from lancedb.pydantic import ( + PYDANTIC_VERSION, + LanceModel, + Vector, + pydantic_to_schema, + MultiVector, +) from pydantic import BaseModel from pydantic import Field @@ -354,3 +360,55 @@ def test_optional_nested_model(): ), ] ) + + +def test_multi_vector(): + class TestModel(pydantic.BaseModel): + vec: MultiVector(8) + + schema = pydantic_to_schema(TestModel) + assert schema == pa.schema( + [pa.field("vec", pa.list_(pa.list_(pa.float32(), 8)), True)] + ) + + with pytest.raises(pydantic.ValidationError): + TestModel(vec=[[1.0] * 7]) + + with pytest.raises(pydantic.ValidationError): + TestModel(vec=[[1.0] * 9]) + + TestModel(vec=[[1.0] * 8]) + TestModel(vec=[[1.0] * 8, [2.0] * 8]) + + TestModel(vec=[]) + + +def test_multi_vector_nullable(): + class NullableModel(pydantic.BaseModel): + vec: MultiVector(16, nullable=False) + + schema = pydantic_to_schema(NullableModel) + assert schema == pa.schema( + [pa.field("vec", pa.list_(pa.list_(pa.float32(), 16)), False)] + ) + + class DefaultModel(pydantic.BaseModel): + vec: MultiVector(16) + + schema = pydantic_to_schema(DefaultModel) + assert schema == pa.schema( + [pa.field("vec", pa.list_(pa.list_(pa.float32(), 16)), True)] + ) + + +def test_multi_vector_in_lance_model(): + class TestModel(LanceModel): + id: int + vectors: MultiVector(16) = Field(default=[[0.0] * 16]) + + schema = pydantic_to_schema(TestModel) + assert schema == TestModel.to_arrow_schema() + assert TestModel.field_names() == ["id", "vectors"] + + t = TestModel(id=1) + assert t.vectors == [[0.0] * 16] diff --git a/python/python/tests/test_query.py b/python/python/tests/test_query.py index 7e3aaf81..23d0e35f 100644 --- a/python/python/tests/test_query.py +++ b/python/python/tests/test_query.py @@ -257,7 +257,9 @@ async def test_distance_range_with_new_rows_async(): } ) table = await conn.create_table("test", data) - table.create_index("vector", config=IvfPq(num_partitions=1, num_sub_vectors=2)) + await table.create_index( + "vector", config=IvfPq(num_partitions=1, num_sub_vectors=2) + ) q = [0, 0] rs = await table.query().nearest_to(q).to_arrow()