Files
lancedb/python/python/lancedb/embeddings/siglip.py
Poornachandra.A.N 7d0127b376 feat(embeddings): add siglip embedding support to lancedb (#2499)
###  Summary

This PR adds **SigLIP** (Sigmoid Loss Image Pretraining) as a new
embedding model in the LanceDB embedding registry. SigLIP improves
image-text alignment performance using sigmoid-based contrastive loss
and offers robust zero-shot generalization.

Fixes #2498 

### What’s Implemented

#### 1. `SigLIP` Embedding Class

* Added `SigLIP` support under `python/lancedb/embeddings/siglip.py`
* Implements:

  * `compute_source_embeddings`
  * `_batch_generate_embeddings`
  * Normalization logic
  * Batch-wise progress logging for image embedding

#### 2. Registry Integration

* Registered `SigLIP` in `embeddings/__init__.py`
* `SigLIP` now usable via `connect(..., embedding="siglip")`

#### 3. Evaluation Benchmark Support

* Added SigLIP to `test_embeddings_slow.py` for side-by-side
benchmarking with OpenCLIP and ImageBind


###  New Test Methods

####  `test_siglip`

* End-to-end test to verify embeddings table creation and vector shape
for SigLIP
![WhatsApp Image 2025-07-10 at 18 00
27_a3368163](https://github.com/user-attachments/assets/e5582ee1-80a3-43d7-a7a1-26ceecce9f4d)


####  `test_siglip_vs_openclip_vs_imagebind_benchmark_full`

* Benchmarks:

  * **Recall\@1 / 5 / 10**
  * **mAP (Mean Average Precision)**
  * **Embedding & Search Latency**
  * Dimensionality reporting
![WhatsApp Image 2025-07-10 at 18 12
13_22c67a84](https://github.com/user-attachments/assets/455bf30f-62b7-4684-a3f3-ad52e2a1ffe5)


###  Notes

* SigLIP outputs 768D embeddings (vs 512D for OpenCLIP)
* Benchmark shows competitive performance despite higher dimensionality
* I'm still new to contributing to open-source and learning as I go.
Please feel free to suggest any improvements — I'm happy to make
changes!
2025-08-04 11:42:39 -07:00

149 lines
5.4 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import concurrent.futures
import io
import os
from typing import TYPE_CHECKING, List, Union
import urllib.parse as urlparse
import numpy as np
import pyarrow as pa
from tqdm import tqdm
from pydantic import PrivateAttr
from ..util import attempt_import_or_raise
from .base import EmbeddingFunction
from .registry import register
from .utils import IMAGES, url_retrieve
if TYPE_CHECKING:
import PIL
import torch
@register("siglip")
class SigLipEmbeddings(EmbeddingFunction):
model_name: str = "google/siglip-base-patch16-224"
device: str = "cpu"
batch_size: int = 64
normalize: bool = True
_model = PrivateAttr()
_processor = PrivateAttr()
_tokenizer = PrivateAttr()
_torch = PrivateAttr()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
transformers = attempt_import_or_raise("transformers")
self._torch = attempt_import_or_raise("torch")
self._processor = transformers.AutoProcessor.from_pretrained(self.model_name)
self._model = transformers.SiglipModel.from_pretrained(self.model_name)
self._model.to(self.device)
self._model.eval()
self._ndims = None
def ndims(self):
if self._ndims is None:
self._ndims = self.generate_text_embeddings("foo").shape[0]
return self._ndims
def compute_query_embeddings(
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
) -> List[np.ndarray]:
if isinstance(query, str):
return [self.generate_text_embeddings(query)]
else:
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(query, PIL.Image.Image):
return [self.generate_image_embedding(query)]
else:
raise TypeError("SigLIP supports str or PIL Image as query")
def generate_text_embeddings(self, text: str) -> np.ndarray:
torch = self._torch
text_inputs = self._processor(
text=text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=64,
).to(self.device)
with torch.no_grad():
text_features = self._model.get_text_features(**text_inputs)
if self.normalize:
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
return text_features.cpu().detach().numpy().squeeze()
def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]:
if isinstance(images, (str, bytes)):
images = [images]
elif isinstance(images, pa.Array):
images = images.to_pylist()
elif isinstance(images, pa.ChunkedArray):
images = images.combine_chunks().to_pylist()
return images
def compute_source_embeddings(
self, images: IMAGES, *args, **kwargs
) -> List[np.ndarray]:
images = self.sanitize_input(images)
embeddings = []
for i in range(0, len(images), self.batch_size):
j = min(i + self.batch_size, len(images))
batch = images[i:j]
embeddings.extend(self._parallel_get(batch))
return embeddings
def _parallel_get(self, images: Union[List[str], List[bytes]]) -> List[np.ndarray]:
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [
executor.submit(self.generate_image_embedding, image)
for image in images
]
return [f.result() for f in tqdm(futures, desc="SigLIP Embedding")]
def generate_image_embedding(
self, image: Union[str, bytes, "PIL.Image.Image"]
) -> np.ndarray:
image = self._to_pil(image)
image = self._processor(images=image, return_tensors="pt")["pixel_values"]
return self._encode_and_normalize_image(image)
def _encode_and_normalize_image(self, image_tensor: "torch.Tensor") -> np.ndarray:
torch = self._torch
with torch.no_grad():
image_features = self._model.get_image_features(
image_tensor.to(self.device)
)
if self.normalize:
image_features = image_features / image_features.norm(
dim=-1, keepdim=True
)
return image_features.cpu().detach().numpy().squeeze()
def _to_pil(self, image: Union[str, bytes, "PIL.Image.Image"]):
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(image, PIL.Image.Image):
return image.convert("RGB") if image.mode != "RGB" else image
elif isinstance(image, bytes):
return PIL.Image.open(io.BytesIO(image)).convert("RGB")
elif isinstance(image, str):
parsed = urlparse.urlparse(image)
if parsed.scheme == "file":
return PIL.Image.open(parsed.path).convert("RGB")
elif parsed.scheme == "":
path = image if os.name == "nt" else parsed.path
return PIL.Image.open(path).convert("RGB")
elif parsed.scheme.startswith("http"):
image_bytes = url_retrieve(image)
return PIL.Image.open(io.BytesIO(image_bytes)).convert("RGB")
else:
raise NotImplementedError("Only local and http(s) urls are supported")
else:
raise ValueError(f"Unsupported image type: {type(image)}")