This commit is contained in:
Chang She
2024-03-03 12:27:18 -08:00
parent 63399dc0ee
commit e68fbf65cc
2 changed files with 109 additions and 118 deletions

View File

@@ -45,8 +45,6 @@ from lance.arrow import (
EncodedImageArray,
)
from .embeddings import EmbeddingFunctionRegistry
PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__)
try:
from pydantic_core import CoreSchema, core_schema
@@ -205,8 +203,8 @@ class ImageMixin(ABC):
raise NotImplementedError
def EncodedImage() -> Type[ImageMixin]:
"""Pydantic EncodedImage Type.
class Image(bytes, ImageMixin):
"""Pydantic type for inlined images.
!!! warning
Experimental feature.
@@ -215,84 +213,81 @@ def EncodedImage() -> Type[ImageMixin]:
--------
>>> import pydantic
>>> from lancedb.pydantic import EncodedImage
>>> from lancedb.pydantic import Image
...
>>> class MyModel(pydantic.BaseModel):
... image: EncodedImage()
... image: Image
>>> schema = pydantic_to_schema(MyModel)
>>> assert schema == pa.schema([
... pa.field("image", pa.binary(), False)
... ])
"""
class EncodedImage(bytes, ImageMixin):
def __repr__(self):
return "EncodedImage()"
def __repr__(self):
return "Image()"
@staticmethod
def value_arrow_type() -> pa.DataType:
return pa.binary()
@staticmethod
def value_arrow_type() -> pa.DataType:
return pa.binary()
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
) -> CoreSchema:
def validate_from_bytes(value: bytes) -> EncodedImageScalar:
return EncodedImageScalar(value)
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
) -> CoreSchema:
def validate_from_bytes(value: bytes) -> EncodedImageScalar:
return EncodedImageScalar(value)
from_bytes_schema = core_schema.chain_schema(
from_bytes_schema = core_schema.chain_schema(
[
core_schema.bytes_schema(),
core_schema.no_info_plain_validator_function(validate_from_bytes),
]
)
return core_schema.json_or_python_schema(
json_schema=from_bytes_schema,
python_schema=core_schema.union_schema(
[
core_schema.bytes_schema(),
core_schema.no_info_plain_validator_function(validate_from_bytes),
core_schema.is_instance_schema(EncodedImageArray),
from_bytes_schema,
]
)
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: instance.values
),
)
return core_schema.json_or_python_schema(
json_schema=from_bytes_schema,
python_schema=core_schema.union_schema(
[
core_schema.is_instance_schema(EncodedImageArray),
from_bytes_schema,
]
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: instance.values
),
)
@classmethod
def __get_pydantic_json_schema__(
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
return handler(core_schema.bytes_schema())
@classmethod
def __get_validators__(cls) -> Generator[Callable, None, None]:
yield cls.validate
# For pydantic v2
@classmethod
def validate(cls, v):
if isinstance(v, ImageURIArray):
v = v.read_uris()
if isinstance(v, pa.BinaryArray):
v = pa.ExtensionArray.from_storage(EncodedImageType(), v)
if not isinstance(v, EncodedImageArray):
raise TypeError("Invalid input array type", type(v))
return v
if PYDANTIC_VERSION < (2, 0):
@classmethod
def __get_pydantic_json_schema__(
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
return handler(core_schema.bytes_schema())
@classmethod
def __get_validators__(cls) -> Generator[Callable, None, None]:
yield cls.validate
# For pydantic v2
@classmethod
def validate(cls, v):
if isinstance(v, ImageURIArray):
v = v.read_uris()
if isinstance(v, pa.BinaryArray):
v = pa.ExtensionArray.from_storage(EncodedImageType(), v)
if not isinstance(v, EncodedImageArray):
raise TypeError("Invalid input array type", type(v))
return v
if PYDANTIC_VERSION < (2, 0):
@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]):
field_schema["type"] = "string"
field_schema["format"] = "binary"
return EncodedImage
def __modify_schema__(cls, field_schema: Dict[str, Any]):
field_schema["type"] = "string"
field_schema["format"] = "binary"
def ImageURI() -> Type[ImageMixin]:
class ImageURI(str, ImageMixin):
"""Pydantic ImageUri Type.
!!! warning
@@ -310,67 +305,63 @@ def ImageURI() -> Type[ImageMixin]:
>>> assert schema == pa.schema([
... pa.field("url", pa.utf8(), False),
... ])
"""
"""
def __repr__(self):
return "ImageURI()"
class ImageURI(str, ImageMixin):
def __repr__(self):
return "ImageURI()"
@staticmethod
def value_arrow_type() -> pa.DataType:
return pa.string()
@staticmethod
def value_arrow_type() -> pa.DataType:
return pa.string()
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
) -> CoreSchema:
def validate_from_str(value: str) -> ImageURIScalar:
return ImageURIScalar(value)
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
) -> CoreSchema:
def validate_from_str(value: str) -> ImageURIScalar:
return ImageURIScalar(value)
from_str_schema = core_schema.chain_schema(
[
core_schema.str_schema(),
core_schema.no_info_plain_validator_function(validate_from_str),
]
)
from_str_schema = core_schema.chain_schema(
return core_schema.json_or_python_schema(
json_schema=from_str_schema,
python_schema=core_schema.union_schema(
[
core_schema.str_schema(),
core_schema.no_info_plain_validator_function(validate_from_str),
core_schema.is_instance_schema(ImageURIArray),
from_str_schema,
]
)
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: instance.values
),
)
return core_schema.json_or_python_schema(
json_schema=from_str_schema,
python_schema=core_schema.union_schema(
[
core_schema.is_instance_schema(ImageURIArray),
from_str_schema,
]
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: instance.values
),
)
@classmethod
def __get_validators__(cls) -> Generator[Callable, None, None]:
yield cls.validate
# For pydantic v2
@classmethod
def validate(cls, v):
from lance.arrow import ImageURIArray, ImageURIType
if isinstance(v, (str, pa.StringArray)):
v = pa.ExtensionArray.from_storage(ImageURIType(), v)
if not isinstance(v, ImageURIArray):
raise TypeError("Invalid input array type", type(v))
return v
if PYDANTIC_VERSION < (2, 0):
@classmethod
def __get_validators__(cls) -> Generator[Callable, None, None]:
yield cls.validate
# For pydantic v2
@classmethod
def validate(cls, v):
from lance.arrow import ImageURIArray, ImageURIType
if isinstance(v, (str, pa.StringArray)):
v = pa.ExtensionArray.from_storage(ImageURIType(), v)
if not isinstance(v, ImageURIArray):
raise TypeError("Invalid input array type", type(v))
return v
if PYDANTIC_VERSION < (2, 0):
@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]):
field_schema["type"] = "string"
field_schema["format"] = "string"
return ImageURI
def __modify_schema__(cls, field_schema: Dict[str, Any]):
field_schema["type"] = "string"
field_schema["format"] = "string"
if PYDANTIC_VERSION.major < 2:

View File

@@ -28,7 +28,7 @@ from pydantic import Field
from lancedb.pydantic import (
PYDANTIC_VERSION,
EncodedImage,
Image,
ImageURI,
LanceModel,
Vector,
@@ -268,7 +268,7 @@ def test_lance_model_with_lance_types():
default_encoded_images = default_image_uris.read_uris()
class TestModel(LanceModel):
encoded_images: EncodedImage() = Field(default=default_encoded_images)
encoded_images: Image() = Field(default=default_encoded_images)
image_uris: ImageURI() = Field(default=default_image_uris)
schema = pydantic_to_schema(TestModel)