From 67b38d611502c838cdc920223e72bf7bf6a14e89 Mon Sep 17 00:00:00 2001 From: Rok Mihevc Date: Fri, 5 Jan 2024 18:46:24 +0100 Subject: [PATCH] changes --- python/python/lancedb/pydantic.py | 15 +++++++--- python/python/tests/test_pydantic.py | 42 ++++++++++++---------------- 2 files changed, 29 insertions(+), 28 deletions(-) diff --git a/python/python/lancedb/pydantic.py b/python/python/lancedb/pydantic.py index e3ddcee2..b33f8270 100644 --- a/python/python/lancedb/pydantic.py +++ b/python/python/lancedb/pydantic.py @@ -237,9 +237,15 @@ def EncodedImage() -> Type[ImageMixin]: # For pydantic v1 @classmethod def validate(cls, v): - # if not isinstance(v, bytes): - # raise TypeError("A bytes is needed") - return cls(v) + from lance.arrow import ImageURIArray, EncodedImageType, EncodedImageArray + if isinstance(v, ImageURIArray): + v = v.read_uris() + if isinstance(v, pa.BinaryArray): + v = pa.ExtensionArray.from_storage(EncodedImageType(), v) + if not isinstance(v, EncodedImageArray): + raise TypeError("Invalid input array type", type(v)) + + return v if PYDANTIC_VERSION < (2, 0): @@ -297,7 +303,8 @@ def ImageURI() -> Type[ImageMixin]: def validate(cls, v): # if not isinstance(v, str): # raise TypeError("A str is needed") - return cls(v) + # return cls(v) + return v if PYDANTIC_VERSION < (2, 0): diff --git a/python/python/tests/test_pydantic.py b/python/python/tests/test_pydantic.py index 7911776b..204f91a4 100644 --- a/python/python/tests/test_pydantic.py +++ b/python/python/tests/test_pydantic.py @@ -20,24 +20,16 @@ from typing import List, Optional, Tuple import pyarrow as pa import pydantic import pytest -from lancedb.pydantic import ( - PYDANTIC_VERSION, - LanceModel, - Vector, - pydantic_to_schema, - EncodedImage, - ImageURI, -) +import pytz from pydantic import Field -from lance.arrow import EncodedImageArray, EncodedImageType, ImageURIType from lancedb.pydantic import ( PYDANTIC_VERSION, + EncodedImage, + ImageURI, LanceModel, Vector, pydantic_to_schema, - EncodedImage, - ImageURI, ) @@ -263,22 +255,24 @@ def test_lance_model(): def test_lance_model_with_lance_types(): - class TestModel(LanceModel): - encoded_images: EncodedImage() = Field(default=[]) - image_uris: ImageURI() = Field(default=[]) - # TODO: tensor type? + img = (b'\\x89PNG\\r\\n\\x1a\\n\\x00\\x00\\x00\\rIHDR\\x00\\x00\\x00' + b'\\x01\\x00\\x00\\x00\\x01\\x08\\x06\\x00\\x00\\x00\\x1f') + default_image_uris = ImageURIArray.from_uris(["/tmp/bar"]) + encoded_images = pa.array([img], pa.binary()) + default_encoded_images = pa.ExtensionArray.from_storage(EncodedImageType(), encoded_images) + + class TestModel(LanceModel): + encoded_images: EncodedImage() = Field(default=default_encoded_images) + image_uris: ImageURI() = Field(default=default_image_uris) - # TODO schema = pydantic_to_schema(TestModel) assert schema == TestModel.to_arrow_schema() assert TestModel.field_names() == ["encoded_images", "image_uris"] - encoded_images = pa.ExtensionArray.from_storage( - EncodedImageType(), pa.array([b"foo", b"bar"], pa.binary()) - ) - image_uris = pa.ExtensionArray.from_storage( - ImageURIType(), pa.array(["/tmp/foo", "/tmp/bar"], pa.string()) - ) + expected_model = TestModel() + actual_model = TestModel(encoded_images=default_encoded_images, image_uris=default_image_uris) + assert expected_model == actual_model - t = TestModel() - # assert t == TestModel(encoded_images=encoded_images, image_uris=image_uris) + # TODO: add images to repo and test with real path + # actual_model = TestModel(encoded_images=default_image_uris, image_uris=default_image_uris) + # assert expected_model == actual_model