python(feat): Imagebind embedding fn support (#1003)

Added imagebind fn support , steps to install mentioned in docstring. 
pytest slow checks done locally

---------

Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
Raghav Dixit
2024-02-22 01:17:08 -05:00
committed by Weston Pace
parent 538d0320f7
commit fdabf31984
3 changed files with 273 additions and 7 deletions

View File

@@ -28,6 +28,23 @@ from lancedb.pydantic import LanceModel, Vector
# or connection to external api
try:
if importlib.util.find_spec("mlx.core") is not None:
_mlx = True
else:
_mlx = None
except Exception:
_mlx = None
try:
if importlib.util.find_spec("imagebind") is not None:
_imagebind = True
else:
_imagebind = None
except Exception:
_imagebind = None
@pytest.mark.slow
@pytest.mark.parametrize("alias", ["sentence-transformers", "openai"])
def test_basic_text_embeddings(alias, tmp_path):
@@ -158,6 +175,89 @@ def test_openclip(tmp_path):
)
@pytest.mark.skipif(
_imagebind is None,
reason="skip if imagebind not installed.",
)
@pytest.mark.slow
def test_imagebind(tmp_path):
import os
import shutil
import tempfile
import pandas as pd
import requests
import lancedb.embeddings.imagebind
from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector
with tempfile.TemporaryDirectory() as temp_dir:
print(f"Created temporary directory {temp_dir}")
def download_images(image_uris):
downloaded_image_paths = []
for uri in image_uris:
try:
response = requests.get(uri, stream=True)
if response.status_code == 200:
# Extract image name from URI
image_name = os.path.basename(uri)
image_path = os.path.join(temp_dir, image_name)
with open(image_path, "wb") as out_file:
shutil.copyfileobj(response.raw, out_file)
downloaded_image_paths.append(image_path)
except Exception as e: # noqa: PERF203
print(f"Failed to download {uri}. Error: {e}")
return temp_dir, downloaded_image_paths
db = lancedb.connect(tmp_path)
registry = get_registry()
func = registry.get("imagebind").create(max_retries=0)
class Images(LanceModel):
label: str
image_uri: str = func.SourceField()
vector: Vector(func.ndims()) = func.VectorField()
table = db.create_table("images", schema=Images)
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
uris = [
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
]
temp_dir, downloaded_images = download_images(uris)
table.add(pd.DataFrame({"label": labels, "image_uri": downloaded_images}))
# text search
actual = (
table.search("man's best friend", vector_column_name="vector")
.limit(1)
.to_pydantic(Images)[0]
)
assert actual.label == "dog"
# image search
query_image_uri = [
"https://live.staticflickr.com/65535/33336453970_491665f66e_h.jpg"
]
temp_dir, downloaded_images = download_images(query_image_uri)
query_image_uri = downloaded_images[0]
actual = (
table.search(query_image_uri, vector_column_name="vector")
.limit(1)
.to_pydantic(Images)[0]
)
assert actual.label == "dog"
if os.path.isdir(temp_dir):
shutil.rmtree(temp_dir)
print(f"Deleted temporary directory {temp_dir}")
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
@@ -217,13 +317,6 @@ def test_gemini_embedding(tmp_path):
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
try:
if importlib.util.find_spec("mlx.core") is not None:
_mlx = True
except ImportError:
_mlx = None
@pytest.mark.skipif(
_mlx is None,
reason="mlx tests only required for apple users.",