mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-13 15:22:57 +00:00
work
This commit is contained in:
@@ -237,8 +237,8 @@ def EncodedImage() -> Type[ImageMixin]:
|
||||
# For pydantic v1
|
||||
@classmethod
|
||||
def validate(cls, v):
|
||||
if not isinstance(v, bytes):
|
||||
raise TypeError("A bytes is needed")
|
||||
# if not isinstance(v, bytes):
|
||||
# raise TypeError("A bytes is needed")
|
||||
return cls(v)
|
||||
|
||||
if PYDANTIC_VERSION < (2, 0):
|
||||
@@ -295,8 +295,8 @@ def ImageURI() -> Type[ImageMixin]:
|
||||
# For pydantic v1
|
||||
@classmethod
|
||||
def validate(cls, v):
|
||||
if not isinstance(v, str):
|
||||
raise TypeError("A str is needed")
|
||||
# if not isinstance(v, str):
|
||||
# raise TypeError("A str is needed")
|
||||
return cls(v)
|
||||
|
||||
if PYDANTIC_VERSION < (2, 0):
|
||||
@@ -353,6 +353,9 @@ def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
|
||||
return pa.struct(fields)
|
||||
elif issubclass(field.annotation, FixedSizeListMixin):
|
||||
return pa.list_(field.annotation.value_arrow_type(), field.annotation.dim())
|
||||
elif issubclass(field.annotation, ImageMixin):
|
||||
return field.annotation.value_arrow_type()
|
||||
|
||||
return _py_type_to_arrow_type(field.annotation, field)
|
||||
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ from lancedb.pydantic import (
|
||||
ImageURI,
|
||||
)
|
||||
from pydantic import Field
|
||||
from lance.arrow import EncodedImageArray, EncodedImageType, ImageURIType
|
||||
|
||||
from lancedb.pydantic import (
|
||||
PYDANTIC_VERSION,
|
||||
@@ -263,14 +264,21 @@ def test_lance_model():
|
||||
|
||||
def test_lance_model_with_lance_types():
|
||||
class TestModel(LanceModel):
|
||||
image: EncodedImage() = Field()
|
||||
uri: ImageURI() = Field()
|
||||
encoded_images: EncodedImage() = Field(default=[])
|
||||
image_uris: ImageURI() = Field(default=[])
|
||||
# TODO: tensor type?
|
||||
|
||||
# TODO
|
||||
# schema = pydantic_to_schema(TestModel)
|
||||
# assert schema == TestModel.to_arrow_schema()
|
||||
# assert TestModel.field_names() == ["image", "uri"]
|
||||
#
|
||||
# t = TestModel()
|
||||
# assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3], image=EncodedImageArray(), uri="https://lancedb.dev")
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
assert schema == TestModel.to_arrow_schema()
|
||||
assert TestModel.field_names() == ["encoded_images", "image_uris"]
|
||||
|
||||
encoded_images = pa.ExtensionArray.from_storage(
|
||||
EncodedImageType(), pa.array([b"foo", b"bar"], pa.binary())
|
||||
)
|
||||
image_uris = pa.ExtensionArray.from_storage(
|
||||
ImageURIType(), pa.array(["/tmp/foo", "/tmp/bar"], pa.string())
|
||||
)
|
||||
|
||||
t = TestModel()
|
||||
# assert t == TestModel(encoded_images=encoded_images, image_uris=image_uris)
|
||||
|
||||
Reference in New Issue
Block a user