From c1dfad675adf15b19b3a59917af4b08d55cfb2f1 Mon Sep 17 00:00:00 2001 From: Chang She <759245+changhiskhan@users.noreply.github.com> Date: Tue, 5 Mar 2024 10:08:19 -0800 Subject: [PATCH] see if we make EncodedImage work --- python/python/lancedb/embeddings/open_clip.py | 5 +++++ python/python/lancedb/table.py | 21 +++++++++++++------ 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/python/python/lancedb/embeddings/open_clip.py b/python/python/lancedb/embeddings/open_clip.py index 4d0a3a32..3df1d1fb 100644 --- a/python/python/lancedb/embeddings/open_clip.py +++ b/python/python/lancedb/embeddings/open_clip.py @@ -126,6 +126,10 @@ class OpenClipEmbeddings(EmbeddingFunction): """ Issue concurrent requests to retrieve the image data """ + return [ + self.generate_image_embedding(image) for image in tqdm(images) + ] + with concurrent.futures.ThreadPoolExecutor() as executor: futures = [ executor.submit(self.generate_image_embedding, image) @@ -145,6 +149,7 @@ class OpenClipEmbeddings(EmbeddingFunction): The image to embed. If the image is a str, it is treated as a uri. If the image is bytes, it is treated as the raw image bytes. """ + import pdb; pdb.set_trace() torch = attempt_import_or_raise("torch") # TODO handle retry and errors for https image = self._to_pil(image) diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index b07cef2f..c4297480 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -37,6 +37,7 @@ import pyarrow as pa import pyarrow.compute as pc import pyarrow.fs as pa_fs from lance import LanceDataset +from lance.dependencies import _check_for_hugging_face from lance.vector import vec_to_table from .common import DATA, VEC, VECTOR_COLUMN_NAME @@ -74,7 +75,16 @@ def _sanitize_data( on_bad_vectors: str, fill_value: Any, ): - if isinstance(data, list): + if _check_for_hugging_face(data): + # Huggingface datasets + import datasets + + if isinstance(data, datasets.Dataset): + if schema is None: + schema = data.features.arrow_schema + data = data.data.to_batches() + import pdb; pdb.set_trace() + elif isinstance(data, list): # convert to list of dict if data is a bunch of LanceModels if isinstance(data[0], LanceModel): schema = data[0].__class__.to_arrow_schema() @@ -136,11 +146,10 @@ def _to_record_batch_generator( data: Iterable, schema, metadata, on_bad_vectors, fill_value ): for batch in data: - if not isinstance(batch, pa.RecordBatch): - table = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value) - for batch in table.to_batches(): - yield batch - else: + if isinstance(batch, pa.RecordBatch): + batch = pa.Table.from_batches([batch]) + table = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value) + for batch in table.to_batches(): yield batch