This commit is contained in:
Rok Mihevc
2024-01-05 18:46:24 +01:00
committed by Chang She
parent c112dea28b
commit 67b38d6115
2 changed files with 29 additions and 28 deletions

View File

@@ -237,9 +237,15 @@ def EncodedImage() -> Type[ImageMixin]:
# For pydantic v1
@classmethod
def validate(cls, v):
# if not isinstance(v, bytes):
# raise TypeError("A bytes is needed")
return cls(v)
from lance.arrow import ImageURIArray, EncodedImageType, EncodedImageArray
if isinstance(v, ImageURIArray):
v = v.read_uris()
if isinstance(v, pa.BinaryArray):
v = pa.ExtensionArray.from_storage(EncodedImageType(), v)
if not isinstance(v, EncodedImageArray):
raise TypeError("Invalid input array type", type(v))
return v
if PYDANTIC_VERSION < (2, 0):
@@ -297,7 +303,8 @@ def ImageURI() -> Type[ImageMixin]:
def validate(cls, v):
# if not isinstance(v, str):
# raise TypeError("A str is needed")
return cls(v)
# return cls(v)
return v
if PYDANTIC_VERSION < (2, 0):

View File

@@ -20,24 +20,16 @@ from typing import List, Optional, Tuple
import pyarrow as pa
import pydantic
import pytest
from lancedb.pydantic import (
PYDANTIC_VERSION,
LanceModel,
Vector,
pydantic_to_schema,
EncodedImage,
ImageURI,
)
import pytz
from pydantic import Field
from lance.arrow import EncodedImageArray, EncodedImageType, ImageURIType
from lancedb.pydantic import (
PYDANTIC_VERSION,
EncodedImage,
ImageURI,
LanceModel,
Vector,
pydantic_to_schema,
EncodedImage,
ImageURI,
)
@@ -263,22 +255,24 @@ def test_lance_model():
def test_lance_model_with_lance_types():
class TestModel(LanceModel):
encoded_images: EncodedImage() = Field(default=[])
image_uris: ImageURI() = Field(default=[])
# TODO: tensor type?
img = (b'\\x89PNG\\r\\n\\x1a\\n\\x00\\x00\\x00\\rIHDR\\x00\\x00\\x00'
b'\\x01\\x00\\x00\\x00\\x01\\x08\\x06\\x00\\x00\\x00\\x1f')
default_image_uris = ImageURIArray.from_uris(["/tmp/bar"])
encoded_images = pa.array([img], pa.binary())
default_encoded_images = pa.ExtensionArray.from_storage(EncodedImageType(), encoded_images)
class TestModel(LanceModel):
encoded_images: EncodedImage() = Field(default=default_encoded_images)
image_uris: ImageURI() = Field(default=default_image_uris)
# TODO
schema = pydantic_to_schema(TestModel)
assert schema == TestModel.to_arrow_schema()
assert TestModel.field_names() == ["encoded_images", "image_uris"]
encoded_images = pa.ExtensionArray.from_storage(
EncodedImageType(), pa.array([b"foo", b"bar"], pa.binary())
)
image_uris = pa.ExtensionArray.from_storage(
ImageURIType(), pa.array(["/tmp/foo", "/tmp/bar"], pa.string())
)
expected_model = TestModel()
actual_model = TestModel(encoded_images=default_encoded_images, image_uris=default_image_uris)
assert expected_model == actual_model
t = TestModel()
# assert t == TestModel(encoded_images=encoded_images, image_uris=image_uris)
# TODO: add images to repo and test with real path
# actual_model = TestModel(encoded_images=default_image_uris, image_uris=default_image_uris)
# assert expected_model == actual_model