This commit is contained in:
Rok Mihevc
2023-10-25 12:31:34 +02:00
committed by Chang She
parent d662b9744e
commit c112dea28b
2 changed files with 23 additions and 12 deletions

View File

@@ -237,8 +237,8 @@ def EncodedImage() -> Type[ImageMixin]:
# For pydantic v1
@classmethod
def validate(cls, v):
if not isinstance(v, bytes):
raise TypeError("A bytes is needed")
# if not isinstance(v, bytes):
# raise TypeError("A bytes is needed")
return cls(v)
if PYDANTIC_VERSION < (2, 0):
@@ -295,8 +295,8 @@ def ImageURI() -> Type[ImageMixin]:
# For pydantic v1
@classmethod
def validate(cls, v):
if not isinstance(v, str):
raise TypeError("A str is needed")
# if not isinstance(v, str):
# raise TypeError("A str is needed")
return cls(v)
if PYDANTIC_VERSION < (2, 0):
@@ -353,6 +353,9 @@ def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
return pa.struct(fields)
elif issubclass(field.annotation, FixedSizeListMixin):
return pa.list_(field.annotation.value_arrow_type(), field.annotation.dim())
elif issubclass(field.annotation, ImageMixin):
return field.annotation.value_arrow_type()
return _py_type_to_arrow_type(field.annotation, field)

View File

@@ -29,6 +29,7 @@ from lancedb.pydantic import (
ImageURI,
)
from pydantic import Field
from lance.arrow import EncodedImageArray, EncodedImageType, ImageURIType
from lancedb.pydantic import (
PYDANTIC_VERSION,
@@ -263,14 +264,21 @@ def test_lance_model():
def test_lance_model_with_lance_types():
class TestModel(LanceModel):
image: EncodedImage() = Field()
uri: ImageURI() = Field()
encoded_images: EncodedImage() = Field(default=[])
image_uris: ImageURI() = Field(default=[])
# TODO: tensor type?
# TODO
# schema = pydantic_to_schema(TestModel)
# assert schema == TestModel.to_arrow_schema()
# assert TestModel.field_names() == ["image", "uri"]
#
# t = TestModel()
# assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3], image=EncodedImageArray(), uri="https://lancedb.dev")
schema = pydantic_to_schema(TestModel)
assert schema == TestModel.to_arrow_schema()
assert TestModel.field_names() == ["encoded_images", "image_uris"]
encoded_images = pa.ExtensionArray.from_storage(
EncodedImageType(), pa.array([b"foo", b"bar"], pa.binary())
)
image_uris = pa.ExtensionArray.from_storage(
ImageURIType(), pa.array(["/tmp/foo", "/tmp/bar"], pa.string())
)
t = TestModel()
# assert t == TestModel(encoded_images=encoded_images, image_uris=image_uris)