mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-09 05:12:58 +00:00
feat(embeddings): add siglip embedding support to lancedb (#2499)
### Summary This PR adds **SigLIP** (Sigmoid Loss Image Pretraining) as a new embedding model in the LanceDB embedding registry. SigLIP improves image-text alignment performance using sigmoid-based contrastive loss and offers robust zero-shot generalization. Fixes #2498 ### What’s Implemented #### 1. `SigLIP` Embedding Class * Added `SigLIP` support under `python/lancedb/embeddings/siglip.py` * Implements: * `compute_source_embeddings` * `_batch_generate_embeddings` * Normalization logic * Batch-wise progress logging for image embedding #### 2. Registry Integration * Registered `SigLIP` in `embeddings/__init__.py` * `SigLIP` now usable via `connect(..., embedding="siglip")` #### 3. Evaluation Benchmark Support * Added SigLIP to `test_embeddings_slow.py` for side-by-side benchmarking with OpenCLIP and ImageBind ### New Test Methods #### `test_siglip` * End-to-end test to verify embeddings table creation and vector shape for SigLIP  #### `test_siglip_vs_openclip_vs_imagebind_benchmark_full` * Benchmarks: * **Recall\@1 / 5 / 10** * **mAP (Mean Average Precision)** * **Embedding & Search Latency** * Dimensionality reporting  ### Notes * SigLIP outputs 768D embeddings (vs 512D for OpenCLIP) * Benchmark shows competitive performance despite higher dimensionality * I'm still new to contributing to open-source and learning as I go. Please feel free to suggest any improvements — I'm happy to make changes!
This commit is contained in:
committed by
GitHub
parent
02595dc475
commit
7d0127b376
@@ -4,7 +4,6 @@
|
||||
import importlib
|
||||
import io
|
||||
import os
|
||||
|
||||
import lancedb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -12,7 +11,6 @@ import pyarrow as pa
|
||||
import pytest
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector, MultiVector
|
||||
import requests
|
||||
|
||||
# These are integration tests for embedding functions.
|
||||
# They are slow because they require downloading models
|
||||
@@ -98,9 +96,34 @@ def test_basic_text_embeddings(alias, tmp_path):
|
||||
assert not np.allclose(actual.vector, actual.vector2)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_openclip(tmp_path):
|
||||
@pytest.fixture(scope="module")
|
||||
def test_images():
|
||||
import requests
|
||||
|
||||
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
|
||||
uris = [
|
||||
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
||||
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
|
||||
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
|
||||
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
|
||||
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
|
||||
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
|
||||
]
|
||||
image_bytes = [requests.get(uri).content for uri in uris]
|
||||
return labels, uris, image_bytes
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def query_image_bytes():
|
||||
import requests
|
||||
|
||||
query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg"
|
||||
image_bytes = requests.get(query_image_uri).content
|
||||
return image_bytes
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_openclip(tmp_path, test_images, query_image_bytes):
|
||||
from PIL import Image
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
@@ -114,20 +137,12 @@ def test_openclip(tmp_path):
|
||||
vector: Vector(func.ndims()) = func.VectorField()
|
||||
vec_from_bytes: Vector(func.ndims()) = func.VectorField()
|
||||
|
||||
labels, uris, image_bytes_list = test_images
|
||||
table = db.create_table("images", schema=Images)
|
||||
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
|
||||
uris = [
|
||||
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
||||
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
|
||||
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
|
||||
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
|
||||
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
|
||||
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
|
||||
]
|
||||
# get each uri as bytes
|
||||
image_bytes = [requests.get(uri).content for uri in uris]
|
||||
table.add(
|
||||
pd.DataFrame({"label": labels, "image_uri": uris, "image_bytes": image_bytes})
|
||||
pd.DataFrame(
|
||||
{"label": labels, "image_uri": uris, "image_bytes": image_bytes_list}
|
||||
)
|
||||
)
|
||||
|
||||
# text search
|
||||
@@ -146,9 +161,7 @@ def test_openclip(tmp_path):
|
||||
assert np.allclose(actual.vector, frombytes.vector)
|
||||
|
||||
# image search
|
||||
query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg"
|
||||
image_bytes = requests.get(query_image_uri).content
|
||||
query_image = Image.open(io.BytesIO(image_bytes))
|
||||
query_image = Image.open(io.BytesIO(query_image_bytes))
|
||||
actual = (
|
||||
table.search(query_image, vector_column_name="vector")
|
||||
.limit(1)
|
||||
@@ -524,6 +537,8 @@ def test_voyageai_embedding_function():
|
||||
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
||||
)
|
||||
def test_voyageai_multimodal_embedding_function():
|
||||
import requests
|
||||
|
||||
voyageai = (
|
||||
get_registry().get("voyageai").create(name="voyage-multimodal-3", max_retries=0)
|
||||
)
|
||||
@@ -639,3 +654,71 @@ def test_colpali(tmp_path):
|
||||
assert len(first_row["image_vectors"][0]) == func.ndims(), (
|
||||
"Vector dimension mismatch"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_siglip(tmp_path, test_images, query_image_bytes):
|
||||
from PIL import Image
|
||||
|
||||
labels, uris, image_bytes = test_images
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = get_registry()
|
||||
func = registry.get("siglip").create(max_retries=0)
|
||||
|
||||
class Images(LanceModel):
|
||||
label: str
|
||||
image_uri: str = func.SourceField()
|
||||
image_bytes: bytes = func.SourceField()
|
||||
vector: Vector(func.ndims()) = func.VectorField()
|
||||
vec_from_bytes: Vector(func.ndims()) = func.VectorField()
|
||||
|
||||
table = db.create_table("images", schema=Images)
|
||||
|
||||
table.add(
|
||||
pd.DataFrame(
|
||||
{
|
||||
"label": labels,
|
||||
"image_uri": uris,
|
||||
"image_bytes": image_bytes,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Text search
|
||||
actual = (
|
||||
table.search("man's best friend", vector_column_name="vector")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
assert actual.label == "dog"
|
||||
|
||||
frombytes = (
|
||||
table.search("man's best friend", vector_column_name="vec_from_bytes")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
assert actual.label == frombytes.label
|
||||
assert np.allclose(actual.vector, frombytes.vector)
|
||||
|
||||
# Image search
|
||||
query_image = Image.open(io.BytesIO(query_image_bytes))
|
||||
actual = (
|
||||
table.search(query_image, vector_column_name="vector")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
assert actual.label == "dog"
|
||||
|
||||
other = (
|
||||
table.search(query_image, vector_column_name="vec_from_bytes")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
assert actual.label == other.label
|
||||
|
||||
arrow_table = table.search().select(["vector", "vec_from_bytes"]).to_arrow()
|
||||
assert np.allclose(
|
||||
arrow_table["vector"].combine_chunks().values.to_numpy(),
|
||||
arrow_table["vec_from_bytes"].combine_chunks().values.to_numpy(),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user