diff --git a/python/lancedb/embeddings/imagebind.py b/python/lancedb/embeddings/imagebind.py new file mode 100644 index 00000000..eb89d505 --- /dev/null +++ b/python/lancedb/embeddings/imagebind.py @@ -0,0 +1,172 @@ +# Copyright (c) 2023. LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import cached_property +from typing import List, Union + +import numpy as np +import pyarrow as pa + +from ..util import attempt_import_or_raise +from .base import EmbeddingFunction +from .registry import register +from .utils import AUDIO, IMAGES, TEXT + + +@register("imagebind") +class ImageBindEmbeddings(EmbeddingFunction): + """ + An embedding function that uses the ImageBind API + For generating multi-modal embeddings across + six different modalities: images, text, audio, depth, thermal, and IMU data + + to download package, run : + `pip install imagebind@git+https://github.com/raghavdixit99/ImageBind` + """ + + name: str = "imagebind_huge" + device: str = "cpu" + normalize: bool = False + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._ndims = 1024 + self._audio_extensions = (".mp3", ".wav", ".flac", ".ogg", ".aac") + self._image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".bmp") + + @cached_property + def embedding_model(self): + """ + Get the embedding model. This is cached so that the model is only loaded + once per process. + """ + return self.get_embedding_model() + + @cached_property + def _data(self): + """ + Get the data module from imagebind + """ + data = attempt_import_or_raise("imagebind.data", "imagebind") + return data + + @cached_property + def _ModalityType(self): + """ + Get the ModalityType from imagebind + """ + imagebind = attempt_import_or_raise("imagebind", "imagebind") + return imagebind.imagebind_model.ModalityType + + def ndims(self): + return self._ndims + + def compute_query_embeddings( + self, query: Union[str], *args, **kwargs + ) -> List[np.ndarray]: + """ + Compute the embeddings for a given user query + + Parameters + ---------- + query : Union[str] + The query to embed. A query can be either text, image paths or audio paths. + """ + query = self.sanitize_input(query) + if query[0].endswith(self._audio_extensions): + return [self.generate_audio_embeddings(query)] + elif query[0].endswith(self._image_extensions): + return [self.generate_image_embeddings(query)] + else: + return [self.generate_text_embeddings(query)] + + def generate_image_embeddings(self, image: IMAGES) -> np.ndarray: + torch = attempt_import_or_raise("torch") + inputs = { + self._ModalityType.VISION: self._data.load_and_transform_vision_data( + image, self.device + ) + } + with torch.no_grad(): + image_features = self.embedding_model(inputs)[self._ModalityType.VISION] + if self.normalize: + image_features /= image_features.norm(dim=-1, keepdim=True) + return image_features.cpu().numpy().squeeze() + + def generate_audio_embeddings(self, audio: AUDIO) -> np.ndarray: + torch = attempt_import_or_raise("torch") + inputs = { + self._ModalityType.AUDIO: self._data.load_and_transform_audio_data( + audio, self.device + ) + } + with torch.no_grad(): + audio_features = self.embedding_model(inputs)[self._ModalityType.AUDIO] + if self.normalize: + audio_features /= audio_features.norm(dim=-1, keepdim=True) + return audio_features.cpu().numpy().squeeze() + + def generate_text_embeddings(self, text: TEXT) -> np.ndarray: + torch = attempt_import_or_raise("torch") + inputs = { + self._ModalityType.TEXT: self._data.load_and_transform_text( + text, self.device + ) + } + with torch.no_grad(): + text_features = self.embedding_model(inputs)[self._ModalityType.TEXT] + if self.normalize: + text_features /= text_features.norm(dim=-1, keepdim=True) + return text_features.cpu().numpy().squeeze() + + def compute_source_embeddings( + self, source: Union[IMAGES, AUDIO], *args, **kwargs + ) -> List[np.array]: + """ + Get the embeddings for the given sourcefield column in the pydantic model. + """ + source = self.sanitize_input(source) + embeddings = [] + if source[0].endswith(self._audio_extensions): + embeddings.extend(self.generate_audio_embeddings(source)) + return embeddings + elif source[0].endswith(self._image_extensions): + embeddings.extend(self.generate_image_embeddings(source)) + return embeddings + else: + embeddings.extend(self.generate_text_embeddings(source)) + return embeddings + + def sanitize_input( + self, input: Union[IMAGES, AUDIO] + ) -> Union[List[bytes], np.ndarray]: + """ + Sanitize the input to the embedding function. + """ + if isinstance(input, (str, bytes)): + input = [input] + elif isinstance(input, pa.Array): + input = input.to_pylist() + elif isinstance(input, pa.ChunkedArray): + input = input.combine_chunks().to_pylist() + return input + + def get_embedding_model(self): + """ + fetches the imagebind embedding model + """ + imagebind = attempt_import_or_raise("imagebind", "imagebind") + model = imagebind.imagebind_model.imagebind_huge(pretrained=True) + model.eval() + model.to(self.device) + return model diff --git a/python/lancedb/embeddings/utils.py b/python/lancedb/embeddings/utils.py index ed9162ba..fe997bbc 100644 --- a/python/lancedb/embeddings/utils.py +++ b/python/lancedb/embeddings/utils.py @@ -36,6 +36,7 @@ TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray] IMAGES = Union[ str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray ] +AUDIO = Union[str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray] @deprecated diff --git a/python/tests/test_embeddings_slow.py b/python/tests/test_embeddings_slow.py index dff931c1..11b024f0 100644 --- a/python/tests/test_embeddings_slow.py +++ b/python/tests/test_embeddings_slow.py @@ -28,6 +28,23 @@ from lancedb.pydantic import LanceModel, Vector # or connection to external api +try: + if importlib.util.find_spec("mlx.core") is not None: + _mlx = True + else: + _mlx = None +except Exception: + _mlx = None + +try: + if importlib.util.find_spec("imagebind") is not None: + _imagebind = True + else: + _imagebind = None +except Exception: + _imagebind = None + + @pytest.mark.slow @pytest.mark.parametrize("alias", ["sentence-transformers", "openai"]) def test_basic_text_embeddings(alias, tmp_path): @@ -158,6 +175,89 @@ def test_openclip(tmp_path): ) +@pytest.mark.skipif( + _imagebind is None, + reason="skip if imagebind not installed.", +) +@pytest.mark.slow +def test_imagebind(tmp_path): + import os + import shutil + import tempfile + + import pandas as pd + import requests + + import lancedb.embeddings.imagebind + from lancedb.embeddings import get_registry + from lancedb.pydantic import LanceModel, Vector + + with tempfile.TemporaryDirectory() as temp_dir: + print(f"Created temporary directory {temp_dir}") + + def download_images(image_uris): + downloaded_image_paths = [] + for uri in image_uris: + try: + response = requests.get(uri, stream=True) + if response.status_code == 200: + # Extract image name from URI + image_name = os.path.basename(uri) + image_path = os.path.join(temp_dir, image_name) + with open(image_path, "wb") as out_file: + shutil.copyfileobj(response.raw, out_file) + downloaded_image_paths.append(image_path) + except Exception as e: # noqa: PERF203 + print(f"Failed to download {uri}. Error: {e}") + return temp_dir, downloaded_image_paths + + db = lancedb.connect(tmp_path) + registry = get_registry() + func = registry.get("imagebind").create(max_retries=0) + + class Images(LanceModel): + label: str + image_uri: str = func.SourceField() + vector: Vector(func.ndims()) = func.VectorField() + + table = db.create_table("images", schema=Images) + labels = ["cat", "cat", "dog", "dog", "horse", "horse"] + uris = [ + "http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg", + "http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg", + "http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg", + "http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg", + "http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg", + "http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg", + ] + temp_dir, downloaded_images = download_images(uris) + table.add(pd.DataFrame({"label": labels, "image_uri": downloaded_images})) + # text search + actual = ( + table.search("man's best friend", vector_column_name="vector") + .limit(1) + .to_pydantic(Images)[0] + ) + assert actual.label == "dog" + + # image search + query_image_uri = [ + "https://live.staticflickr.com/65535/33336453970_491665f66e_h.jpg" + ] + temp_dir, downloaded_images = download_images(query_image_uri) + query_image_uri = downloaded_images[0] + actual = ( + table.search(query_image_uri, vector_column_name="vector") + .limit(1) + .to_pydantic(Images)[0] + ) + assert actual.label == "dog" + + if os.path.isdir(temp_dir): + shutil.rmtree(temp_dir) + print(f"Deleted temporary directory {temp_dir}") + + @pytest.mark.slow @pytest.mark.skipif( os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set" @@ -217,13 +317,6 @@ def test_gemini_embedding(tmp_path): assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world" -try: - if importlib.util.find_spec("mlx.core") is not None: - _mlx = True -except ImportError: - _mlx = None - - @pytest.mark.skipif( _mlx is None, reason="mlx tests only required for apple users.",