see if we make EncodedImage work

This commit is contained in:
Chang She
2024-03-05 10:08:19 -08:00
parent 2e1838a62a
commit c1dfad675a
2 changed files with 20 additions and 6 deletions

View File

@@ -126,6 +126,10 @@ class OpenClipEmbeddings(EmbeddingFunction):
"""
Issue concurrent requests to retrieve the image data
"""
return [
self.generate_image_embedding(image) for image in tqdm(images)
]
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [
executor.submit(self.generate_image_embedding, image)
@@ -145,6 +149,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
The image to embed. If the image is a str, it is treated as a uri.
If the image is bytes, it is treated as the raw image bytes.
"""
import pdb; pdb.set_trace()
torch = attempt_import_or_raise("torch")
# TODO handle retry and errors for https
image = self._to_pil(image)

View File

@@ -37,6 +37,7 @@ import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.fs as pa_fs
from lance import LanceDataset
from lance.dependencies import _check_for_hugging_face
from lance.vector import vec_to_table
from .common import DATA, VEC, VECTOR_COLUMN_NAME
@@ -74,7 +75,16 @@ def _sanitize_data(
on_bad_vectors: str,
fill_value: Any,
):
if isinstance(data, list):
if _check_for_hugging_face(data):
# Huggingface datasets
import datasets
if isinstance(data, datasets.Dataset):
if schema is None:
schema = data.features.arrow_schema
data = data.data.to_batches()
import pdb; pdb.set_trace()
elif isinstance(data, list):
# convert to list of dict if data is a bunch of LanceModels
if isinstance(data[0], LanceModel):
schema = data[0].__class__.to_arrow_schema()
@@ -136,11 +146,10 @@ def _to_record_batch_generator(
data: Iterable, schema, metadata, on_bad_vectors, fill_value
):
for batch in data:
if not isinstance(batch, pa.RecordBatch):
table = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value)
for batch in table.to_batches():
yield batch
else:
if isinstance(batch, pa.RecordBatch):
batch = pa.Table.from_batches([batch])
table = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value)
for batch in table.to_batches():
yield batch