mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-11 06:12:58 +00:00
changes
This commit is contained in:
@@ -237,9 +237,15 @@ def EncodedImage() -> Type[ImageMixin]:
|
||||
# For pydantic v1
|
||||
@classmethod
|
||||
def validate(cls, v):
|
||||
# if not isinstance(v, bytes):
|
||||
# raise TypeError("A bytes is needed")
|
||||
return cls(v)
|
||||
from lance.arrow import ImageURIArray, EncodedImageType, EncodedImageArray
|
||||
if isinstance(v, ImageURIArray):
|
||||
v = v.read_uris()
|
||||
if isinstance(v, pa.BinaryArray):
|
||||
v = pa.ExtensionArray.from_storage(EncodedImageType(), v)
|
||||
if not isinstance(v, EncodedImageArray):
|
||||
raise TypeError("Invalid input array type", type(v))
|
||||
|
||||
return v
|
||||
|
||||
if PYDANTIC_VERSION < (2, 0):
|
||||
|
||||
@@ -297,7 +303,8 @@ def ImageURI() -> Type[ImageMixin]:
|
||||
def validate(cls, v):
|
||||
# if not isinstance(v, str):
|
||||
# raise TypeError("A str is needed")
|
||||
return cls(v)
|
||||
# return cls(v)
|
||||
return v
|
||||
|
||||
if PYDANTIC_VERSION < (2, 0):
|
||||
|
||||
|
||||
@@ -20,24 +20,16 @@ from typing import List, Optional, Tuple
|
||||
import pyarrow as pa
|
||||
import pydantic
|
||||
import pytest
|
||||
from lancedb.pydantic import (
|
||||
PYDANTIC_VERSION,
|
||||
LanceModel,
|
||||
Vector,
|
||||
pydantic_to_schema,
|
||||
EncodedImage,
|
||||
ImageURI,
|
||||
)
|
||||
import pytz
|
||||
from pydantic import Field
|
||||
from lance.arrow import EncodedImageArray, EncodedImageType, ImageURIType
|
||||
|
||||
from lancedb.pydantic import (
|
||||
PYDANTIC_VERSION,
|
||||
EncodedImage,
|
||||
ImageURI,
|
||||
LanceModel,
|
||||
Vector,
|
||||
pydantic_to_schema,
|
||||
EncodedImage,
|
||||
ImageURI,
|
||||
)
|
||||
|
||||
|
||||
@@ -263,22 +255,24 @@ def test_lance_model():
|
||||
|
||||
|
||||
def test_lance_model_with_lance_types():
|
||||
class TestModel(LanceModel):
|
||||
encoded_images: EncodedImage() = Field(default=[])
|
||||
image_uris: ImageURI() = Field(default=[])
|
||||
# TODO: tensor type?
|
||||
img = (b'\\x89PNG\\r\\n\\x1a\\n\\x00\\x00\\x00\\rIHDR\\x00\\x00\\x00'
|
||||
b'\\x01\\x00\\x00\\x00\\x01\\x08\\x06\\x00\\x00\\x00\\x1f')
|
||||
default_image_uris = ImageURIArray.from_uris(["/tmp/bar"])
|
||||
encoded_images = pa.array([img], pa.binary())
|
||||
default_encoded_images = pa.ExtensionArray.from_storage(EncodedImageType(), encoded_images)
|
||||
|
||||
class TestModel(LanceModel):
|
||||
encoded_images: EncodedImage() = Field(default=default_encoded_images)
|
||||
image_uris: ImageURI() = Field(default=default_image_uris)
|
||||
|
||||
# TODO
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
assert schema == TestModel.to_arrow_schema()
|
||||
assert TestModel.field_names() == ["encoded_images", "image_uris"]
|
||||
|
||||
encoded_images = pa.ExtensionArray.from_storage(
|
||||
EncodedImageType(), pa.array([b"foo", b"bar"], pa.binary())
|
||||
)
|
||||
image_uris = pa.ExtensionArray.from_storage(
|
||||
ImageURIType(), pa.array(["/tmp/foo", "/tmp/bar"], pa.string())
|
||||
)
|
||||
expected_model = TestModel()
|
||||
actual_model = TestModel(encoded_images=default_encoded_images, image_uris=default_image_uris)
|
||||
assert expected_model == actual_model
|
||||
|
||||
t = TestModel()
|
||||
# assert t == TestModel(encoded_images=encoded_images, image_uris=image_uris)
|
||||
# TODO: add images to repo and test with real path
|
||||
# actual_model = TestModel(encoded_images=default_image_uris, image_uris=default_image_uris)
|
||||
# assert expected_model == actual_model
|
||||
|
||||
Reference in New Issue
Block a user