diff --git a/python/python/lancedb/pydantic.py b/python/python/lancedb/pydantic.py index 6bebfc62..355cfcb1 100644 --- a/python/python/lancedb/pydantic.py +++ b/python/python/lancedb/pydantic.py @@ -227,7 +227,7 @@ def EncodedImage() -> Type[ImageMixin]: ) -> CoreSchema: return core_schema.no_info_after_validator_function( cls, - core_schema.binary_schema(), + core_schema.str_schema(), ) @classmethod @@ -292,7 +292,7 @@ def ImageURI() -> Type[ImageMixin]: ) -> CoreSchema: return core_schema.no_info_after_validator_function( cls, - core_schema.string_schema(), + core_schema.str_schema(), ) @classmethod @@ -302,9 +302,13 @@ def ImageURI() -> Type[ImageMixin]: # For pydantic v1 @classmethod def validate(cls, v): - # if not isinstance(v, str): - # raise TypeError("A str is needed") - # return cls(v) + from lance.arrow import ImageURIArray, ImageURIType + + if isinstance(v, (str, pa.StringArray)): + v = pa.ExtensionArray.from_storage(ImageURIType(), v) + if not isinstance(v, ImageURIArray): + raise TypeError("Invalid input array type", type(v)) + return v if PYDANTIC_VERSION < (2, 0): diff --git a/python/python/tests/test_pydantic.py b/python/python/tests/test_pydantic.py index 58dd7c83..d6afb2ec 100644 --- a/python/python/tests/test_pydantic.py +++ b/python/python/tests/test_pydantic.py @@ -18,6 +18,7 @@ import os from datetime import date, datetime from typing import List, Optional, Tuple +import numpy as np import pyarrow as pa import pydantic import pytest @@ -256,18 +257,11 @@ def test_lance_model(): def test_lance_model_with_lance_types(): - img = ( - b"\\x89PNG\\r\\n\\x1a\\n\\x00\\x00\\x00\\rIHDR\\x00\\x00\\x00" - b"\\x01\\x00\\x00\\x00\\x01\\x08\\x06\\x00\\x00\\x00\\x1f" - ) png_uris = [ "file://" + os.path.join(os.path.dirname(__file__), "images/1.png"), ] default_image_uris = ImageURIArray.from_uris(png_uris) - encoded_images = pa.array([img], pa.binary()) - default_encoded_images = pa.ExtensionArray.from_storage( - EncodedImageType(), encoded_images - ) + default_encoded_images = default_image_uris.read_uris() class TestModel(LanceModel): encoded_images: EncodedImage() = Field(default=default_encoded_images) @@ -278,6 +272,7 @@ def test_lance_model_with_lance_types(): assert TestModel.field_names() == ["encoded_images", "image_uris"] expected_model = TestModel() + actual_model = TestModel( encoded_images=default_encoded_images, image_uris=default_image_uris )