From c112dea28b5a5e37475e6ead2499a94cce1d3664 Mon Sep 17 00:00:00 2001 From: Rok Mihevc Date: Wed, 25 Oct 2023 12:31:34 +0200 Subject: [PATCH] work --- python/python/lancedb/pydantic.py | 11 +++++++---- python/python/tests/test_pydantic.py | 24 ++++++++++++++++-------- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/python/python/lancedb/pydantic.py b/python/python/lancedb/pydantic.py index 58b6294f..e3ddcee2 100644 --- a/python/python/lancedb/pydantic.py +++ b/python/python/lancedb/pydantic.py @@ -237,8 +237,8 @@ def EncodedImage() -> Type[ImageMixin]: # For pydantic v1 @classmethod def validate(cls, v): - if not isinstance(v, bytes): - raise TypeError("A bytes is needed") + # if not isinstance(v, bytes): + # raise TypeError("A bytes is needed") return cls(v) if PYDANTIC_VERSION < (2, 0): @@ -295,8 +295,8 @@ def ImageURI() -> Type[ImageMixin]: # For pydantic v1 @classmethod def validate(cls, v): - if not isinstance(v, str): - raise TypeError("A str is needed") + # if not isinstance(v, str): + # raise TypeError("A str is needed") return cls(v) if PYDANTIC_VERSION < (2, 0): @@ -353,6 +353,9 @@ def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType: return pa.struct(fields) elif issubclass(field.annotation, FixedSizeListMixin): return pa.list_(field.annotation.value_arrow_type(), field.annotation.dim()) + elif issubclass(field.annotation, ImageMixin): + return field.annotation.value_arrow_type() + return _py_type_to_arrow_type(field.annotation, field) diff --git a/python/python/tests/test_pydantic.py b/python/python/tests/test_pydantic.py index d8d57e2e..7911776b 100644 --- a/python/python/tests/test_pydantic.py +++ b/python/python/tests/test_pydantic.py @@ -29,6 +29,7 @@ from lancedb.pydantic import ( ImageURI, ) from pydantic import Field +from lance.arrow import EncodedImageArray, EncodedImageType, ImageURIType from lancedb.pydantic import ( PYDANTIC_VERSION, @@ -263,14 +264,21 @@ def test_lance_model(): def test_lance_model_with_lance_types(): class TestModel(LanceModel): - image: EncodedImage() = Field() - uri: ImageURI() = Field() + encoded_images: EncodedImage() = Field(default=[]) + image_uris: ImageURI() = Field(default=[]) # TODO: tensor type? # TODO - # schema = pydantic_to_schema(TestModel) - # assert schema == TestModel.to_arrow_schema() - # assert TestModel.field_names() == ["image", "uri"] - # - # t = TestModel() - # assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3], image=EncodedImageArray(), uri="https://lancedb.dev") + schema = pydantic_to_schema(TestModel) + assert schema == TestModel.to_arrow_schema() + assert TestModel.field_names() == ["encoded_images", "image_uris"] + + encoded_images = pa.ExtensionArray.from_storage( + EncodedImageType(), pa.array([b"foo", b"bar"], pa.binary()) + ) + image_uris = pa.ExtensionArray.from_storage( + ImageURIType(), pa.array(["/tmp/foo", "/tmp/bar"], pa.string()) + ) + + t = TestModel() + # assert t == TestModel(encoded_images=encoded_images, image_uris=image_uris)