mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-08 12:52:58 +00:00
python(feat): Imagebind embedding fn support (#1003)
Added imagebind fn support , steps to install mentioned in docstring. pytest slow checks done locally --------- Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
committed by
Weston Pace
parent
538d0320f7
commit
fdabf31984
172
python/lancedb/embeddings/imagebind.py
Normal file
172
python/lancedb/embeddings/imagebind.py
Normal file
@@ -0,0 +1,172 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import cached_property
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import EmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import AUDIO, IMAGES, TEXT
|
||||
|
||||
|
||||
@register("imagebind")
|
||||
class ImageBindEmbeddings(EmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the ImageBind API
|
||||
For generating multi-modal embeddings across
|
||||
six different modalities: images, text, audio, depth, thermal, and IMU data
|
||||
|
||||
to download package, run :
|
||||
`pip install imagebind@git+https://github.com/raghavdixit99/ImageBind`
|
||||
"""
|
||||
|
||||
name: str = "imagebind_huge"
|
||||
device: str = "cpu"
|
||||
normalize: bool = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._ndims = 1024
|
||||
self._audio_extensions = (".mp3", ".wav", ".flac", ".ogg", ".aac")
|
||||
self._image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".bmp")
|
||||
|
||||
@cached_property
|
||||
def embedding_model(self):
|
||||
"""
|
||||
Get the embedding model. This is cached so that the model is only loaded
|
||||
once per process.
|
||||
"""
|
||||
return self.get_embedding_model()
|
||||
|
||||
@cached_property
|
||||
def _data(self):
|
||||
"""
|
||||
Get the data module from imagebind
|
||||
"""
|
||||
data = attempt_import_or_raise("imagebind.data", "imagebind")
|
||||
return data
|
||||
|
||||
@cached_property
|
||||
def _ModalityType(self):
|
||||
"""
|
||||
Get the ModalityType from imagebind
|
||||
"""
|
||||
imagebind = attempt_import_or_raise("imagebind", "imagebind")
|
||||
return imagebind.imagebind_model.ModalityType
|
||||
|
||||
def ndims(self):
|
||||
return self._ndims
|
||||
|
||||
def compute_query_embeddings(
|
||||
self, query: Union[str], *args, **kwargs
|
||||
) -> List[np.ndarray]:
|
||||
"""
|
||||
Compute the embeddings for a given user query
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query : Union[str]
|
||||
The query to embed. A query can be either text, image paths or audio paths.
|
||||
"""
|
||||
query = self.sanitize_input(query)
|
||||
if query[0].endswith(self._audio_extensions):
|
||||
return [self.generate_audio_embeddings(query)]
|
||||
elif query[0].endswith(self._image_extensions):
|
||||
return [self.generate_image_embeddings(query)]
|
||||
else:
|
||||
return [self.generate_text_embeddings(query)]
|
||||
|
||||
def generate_image_embeddings(self, image: IMAGES) -> np.ndarray:
|
||||
torch = attempt_import_or_raise("torch")
|
||||
inputs = {
|
||||
self._ModalityType.VISION: self._data.load_and_transform_vision_data(
|
||||
image, self.device
|
||||
)
|
||||
}
|
||||
with torch.no_grad():
|
||||
image_features = self.embedding_model(inputs)[self._ModalityType.VISION]
|
||||
if self.normalize:
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
return image_features.cpu().numpy().squeeze()
|
||||
|
||||
def generate_audio_embeddings(self, audio: AUDIO) -> np.ndarray:
|
||||
torch = attempt_import_or_raise("torch")
|
||||
inputs = {
|
||||
self._ModalityType.AUDIO: self._data.load_and_transform_audio_data(
|
||||
audio, self.device
|
||||
)
|
||||
}
|
||||
with torch.no_grad():
|
||||
audio_features = self.embedding_model(inputs)[self._ModalityType.AUDIO]
|
||||
if self.normalize:
|
||||
audio_features /= audio_features.norm(dim=-1, keepdim=True)
|
||||
return audio_features.cpu().numpy().squeeze()
|
||||
|
||||
def generate_text_embeddings(self, text: TEXT) -> np.ndarray:
|
||||
torch = attempt_import_or_raise("torch")
|
||||
inputs = {
|
||||
self._ModalityType.TEXT: self._data.load_and_transform_text(
|
||||
text, self.device
|
||||
)
|
||||
}
|
||||
with torch.no_grad():
|
||||
text_features = self.embedding_model(inputs)[self._ModalityType.TEXT]
|
||||
if self.normalize:
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
return text_features.cpu().numpy().squeeze()
|
||||
|
||||
def compute_source_embeddings(
|
||||
self, source: Union[IMAGES, AUDIO], *args, **kwargs
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Get the embeddings for the given sourcefield column in the pydantic model.
|
||||
"""
|
||||
source = self.sanitize_input(source)
|
||||
embeddings = []
|
||||
if source[0].endswith(self._audio_extensions):
|
||||
embeddings.extend(self.generate_audio_embeddings(source))
|
||||
return embeddings
|
||||
elif source[0].endswith(self._image_extensions):
|
||||
embeddings.extend(self.generate_image_embeddings(source))
|
||||
return embeddings
|
||||
else:
|
||||
embeddings.extend(self.generate_text_embeddings(source))
|
||||
return embeddings
|
||||
|
||||
def sanitize_input(
|
||||
self, input: Union[IMAGES, AUDIO]
|
||||
) -> Union[List[bytes], np.ndarray]:
|
||||
"""
|
||||
Sanitize the input to the embedding function.
|
||||
"""
|
||||
if isinstance(input, (str, bytes)):
|
||||
input = [input]
|
||||
elif isinstance(input, pa.Array):
|
||||
input = input.to_pylist()
|
||||
elif isinstance(input, pa.ChunkedArray):
|
||||
input = input.combine_chunks().to_pylist()
|
||||
return input
|
||||
|
||||
def get_embedding_model(self):
|
||||
"""
|
||||
fetches the imagebind embedding model
|
||||
"""
|
||||
imagebind = attempt_import_or_raise("imagebind", "imagebind")
|
||||
model = imagebind.imagebind_model.imagebind_huge(pretrained=True)
|
||||
model.eval()
|
||||
model.to(self.device)
|
||||
return model
|
||||
@@ -36,6 +36,7 @@ TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
|
||||
IMAGES = Union[
|
||||
str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray
|
||||
]
|
||||
AUDIO = Union[str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray]
|
||||
|
||||
|
||||
@deprecated
|
||||
|
||||
@@ -28,6 +28,23 @@ from lancedb.pydantic import LanceModel, Vector
|
||||
# or connection to external api
|
||||
|
||||
|
||||
try:
|
||||
if importlib.util.find_spec("mlx.core") is not None:
|
||||
_mlx = True
|
||||
else:
|
||||
_mlx = None
|
||||
except Exception:
|
||||
_mlx = None
|
||||
|
||||
try:
|
||||
if importlib.util.find_spec("imagebind") is not None:
|
||||
_imagebind = True
|
||||
else:
|
||||
_imagebind = None
|
||||
except Exception:
|
||||
_imagebind = None
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("alias", ["sentence-transformers", "openai"])
|
||||
def test_basic_text_embeddings(alias, tmp_path):
|
||||
@@ -158,6 +175,89 @@ def test_openclip(tmp_path):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
_imagebind is None,
|
||||
reason="skip if imagebind not installed.",
|
||||
)
|
||||
@pytest.mark.slow
|
||||
def test_imagebind(tmp_path):
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import pandas as pd
|
||||
import requests
|
||||
|
||||
import lancedb.embeddings.imagebind
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
print(f"Created temporary directory {temp_dir}")
|
||||
|
||||
def download_images(image_uris):
|
||||
downloaded_image_paths = []
|
||||
for uri in image_uris:
|
||||
try:
|
||||
response = requests.get(uri, stream=True)
|
||||
if response.status_code == 200:
|
||||
# Extract image name from URI
|
||||
image_name = os.path.basename(uri)
|
||||
image_path = os.path.join(temp_dir, image_name)
|
||||
with open(image_path, "wb") as out_file:
|
||||
shutil.copyfileobj(response.raw, out_file)
|
||||
downloaded_image_paths.append(image_path)
|
||||
except Exception as e: # noqa: PERF203
|
||||
print(f"Failed to download {uri}. Error: {e}")
|
||||
return temp_dir, downloaded_image_paths
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = get_registry()
|
||||
func = registry.get("imagebind").create(max_retries=0)
|
||||
|
||||
class Images(LanceModel):
|
||||
label: str
|
||||
image_uri: str = func.SourceField()
|
||||
vector: Vector(func.ndims()) = func.VectorField()
|
||||
|
||||
table = db.create_table("images", schema=Images)
|
||||
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
|
||||
uris = [
|
||||
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
||||
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
|
||||
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
|
||||
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
|
||||
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
|
||||
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
|
||||
]
|
||||
temp_dir, downloaded_images = download_images(uris)
|
||||
table.add(pd.DataFrame({"label": labels, "image_uri": downloaded_images}))
|
||||
# text search
|
||||
actual = (
|
||||
table.search("man's best friend", vector_column_name="vector")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
assert actual.label == "dog"
|
||||
|
||||
# image search
|
||||
query_image_uri = [
|
||||
"https://live.staticflickr.com/65535/33336453970_491665f66e_h.jpg"
|
||||
]
|
||||
temp_dir, downloaded_images = download_images(query_image_uri)
|
||||
query_image_uri = downloaded_images[0]
|
||||
actual = (
|
||||
table.search(query_image_uri, vector_column_name="vector")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
assert actual.label == "dog"
|
||||
|
||||
if os.path.isdir(temp_dir):
|
||||
shutil.rmtree(temp_dir)
|
||||
print(f"Deleted temporary directory {temp_dir}")
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
|
||||
@@ -217,13 +317,6 @@ def test_gemini_embedding(tmp_path):
|
||||
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
||||
|
||||
|
||||
try:
|
||||
if importlib.util.find_spec("mlx.core") is not None:
|
||||
_mlx = True
|
||||
except ImportError:
|
||||
_mlx = None
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
_mlx is None,
|
||||
reason="mlx tests only required for apple users.",
|
||||
|
||||
Reference in New Issue
Block a user