From 7d0127b376871b11076f0bfcf471cda0fe665605 Mon Sep 17 00:00:00 2001 From: "Poornachandra.A.N" <124677280+Heisenberg208@users.noreply.github.com> Date: Tue, 5 Aug 2025 00:12:39 +0530 Subject: [PATCH] feat(embeddings): add siglip embedding support to lancedb (#2499) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Summary This PR adds **SigLIP** (Sigmoid Loss Image Pretraining) as a new embedding model in the LanceDB embedding registry. SigLIP improves image-text alignment performance using sigmoid-based contrastive loss and offers robust zero-shot generalization. Fixes #2498 ### What’s Implemented #### 1. `SigLIP` Embedding Class * Added `SigLIP` support under `python/lancedb/embeddings/siglip.py` * Implements: * `compute_source_embeddings` * `_batch_generate_embeddings` * Normalization logic * Batch-wise progress logging for image embedding #### 2. Registry Integration * Registered `SigLIP` in `embeddings/__init__.py` * `SigLIP` now usable via `connect(..., embedding="siglip")` #### 3. Evaluation Benchmark Support * Added SigLIP to `test_embeddings_slow.py` for side-by-side benchmarking with OpenCLIP and ImageBind ### New Test Methods #### `test_siglip` * End-to-end test to verify embeddings table creation and vector shape for SigLIP ![WhatsApp Image 2025-07-10 at 18 00 27_a3368163](https://github.com/user-attachments/assets/e5582ee1-80a3-43d7-a7a1-26ceecce9f4d) #### `test_siglip_vs_openclip_vs_imagebind_benchmark_full` * Benchmarks: * **Recall\@1 / 5 / 10** * **mAP (Mean Average Precision)** * **Embedding & Search Latency** * Dimensionality reporting ![WhatsApp Image 2025-07-10 at 18 12 13_22c67a84](https://github.com/user-attachments/assets/455bf30f-62b7-4684-a3f3-ad52e2a1ffe5) ### Notes * SigLIP outputs 768D embeddings (vs 512D for OpenCLIP) * Benchmark shows competitive performance despite higher dimensionality * I'm still new to contributing to open-source and learning as I go. Please feel free to suggest any improvements — I'm happy to make changes! --- python/.gitignore | 3 +- python/pyproject.toml | 4 +- python/python/lancedb/__init__.py | 2 +- python/python/lancedb/embeddings/__init__.py | 1 + python/python/lancedb/embeddings/siglip.py | 148 +++++++++++++++++++ python/python/tests/test_embeddings_slow.py | 121 ++++++++++++--- 6 files changed, 257 insertions(+), 22 deletions(-) create mode 100644 python/python/lancedb/embeddings/siglip.py diff --git a/python/.gitignore b/python/.gitignore index 10d800aa..cd6b09fa 100644 --- a/python/.gitignore +++ b/python/.gitignore @@ -1,2 +1,3 @@ # Test data created by some example tests -data/ \ No newline at end of file +data/ +_lancedb.pyd diff --git a/python/pyproject.toml b/python/pyproject.toml index b8405bc6..f9f4c55a 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -68,8 +68,9 @@ dev = [ "pyright", 'typing-extensions>=4.0.0; python_version < "3.11"', ] -docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] +docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings-python"] clip = ["torch", "pillow", "open-clip-torch"] +siglip = ["torch", "pillow", "transformers>=4.41.0","sentencepiece"] embeddings = [ "requests>=2.31.0", "openai>=1.6.1", @@ -87,6 +88,7 @@ embeddings = [ "botocore>=1.31.57", 'ibm-watsonx-ai>=1.1.2; python_version >= "3.10"', "ollama>=0.3.0", + "sentencepiece" ] azure = ["adlfs>=2024.2.0"] diff --git a/python/python/lancedb/__init__.py b/python/python/lancedb/__init__.py index 3227a9af..e4112577 100644 --- a/python/python/lancedb/__init__.py +++ b/python/python/lancedb/__init__.py @@ -241,4 +241,4 @@ def __warn_on_fork(): if hasattr(os, "register_at_fork"): - os.register_at_fork(before=__warn_on_fork) + os.register_at_fork(before=__warn_on_fork) # type: ignore[attr-defined] diff --git a/python/python/lancedb/embeddings/__init__.py b/python/python/lancedb/embeddings/__init__.py index 3cf320f8..f70aa57b 100644 --- a/python/python/lancedb/embeddings/__init__.py +++ b/python/python/lancedb/embeddings/__init__.py @@ -20,3 +20,4 @@ from .jinaai import JinaEmbeddings from .watsonx import WatsonxEmbeddings from .voyageai import VoyageAIEmbeddingFunction from .colpali import ColPaliEmbeddings +from .siglip import SigLipEmbeddings diff --git a/python/python/lancedb/embeddings/siglip.py b/python/python/lancedb/embeddings/siglip.py new file mode 100644 index 00000000..7e9c6adc --- /dev/null +++ b/python/python/lancedb/embeddings/siglip.py @@ -0,0 +1,148 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The LanceDB Authors + +import concurrent.futures +import io +import os +from typing import TYPE_CHECKING, List, Union +import urllib.parse as urlparse + +import numpy as np +import pyarrow as pa +from tqdm import tqdm +from pydantic import PrivateAttr + +from ..util import attempt_import_or_raise +from .base import EmbeddingFunction +from .registry import register +from .utils import IMAGES, url_retrieve + +if TYPE_CHECKING: + import PIL + import torch + + +@register("siglip") +class SigLipEmbeddings(EmbeddingFunction): + model_name: str = "google/siglip-base-patch16-224" + device: str = "cpu" + batch_size: int = 64 + normalize: bool = True + + _model = PrivateAttr() + _processor = PrivateAttr() + _tokenizer = PrivateAttr() + _torch = PrivateAttr() + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + transformers = attempt_import_or_raise("transformers") + self._torch = attempt_import_or_raise("torch") + + self._processor = transformers.AutoProcessor.from_pretrained(self.model_name) + self._model = transformers.SiglipModel.from_pretrained(self.model_name) + self._model.to(self.device) + self._model.eval() + self._ndims = None + + def ndims(self): + if self._ndims is None: + self._ndims = self.generate_text_embeddings("foo").shape[0] + return self._ndims + + def compute_query_embeddings( + self, query: Union[str, "PIL.Image.Image"], *args, **kwargs + ) -> List[np.ndarray]: + if isinstance(query, str): + return [self.generate_text_embeddings(query)] + else: + PIL = attempt_import_or_raise("PIL", "pillow") + if isinstance(query, PIL.Image.Image): + return [self.generate_image_embedding(query)] + else: + raise TypeError("SigLIP supports str or PIL Image as query") + + def generate_text_embeddings(self, text: str) -> np.ndarray: + torch = self._torch + text_inputs = self._processor( + text=text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=64, + ).to(self.device) + + with torch.no_grad(): + text_features = self._model.get_text_features(**text_inputs) + if self.normalize: + text_features = text_features / text_features.norm(dim=-1, keepdim=True) + return text_features.cpu().detach().numpy().squeeze() + + def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]: + if isinstance(images, (str, bytes)): + images = [images] + elif isinstance(images, pa.Array): + images = images.to_pylist() + elif isinstance(images, pa.ChunkedArray): + images = images.combine_chunks().to_pylist() + return images + + def compute_source_embeddings( + self, images: IMAGES, *args, **kwargs + ) -> List[np.ndarray]: + images = self.sanitize_input(images) + embeddings = [] + + for i in range(0, len(images), self.batch_size): + j = min(i + self.batch_size, len(images)) + batch = images[i:j] + embeddings.extend(self._parallel_get(batch)) + return embeddings + + def _parallel_get(self, images: Union[List[str], List[bytes]]) -> List[np.ndarray]: + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [ + executor.submit(self.generate_image_embedding, image) + for image in images + ] + return [f.result() for f in tqdm(futures, desc="SigLIP Embedding")] + + def generate_image_embedding( + self, image: Union[str, bytes, "PIL.Image.Image"] + ) -> np.ndarray: + image = self._to_pil(image) + image = self._processor(images=image, return_tensors="pt")["pixel_values"] + return self._encode_and_normalize_image(image) + + def _encode_and_normalize_image(self, image_tensor: "torch.Tensor") -> np.ndarray: + torch = self._torch + with torch.no_grad(): + image_features = self._model.get_image_features( + image_tensor.to(self.device) + ) + if self.normalize: + image_features = image_features / image_features.norm( + dim=-1, keepdim=True + ) + return image_features.cpu().detach().numpy().squeeze() + + def _to_pil(self, image: Union[str, bytes, "PIL.Image.Image"]): + PIL = attempt_import_or_raise("PIL", "pillow") + if isinstance(image, PIL.Image.Image): + return image.convert("RGB") if image.mode != "RGB" else image + elif isinstance(image, bytes): + return PIL.Image.open(io.BytesIO(image)).convert("RGB") + elif isinstance(image, str): + parsed = urlparse.urlparse(image) + if parsed.scheme == "file": + return PIL.Image.open(parsed.path).convert("RGB") + elif parsed.scheme == "": + path = image if os.name == "nt" else parsed.path + return PIL.Image.open(path).convert("RGB") + elif parsed.scheme.startswith("http"): + image_bytes = url_retrieve(image) + return PIL.Image.open(io.BytesIO(image_bytes)).convert("RGB") + else: + raise NotImplementedError("Only local and http(s) urls are supported") + else: + raise ValueError(f"Unsupported image type: {type(image)}") diff --git a/python/python/tests/test_embeddings_slow.py b/python/python/tests/test_embeddings_slow.py index d0d71577..50bf76e2 100644 --- a/python/python/tests/test_embeddings_slow.py +++ b/python/python/tests/test_embeddings_slow.py @@ -4,7 +4,6 @@ import importlib import io import os - import lancedb import numpy as np import pandas as pd @@ -12,7 +11,6 @@ import pyarrow as pa import pytest from lancedb.embeddings import get_registry from lancedb.pydantic import LanceModel, Vector, MultiVector -import requests # These are integration tests for embedding functions. # They are slow because they require downloading models @@ -98,9 +96,34 @@ def test_basic_text_embeddings(alias, tmp_path): assert not np.allclose(actual.vector, actual.vector2) -@pytest.mark.slow -def test_openclip(tmp_path): +@pytest.fixture(scope="module") +def test_images(): import requests + + labels = ["cat", "cat", "dog", "dog", "horse", "horse"] + uris = [ + "http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg", + "http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg", + "http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg", + "http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg", + "http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg", + "http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg", + ] + image_bytes = [requests.get(uri).content for uri in uris] + return labels, uris, image_bytes + + +@pytest.fixture(scope="module") +def query_image_bytes(): + import requests + + query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg" + image_bytes = requests.get(query_image_uri).content + return image_bytes + + +@pytest.mark.slow +def test_openclip(tmp_path, test_images, query_image_bytes): from PIL import Image db = lancedb.connect(tmp_path) @@ -114,20 +137,12 @@ def test_openclip(tmp_path): vector: Vector(func.ndims()) = func.VectorField() vec_from_bytes: Vector(func.ndims()) = func.VectorField() + labels, uris, image_bytes_list = test_images table = db.create_table("images", schema=Images) - labels = ["cat", "cat", "dog", "dog", "horse", "horse"] - uris = [ - "http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg", - "http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg", - "http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg", - "http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg", - "http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg", - "http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg", - ] - # get each uri as bytes - image_bytes = [requests.get(uri).content for uri in uris] table.add( - pd.DataFrame({"label": labels, "image_uri": uris, "image_bytes": image_bytes}) + pd.DataFrame( + {"label": labels, "image_uri": uris, "image_bytes": image_bytes_list} + ) ) # text search @@ -146,9 +161,7 @@ def test_openclip(tmp_path): assert np.allclose(actual.vector, frombytes.vector) # image search - query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg" - image_bytes = requests.get(query_image_uri).content - query_image = Image.open(io.BytesIO(image_bytes)) + query_image = Image.open(io.BytesIO(query_image_bytes)) actual = ( table.search(query_image, vector_column_name="vector") .limit(1) @@ -524,6 +537,8 @@ def test_voyageai_embedding_function(): os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set" ) def test_voyageai_multimodal_embedding_function(): + import requests + voyageai = ( get_registry().get("voyageai").create(name="voyage-multimodal-3", max_retries=0) ) @@ -639,3 +654,71 @@ def test_colpali(tmp_path): assert len(first_row["image_vectors"][0]) == func.ndims(), ( "Vector dimension mismatch" ) + + +@pytest.mark.slow +def test_siglip(tmp_path, test_images, query_image_bytes): + from PIL import Image + + labels, uris, image_bytes = test_images + + db = lancedb.connect(tmp_path) + registry = get_registry() + func = registry.get("siglip").create(max_retries=0) + + class Images(LanceModel): + label: str + image_uri: str = func.SourceField() + image_bytes: bytes = func.SourceField() + vector: Vector(func.ndims()) = func.VectorField() + vec_from_bytes: Vector(func.ndims()) = func.VectorField() + + table = db.create_table("images", schema=Images) + + table.add( + pd.DataFrame( + { + "label": labels, + "image_uri": uris, + "image_bytes": image_bytes, + } + ) + ) + + # Text search + actual = ( + table.search("man's best friend", vector_column_name="vector") + .limit(1) + .to_pydantic(Images)[0] + ) + assert actual.label == "dog" + + frombytes = ( + table.search("man's best friend", vector_column_name="vec_from_bytes") + .limit(1) + .to_pydantic(Images)[0] + ) + assert actual.label == frombytes.label + assert np.allclose(actual.vector, frombytes.vector) + + # Image search + query_image = Image.open(io.BytesIO(query_image_bytes)) + actual = ( + table.search(query_image, vector_column_name="vector") + .limit(1) + .to_pydantic(Images)[0] + ) + assert actual.label == "dog" + + other = ( + table.search(query_image, vector_column_name="vec_from_bytes") + .limit(1) + .to_pydantic(Images)[0] + ) + assert actual.label == other.label + + arrow_table = table.search().select(["vector", "vec_from_bytes"]).to_arrow() + assert np.allclose( + arrow_table["vector"].combine_chunks().values.to_numpy(), + arrow_table["vec_from_bytes"].combine_chunks().values.to_numpy(), + )