diff --git a/python/python/lancedb/pydantic.py b/python/python/lancedb/pydantic.py index b33f8270..6bebfc62 100644 --- a/python/python/lancedb/pydantic.py +++ b/python/python/lancedb/pydantic.py @@ -238,6 +238,7 @@ def EncodedImage() -> Type[ImageMixin]: @classmethod def validate(cls, v): from lance.arrow import ImageURIArray, EncodedImageType, EncodedImageArray + if isinstance(v, ImageURIArray): v = v.read_uris() if isinstance(v, pa.BinaryArray): diff --git a/python/python/tests/test_pydantic.py b/python/python/tests/test_pydantic.py index 204f91a4..58dd7c83 100644 --- a/python/python/tests/test_pydantic.py +++ b/python/python/tests/test_pydantic.py @@ -14,6 +14,7 @@ import json import sys +import os from datetime import date, datetime from typing import List, Optional, Tuple @@ -255,11 +256,18 @@ def test_lance_model(): def test_lance_model_with_lance_types(): - img = (b'\\x89PNG\\r\\n\\x1a\\n\\x00\\x00\\x00\\rIHDR\\x00\\x00\\x00' - b'\\x01\\x00\\x00\\x00\\x01\\x08\\x06\\x00\\x00\\x00\\x1f') - default_image_uris = ImageURIArray.from_uris(["/tmp/bar"]) + img = ( + b"\\x89PNG\\r\\n\\x1a\\n\\x00\\x00\\x00\\rIHDR\\x00\\x00\\x00" + b"\\x01\\x00\\x00\\x00\\x01\\x08\\x06\\x00\\x00\\x00\\x1f" + ) + png_uris = [ + "file://" + os.path.join(os.path.dirname(__file__), "images/1.png"), + ] + default_image_uris = ImageURIArray.from_uris(png_uris) encoded_images = pa.array([img], pa.binary()) - default_encoded_images = pa.ExtensionArray.from_storage(EncodedImageType(), encoded_images) + default_encoded_images = pa.ExtensionArray.from_storage( + EncodedImageType(), encoded_images + ) class TestModel(LanceModel): encoded_images: EncodedImage() = Field(default=default_encoded_images) @@ -270,9 +278,12 @@ def test_lance_model_with_lance_types(): assert TestModel.field_names() == ["encoded_images", "image_uris"] expected_model = TestModel() - actual_model = TestModel(encoded_images=default_encoded_images, image_uris=default_image_uris) + actual_model = TestModel( + encoded_images=default_encoded_images, image_uris=default_image_uris + ) assert expected_model == actual_model - # TODO: add images to repo and test with real path - # actual_model = TestModel(encoded_images=default_image_uris, image_uris=default_image_uris) - # assert expected_model == actual_model + actual_model = TestModel( + encoded_images=default_image_uris, image_uris=default_image_uris + ) + assert expected_model == actual_model diff --git a/python/tests/images/1.png b/python/tests/images/1.png new file mode 100644 index 00000000..7097d1a7 Binary files /dev/null and b/python/tests/images/1.png differ