mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-22 21:09:58 +00:00
feat(embeddings): add siglip embedding support to lancedb (#2499)
### Summary This PR adds **SigLIP** (Sigmoid Loss Image Pretraining) as a new embedding model in the LanceDB embedding registry. SigLIP improves image-text alignment performance using sigmoid-based contrastive loss and offers robust zero-shot generalization. Fixes #2498 ### What’s Implemented #### 1. `SigLIP` Embedding Class * Added `SigLIP` support under `python/lancedb/embeddings/siglip.py` * Implements: * `compute_source_embeddings` * `_batch_generate_embeddings` * Normalization logic * Batch-wise progress logging for image embedding #### 2. Registry Integration * Registered `SigLIP` in `embeddings/__init__.py` * `SigLIP` now usable via `connect(..., embedding="siglip")` #### 3. Evaluation Benchmark Support * Added SigLIP to `test_embeddings_slow.py` for side-by-side benchmarking with OpenCLIP and ImageBind ### New Test Methods #### `test_siglip` * End-to-end test to verify embeddings table creation and vector shape for SigLIP  #### `test_siglip_vs_openclip_vs_imagebind_benchmark_full` * Benchmarks: * **Recall\@1 / 5 / 10** * **mAP (Mean Average Precision)** * **Embedding & Search Latency** * Dimensionality reporting  ### Notes * SigLIP outputs 768D embeddings (vs 512D for OpenCLIP) * Benchmark shows competitive performance despite higher dimensionality * I'm still new to contributing to open-source and learning as I go. Please feel free to suggest any improvements — I'm happy to make changes!
This commit is contained in:
committed by
GitHub
parent
02595dc475
commit
7d0127b376
1
python/.gitignore
vendored
1
python/.gitignore
vendored
@@ -1,2 +1,3 @@
|
||||
# Test data created by some example tests
|
||||
data/
|
||||
_lancedb.pyd
|
||||
|
||||
@@ -68,8 +68,9 @@ dev = [
|
||||
"pyright",
|
||||
'typing-extensions>=4.0.0; python_version < "3.11"',
|
||||
]
|
||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings-python"]
|
||||
clip = ["torch", "pillow", "open-clip-torch"]
|
||||
siglip = ["torch", "pillow", "transformers>=4.41.0","sentencepiece"]
|
||||
embeddings = [
|
||||
"requests>=2.31.0",
|
||||
"openai>=1.6.1",
|
||||
@@ -87,6 +88,7 @@ embeddings = [
|
||||
"botocore>=1.31.57",
|
||||
'ibm-watsonx-ai>=1.1.2; python_version >= "3.10"',
|
||||
"ollama>=0.3.0",
|
||||
"sentencepiece"
|
||||
]
|
||||
azure = ["adlfs>=2024.2.0"]
|
||||
|
||||
|
||||
@@ -241,4 +241,4 @@ def __warn_on_fork():
|
||||
|
||||
|
||||
if hasattr(os, "register_at_fork"):
|
||||
os.register_at_fork(before=__warn_on_fork)
|
||||
os.register_at_fork(before=__warn_on_fork) # type: ignore[attr-defined]
|
||||
|
||||
@@ -20,3 +20,4 @@ from .jinaai import JinaEmbeddings
|
||||
from .watsonx import WatsonxEmbeddings
|
||||
from .voyageai import VoyageAIEmbeddingFunction
|
||||
from .colpali import ColPaliEmbeddings
|
||||
from .siglip import SigLipEmbeddings
|
||||
|
||||
148
python/python/lancedb/embeddings/siglip.py
Normal file
148
python/python/lancedb/embeddings/siglip.py
Normal file
@@ -0,0 +1,148 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import concurrent.futures
|
||||
import io
|
||||
import os
|
||||
from typing import TYPE_CHECKING, List, Union
|
||||
import urllib.parse as urlparse
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
from tqdm import tqdm
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import EmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import IMAGES, url_retrieve
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
|
||||
@register("siglip")
|
||||
class SigLipEmbeddings(EmbeddingFunction):
|
||||
model_name: str = "google/siglip-base-patch16-224"
|
||||
device: str = "cpu"
|
||||
batch_size: int = 64
|
||||
normalize: bool = True
|
||||
|
||||
_model = PrivateAttr()
|
||||
_processor = PrivateAttr()
|
||||
_tokenizer = PrivateAttr()
|
||||
_torch = PrivateAttr()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
transformers = attempt_import_or_raise("transformers")
|
||||
self._torch = attempt_import_or_raise("torch")
|
||||
|
||||
self._processor = transformers.AutoProcessor.from_pretrained(self.model_name)
|
||||
self._model = transformers.SiglipModel.from_pretrained(self.model_name)
|
||||
self._model.to(self.device)
|
||||
self._model.eval()
|
||||
self._ndims = None
|
||||
|
||||
def ndims(self):
|
||||
if self._ndims is None:
|
||||
self._ndims = self.generate_text_embeddings("foo").shape[0]
|
||||
return self._ndims
|
||||
|
||||
def compute_query_embeddings(
|
||||
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
|
||||
) -> List[np.ndarray]:
|
||||
if isinstance(query, str):
|
||||
return [self.generate_text_embeddings(query)]
|
||||
else:
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
if isinstance(query, PIL.Image.Image):
|
||||
return [self.generate_image_embedding(query)]
|
||||
else:
|
||||
raise TypeError("SigLIP supports str or PIL Image as query")
|
||||
|
||||
def generate_text_embeddings(self, text: str) -> np.ndarray:
|
||||
torch = self._torch
|
||||
text_inputs = self._processor(
|
||||
text=text,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=64,
|
||||
).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
text_features = self._model.get_text_features(**text_inputs)
|
||||
if self.normalize:
|
||||
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
||||
return text_features.cpu().detach().numpy().squeeze()
|
||||
|
||||
def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]:
|
||||
if isinstance(images, (str, bytes)):
|
||||
images = [images]
|
||||
elif isinstance(images, pa.Array):
|
||||
images = images.to_pylist()
|
||||
elif isinstance(images, pa.ChunkedArray):
|
||||
images = images.combine_chunks().to_pylist()
|
||||
return images
|
||||
|
||||
def compute_source_embeddings(
|
||||
self, images: IMAGES, *args, **kwargs
|
||||
) -> List[np.ndarray]:
|
||||
images = self.sanitize_input(images)
|
||||
embeddings = []
|
||||
|
||||
for i in range(0, len(images), self.batch_size):
|
||||
j = min(i + self.batch_size, len(images))
|
||||
batch = images[i:j]
|
||||
embeddings.extend(self._parallel_get(batch))
|
||||
return embeddings
|
||||
|
||||
def _parallel_get(self, images: Union[List[str], List[bytes]]) -> List[np.ndarray]:
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
futures = [
|
||||
executor.submit(self.generate_image_embedding, image)
|
||||
for image in images
|
||||
]
|
||||
return [f.result() for f in tqdm(futures, desc="SigLIP Embedding")]
|
||||
|
||||
def generate_image_embedding(
|
||||
self, image: Union[str, bytes, "PIL.Image.Image"]
|
||||
) -> np.ndarray:
|
||||
image = self._to_pil(image)
|
||||
image = self._processor(images=image, return_tensors="pt")["pixel_values"]
|
||||
return self._encode_and_normalize_image(image)
|
||||
|
||||
def _encode_and_normalize_image(self, image_tensor: "torch.Tensor") -> np.ndarray:
|
||||
torch = self._torch
|
||||
with torch.no_grad():
|
||||
image_features = self._model.get_image_features(
|
||||
image_tensor.to(self.device)
|
||||
)
|
||||
if self.normalize:
|
||||
image_features = image_features / image_features.norm(
|
||||
dim=-1, keepdim=True
|
||||
)
|
||||
return image_features.cpu().detach().numpy().squeeze()
|
||||
|
||||
def _to_pil(self, image: Union[str, bytes, "PIL.Image.Image"]):
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
return image.convert("RGB") if image.mode != "RGB" else image
|
||||
elif isinstance(image, bytes):
|
||||
return PIL.Image.open(io.BytesIO(image)).convert("RGB")
|
||||
elif isinstance(image, str):
|
||||
parsed = urlparse.urlparse(image)
|
||||
if parsed.scheme == "file":
|
||||
return PIL.Image.open(parsed.path).convert("RGB")
|
||||
elif parsed.scheme == "":
|
||||
path = image if os.name == "nt" else parsed.path
|
||||
return PIL.Image.open(path).convert("RGB")
|
||||
elif parsed.scheme.startswith("http"):
|
||||
image_bytes = url_retrieve(image)
|
||||
return PIL.Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
||||
else:
|
||||
raise NotImplementedError("Only local and http(s) urls are supported")
|
||||
else:
|
||||
raise ValueError(f"Unsupported image type: {type(image)}")
|
||||
@@ -4,7 +4,6 @@
|
||||
import importlib
|
||||
import io
|
||||
import os
|
||||
|
||||
import lancedb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -12,7 +11,6 @@ import pyarrow as pa
|
||||
import pytest
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector, MultiVector
|
||||
import requests
|
||||
|
||||
# These are integration tests for embedding functions.
|
||||
# They are slow because they require downloading models
|
||||
@@ -98,9 +96,34 @@ def test_basic_text_embeddings(alias, tmp_path):
|
||||
assert not np.allclose(actual.vector, actual.vector2)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_openclip(tmp_path):
|
||||
@pytest.fixture(scope="module")
|
||||
def test_images():
|
||||
import requests
|
||||
|
||||
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
|
||||
uris = [
|
||||
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
||||
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
|
||||
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
|
||||
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
|
||||
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
|
||||
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
|
||||
]
|
||||
image_bytes = [requests.get(uri).content for uri in uris]
|
||||
return labels, uris, image_bytes
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def query_image_bytes():
|
||||
import requests
|
||||
|
||||
query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg"
|
||||
image_bytes = requests.get(query_image_uri).content
|
||||
return image_bytes
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_openclip(tmp_path, test_images, query_image_bytes):
|
||||
from PIL import Image
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
@@ -114,20 +137,12 @@ def test_openclip(tmp_path):
|
||||
vector: Vector(func.ndims()) = func.VectorField()
|
||||
vec_from_bytes: Vector(func.ndims()) = func.VectorField()
|
||||
|
||||
labels, uris, image_bytes_list = test_images
|
||||
table = db.create_table("images", schema=Images)
|
||||
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
|
||||
uris = [
|
||||
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
||||
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
|
||||
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
|
||||
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
|
||||
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
|
||||
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
|
||||
]
|
||||
# get each uri as bytes
|
||||
image_bytes = [requests.get(uri).content for uri in uris]
|
||||
table.add(
|
||||
pd.DataFrame({"label": labels, "image_uri": uris, "image_bytes": image_bytes})
|
||||
pd.DataFrame(
|
||||
{"label": labels, "image_uri": uris, "image_bytes": image_bytes_list}
|
||||
)
|
||||
)
|
||||
|
||||
# text search
|
||||
@@ -146,9 +161,7 @@ def test_openclip(tmp_path):
|
||||
assert np.allclose(actual.vector, frombytes.vector)
|
||||
|
||||
# image search
|
||||
query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg"
|
||||
image_bytes = requests.get(query_image_uri).content
|
||||
query_image = Image.open(io.BytesIO(image_bytes))
|
||||
query_image = Image.open(io.BytesIO(query_image_bytes))
|
||||
actual = (
|
||||
table.search(query_image, vector_column_name="vector")
|
||||
.limit(1)
|
||||
@@ -524,6 +537,8 @@ def test_voyageai_embedding_function():
|
||||
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
||||
)
|
||||
def test_voyageai_multimodal_embedding_function():
|
||||
import requests
|
||||
|
||||
voyageai = (
|
||||
get_registry().get("voyageai").create(name="voyage-multimodal-3", max_retries=0)
|
||||
)
|
||||
@@ -639,3 +654,71 @@ def test_colpali(tmp_path):
|
||||
assert len(first_row["image_vectors"][0]) == func.ndims(), (
|
||||
"Vector dimension mismatch"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_siglip(tmp_path, test_images, query_image_bytes):
|
||||
from PIL import Image
|
||||
|
||||
labels, uris, image_bytes = test_images
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = get_registry()
|
||||
func = registry.get("siglip").create(max_retries=0)
|
||||
|
||||
class Images(LanceModel):
|
||||
label: str
|
||||
image_uri: str = func.SourceField()
|
||||
image_bytes: bytes = func.SourceField()
|
||||
vector: Vector(func.ndims()) = func.VectorField()
|
||||
vec_from_bytes: Vector(func.ndims()) = func.VectorField()
|
||||
|
||||
table = db.create_table("images", schema=Images)
|
||||
|
||||
table.add(
|
||||
pd.DataFrame(
|
||||
{
|
||||
"label": labels,
|
||||
"image_uri": uris,
|
||||
"image_bytes": image_bytes,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Text search
|
||||
actual = (
|
||||
table.search("man's best friend", vector_column_name="vector")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
assert actual.label == "dog"
|
||||
|
||||
frombytes = (
|
||||
table.search("man's best friend", vector_column_name="vec_from_bytes")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
assert actual.label == frombytes.label
|
||||
assert np.allclose(actual.vector, frombytes.vector)
|
||||
|
||||
# Image search
|
||||
query_image = Image.open(io.BytesIO(query_image_bytes))
|
||||
actual = (
|
||||
table.search(query_image, vector_column_name="vector")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
assert actual.label == "dog"
|
||||
|
||||
other = (
|
||||
table.search(query_image, vector_column_name="vec_from_bytes")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
assert actual.label == other.label
|
||||
|
||||
arrow_table = table.search().select(["vector", "vec_from_bytes"]).to_arrow()
|
||||
assert np.allclose(
|
||||
arrow_table["vector"].combine_chunks().values.to_numpy(),
|
||||
arrow_table["vec_from_bytes"].combine_chunks().values.to_numpy(),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user