feat(embeddings): add siglip embedding support to lancedb (#2499)

###  Summary

This PR adds **SigLIP** (Sigmoid Loss Image Pretraining) as a new
embedding model in the LanceDB embedding registry. SigLIP improves
image-text alignment performance using sigmoid-based contrastive loss
and offers robust zero-shot generalization.

Fixes #2498 

### What’s Implemented

#### 1. `SigLIP` Embedding Class

* Added `SigLIP` support under `python/lancedb/embeddings/siglip.py`
* Implements:

  * `compute_source_embeddings`
  * `_batch_generate_embeddings`
  * Normalization logic
  * Batch-wise progress logging for image embedding

#### 2. Registry Integration

* Registered `SigLIP` in `embeddings/__init__.py`
* `SigLIP` now usable via `connect(..., embedding="siglip")`

#### 3. Evaluation Benchmark Support

* Added SigLIP to `test_embeddings_slow.py` for side-by-side
benchmarking with OpenCLIP and ImageBind


###  New Test Methods

####  `test_siglip`

* End-to-end test to verify embeddings table creation and vector shape
for SigLIP
![WhatsApp Image 2025-07-10 at 18 00
27_a3368163](https://github.com/user-attachments/assets/e5582ee1-80a3-43d7-a7a1-26ceecce9f4d)


####  `test_siglip_vs_openclip_vs_imagebind_benchmark_full`

* Benchmarks:

  * **Recall\@1 / 5 / 10**
  * **mAP (Mean Average Precision)**
  * **Embedding & Search Latency**
  * Dimensionality reporting
![WhatsApp Image 2025-07-10 at 18 12
13_22c67a84](https://github.com/user-attachments/assets/455bf30f-62b7-4684-a3f3-ad52e2a1ffe5)


###  Notes

* SigLIP outputs 768D embeddings (vs 512D for OpenCLIP)
* Benchmark shows competitive performance despite higher dimensionality
* I'm still new to contributing to open-source and learning as I go.
Please feel free to suggest any improvements — I'm happy to make
changes!
This commit is contained in:
Poornachandra.A.N
2025-08-05 00:12:39 +05:30
committed by GitHub
parent 02595dc475
commit 7d0127b376
6 changed files with 257 additions and 22 deletions

3
python/.gitignore vendored
View File

@@ -1,2 +1,3 @@
# Test data created by some example tests
data/
data/
_lancedb.pyd

View File

@@ -68,8 +68,9 @@ dev = [
"pyright",
'typing-extensions>=4.0.0; python_version < "3.11"',
]
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings-python"]
clip = ["torch", "pillow", "open-clip-torch"]
siglip = ["torch", "pillow", "transformers>=4.41.0","sentencepiece"]
embeddings = [
"requests>=2.31.0",
"openai>=1.6.1",
@@ -87,6 +88,7 @@ embeddings = [
"botocore>=1.31.57",
'ibm-watsonx-ai>=1.1.2; python_version >= "3.10"',
"ollama>=0.3.0",
"sentencepiece"
]
azure = ["adlfs>=2024.2.0"]

View File

@@ -241,4 +241,4 @@ def __warn_on_fork():
if hasattr(os, "register_at_fork"):
os.register_at_fork(before=__warn_on_fork)
os.register_at_fork(before=__warn_on_fork) # type: ignore[attr-defined]

View File

@@ -20,3 +20,4 @@ from .jinaai import JinaEmbeddings
from .watsonx import WatsonxEmbeddings
from .voyageai import VoyageAIEmbeddingFunction
from .colpali import ColPaliEmbeddings
from .siglip import SigLipEmbeddings

View File

@@ -0,0 +1,148 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import concurrent.futures
import io
import os
from typing import TYPE_CHECKING, List, Union
import urllib.parse as urlparse
import numpy as np
import pyarrow as pa
from tqdm import tqdm
from pydantic import PrivateAttr
from ..util import attempt_import_or_raise
from .base import EmbeddingFunction
from .registry import register
from .utils import IMAGES, url_retrieve
if TYPE_CHECKING:
import PIL
import torch
@register("siglip")
class SigLipEmbeddings(EmbeddingFunction):
model_name: str = "google/siglip-base-patch16-224"
device: str = "cpu"
batch_size: int = 64
normalize: bool = True
_model = PrivateAttr()
_processor = PrivateAttr()
_tokenizer = PrivateAttr()
_torch = PrivateAttr()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
transformers = attempt_import_or_raise("transformers")
self._torch = attempt_import_or_raise("torch")
self._processor = transformers.AutoProcessor.from_pretrained(self.model_name)
self._model = transformers.SiglipModel.from_pretrained(self.model_name)
self._model.to(self.device)
self._model.eval()
self._ndims = None
def ndims(self):
if self._ndims is None:
self._ndims = self.generate_text_embeddings("foo").shape[0]
return self._ndims
def compute_query_embeddings(
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
) -> List[np.ndarray]:
if isinstance(query, str):
return [self.generate_text_embeddings(query)]
else:
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(query, PIL.Image.Image):
return [self.generate_image_embedding(query)]
else:
raise TypeError("SigLIP supports str or PIL Image as query")
def generate_text_embeddings(self, text: str) -> np.ndarray:
torch = self._torch
text_inputs = self._processor(
text=text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=64,
).to(self.device)
with torch.no_grad():
text_features = self._model.get_text_features(**text_inputs)
if self.normalize:
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
return text_features.cpu().detach().numpy().squeeze()
def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]:
if isinstance(images, (str, bytes)):
images = [images]
elif isinstance(images, pa.Array):
images = images.to_pylist()
elif isinstance(images, pa.ChunkedArray):
images = images.combine_chunks().to_pylist()
return images
def compute_source_embeddings(
self, images: IMAGES, *args, **kwargs
) -> List[np.ndarray]:
images = self.sanitize_input(images)
embeddings = []
for i in range(0, len(images), self.batch_size):
j = min(i + self.batch_size, len(images))
batch = images[i:j]
embeddings.extend(self._parallel_get(batch))
return embeddings
def _parallel_get(self, images: Union[List[str], List[bytes]]) -> List[np.ndarray]:
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [
executor.submit(self.generate_image_embedding, image)
for image in images
]
return [f.result() for f in tqdm(futures, desc="SigLIP Embedding")]
def generate_image_embedding(
self, image: Union[str, bytes, "PIL.Image.Image"]
) -> np.ndarray:
image = self._to_pil(image)
image = self._processor(images=image, return_tensors="pt")["pixel_values"]
return self._encode_and_normalize_image(image)
def _encode_and_normalize_image(self, image_tensor: "torch.Tensor") -> np.ndarray:
torch = self._torch
with torch.no_grad():
image_features = self._model.get_image_features(
image_tensor.to(self.device)
)
if self.normalize:
image_features = image_features / image_features.norm(
dim=-1, keepdim=True
)
return image_features.cpu().detach().numpy().squeeze()
def _to_pil(self, image: Union[str, bytes, "PIL.Image.Image"]):
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(image, PIL.Image.Image):
return image.convert("RGB") if image.mode != "RGB" else image
elif isinstance(image, bytes):
return PIL.Image.open(io.BytesIO(image)).convert("RGB")
elif isinstance(image, str):
parsed = urlparse.urlparse(image)
if parsed.scheme == "file":
return PIL.Image.open(parsed.path).convert("RGB")
elif parsed.scheme == "":
path = image if os.name == "nt" else parsed.path
return PIL.Image.open(path).convert("RGB")
elif parsed.scheme.startswith("http"):
image_bytes = url_retrieve(image)
return PIL.Image.open(io.BytesIO(image_bytes)).convert("RGB")
else:
raise NotImplementedError("Only local and http(s) urls are supported")
else:
raise ValueError(f"Unsupported image type: {type(image)}")

View File

@@ -4,7 +4,6 @@
import importlib
import io
import os
import lancedb
import numpy as np
import pandas as pd
@@ -12,7 +11,6 @@ import pyarrow as pa
import pytest
from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector, MultiVector
import requests
# These are integration tests for embedding functions.
# They are slow because they require downloading models
@@ -98,9 +96,34 @@ def test_basic_text_embeddings(alias, tmp_path):
assert not np.allclose(actual.vector, actual.vector2)
@pytest.mark.slow
def test_openclip(tmp_path):
@pytest.fixture(scope="module")
def test_images():
import requests
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
uris = [
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
]
image_bytes = [requests.get(uri).content for uri in uris]
return labels, uris, image_bytes
@pytest.fixture(scope="module")
def query_image_bytes():
import requests
query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg"
image_bytes = requests.get(query_image_uri).content
return image_bytes
@pytest.mark.slow
def test_openclip(tmp_path, test_images, query_image_bytes):
from PIL import Image
db = lancedb.connect(tmp_path)
@@ -114,20 +137,12 @@ def test_openclip(tmp_path):
vector: Vector(func.ndims()) = func.VectorField()
vec_from_bytes: Vector(func.ndims()) = func.VectorField()
labels, uris, image_bytes_list = test_images
table = db.create_table("images", schema=Images)
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
uris = [
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
]
# get each uri as bytes
image_bytes = [requests.get(uri).content for uri in uris]
table.add(
pd.DataFrame({"label": labels, "image_uri": uris, "image_bytes": image_bytes})
pd.DataFrame(
{"label": labels, "image_uri": uris, "image_bytes": image_bytes_list}
)
)
# text search
@@ -146,9 +161,7 @@ def test_openclip(tmp_path):
assert np.allclose(actual.vector, frombytes.vector)
# image search
query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg"
image_bytes = requests.get(query_image_uri).content
query_image = Image.open(io.BytesIO(image_bytes))
query_image = Image.open(io.BytesIO(query_image_bytes))
actual = (
table.search(query_image, vector_column_name="vector")
.limit(1)
@@ -524,6 +537,8 @@ def test_voyageai_embedding_function():
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
)
def test_voyageai_multimodal_embedding_function():
import requests
voyageai = (
get_registry().get("voyageai").create(name="voyage-multimodal-3", max_retries=0)
)
@@ -639,3 +654,71 @@ def test_colpali(tmp_path):
assert len(first_row["image_vectors"][0]) == func.ndims(), (
"Vector dimension mismatch"
)
@pytest.mark.slow
def test_siglip(tmp_path, test_images, query_image_bytes):
from PIL import Image
labels, uris, image_bytes = test_images
db = lancedb.connect(tmp_path)
registry = get_registry()
func = registry.get("siglip").create(max_retries=0)
class Images(LanceModel):
label: str
image_uri: str = func.SourceField()
image_bytes: bytes = func.SourceField()
vector: Vector(func.ndims()) = func.VectorField()
vec_from_bytes: Vector(func.ndims()) = func.VectorField()
table = db.create_table("images", schema=Images)
table.add(
pd.DataFrame(
{
"label": labels,
"image_uri": uris,
"image_bytes": image_bytes,
}
)
)
# Text search
actual = (
table.search("man's best friend", vector_column_name="vector")
.limit(1)
.to_pydantic(Images)[0]
)
assert actual.label == "dog"
frombytes = (
table.search("man's best friend", vector_column_name="vec_from_bytes")
.limit(1)
.to_pydantic(Images)[0]
)
assert actual.label == frombytes.label
assert np.allclose(actual.vector, frombytes.vector)
# Image search
query_image = Image.open(io.BytesIO(query_image_bytes))
actual = (
table.search(query_image, vector_column_name="vector")
.limit(1)
.to_pydantic(Images)[0]
)
assert actual.label == "dog"
other = (
table.search(query_image, vector_column_name="vec_from_bytes")
.limit(1)
.to_pydantic(Images)[0]
)
assert actual.label == other.label
arrow_table = table.search().select(["vector", "vec_from_bytes"]).to_arrow()
assert np.allclose(
arrow_table["vector"].combine_chunks().values.to_numpy(),
arrow_table["vec_from_bytes"].combine_chunks().values.to_numpy(),
)