diff --git a/python/python/lancedb/pydantic.py b/python/python/lancedb/pydantic.py index 54cd443e..c2db4481 100644 --- a/python/python/lancedb/pydantic.py +++ b/python/python/lancedb/pydantic.py @@ -45,8 +45,6 @@ from lance.arrow import ( EncodedImageArray, ) -from .embeddings import EmbeddingFunctionRegistry - PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__) try: from pydantic_core import CoreSchema, core_schema @@ -205,8 +203,8 @@ class ImageMixin(ABC): raise NotImplementedError -def EncodedImage() -> Type[ImageMixin]: - """Pydantic EncodedImage Type. +class Image(bytes, ImageMixin): + """Pydantic type for inlined images. !!! warning Experimental feature. @@ -215,84 +213,81 @@ def EncodedImage() -> Type[ImageMixin]: -------- >>> import pydantic - >>> from lancedb.pydantic import EncodedImage + >>> from lancedb.pydantic import Image ... >>> class MyModel(pydantic.BaseModel): - ... image: EncodedImage() + ... image: Image >>> schema = pydantic_to_schema(MyModel) >>> assert schema == pa.schema([ ... pa.field("image", pa.binary(), False) ... ]) """ - class EncodedImage(bytes, ImageMixin): - def __repr__(self): - return "EncodedImage()" + def __repr__(self): + return "Image()" - @staticmethod - def value_arrow_type() -> pa.DataType: - return pa.binary() + @staticmethod + def value_arrow_type() -> pa.DataType: + return pa.binary() - @classmethod - def __get_pydantic_core_schema__( - cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler - ) -> CoreSchema: - def validate_from_bytes(value: bytes) -> EncodedImageScalar: - return EncodedImageScalar(value) + @classmethod + def __get_pydantic_core_schema__( + cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler + ) -> CoreSchema: + def validate_from_bytes(value: bytes) -> EncodedImageScalar: + return EncodedImageScalar(value) - from_bytes_schema = core_schema.chain_schema( + from_bytes_schema = core_schema.chain_schema( + [ + core_schema.bytes_schema(), + core_schema.no_info_plain_validator_function(validate_from_bytes), + ] + ) + + return core_schema.json_or_python_schema( + json_schema=from_bytes_schema, + python_schema=core_schema.union_schema( [ - core_schema.bytes_schema(), - core_schema.no_info_plain_validator_function(validate_from_bytes), + core_schema.is_instance_schema(EncodedImageArray), + from_bytes_schema, ] - ) + ), + serialization=core_schema.plain_serializer_function_ser_schema( + lambda instance: instance.values + ), + ) - return core_schema.json_or_python_schema( - json_schema=from_bytes_schema, - python_schema=core_schema.union_schema( - [ - core_schema.is_instance_schema(EncodedImageArray), - from_bytes_schema, - ] - ), - serialization=core_schema.plain_serializer_function_ser_schema( - lambda instance: instance.values - ), - ) + @classmethod + def __get_pydantic_json_schema__( + cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: + return handler(core_schema.bytes_schema()) + + @classmethod + def __get_validators__(cls) -> Generator[Callable, None, None]: + yield cls.validate + + # For pydantic v2 + @classmethod + def validate(cls, v): + if isinstance(v, ImageURIArray): + v = v.read_uris() + if isinstance(v, pa.BinaryArray): + v = pa.ExtensionArray.from_storage(EncodedImageType(), v) + if not isinstance(v, EncodedImageArray): + raise TypeError("Invalid input array type", type(v)) + + return v + + if PYDANTIC_VERSION < (2, 0): @classmethod - def __get_pydantic_json_schema__( - cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler - ) -> JsonSchemaValue: - return handler(core_schema.bytes_schema()) - - @classmethod - def __get_validators__(cls) -> Generator[Callable, None, None]: - yield cls.validate - - # For pydantic v2 - @classmethod - def validate(cls, v): - if isinstance(v, ImageURIArray): - v = v.read_uris() - if isinstance(v, pa.BinaryArray): - v = pa.ExtensionArray.from_storage(EncodedImageType(), v) - if not isinstance(v, EncodedImageArray): - raise TypeError("Invalid input array type", type(v)) - - return v - - if PYDANTIC_VERSION < (2, 0): - - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]): - field_schema["type"] = "string" - field_schema["format"] = "binary" - - return EncodedImage + def __modify_schema__(cls, field_schema: Dict[str, Any]): + field_schema["type"] = "string" + field_schema["format"] = "binary" -def ImageURI() -> Type[ImageMixin]: +class ImageURI(str, ImageMixin): """Pydantic ImageUri Type. !!! warning @@ -310,67 +305,63 @@ def ImageURI() -> Type[ImageMixin]: >>> assert schema == pa.schema([ ... pa.field("url", pa.utf8(), False), ... ]) - """ + """ + def __repr__(self): + return "ImageURI()" - class ImageURI(str, ImageMixin): - def __repr__(self): - return "ImageURI()" + @staticmethod + def value_arrow_type() -> pa.DataType: + return pa.string() - @staticmethod - def value_arrow_type() -> pa.DataType: - return pa.string() + @classmethod + def __get_pydantic_core_schema__( + cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler + ) -> CoreSchema: + def validate_from_str(value: str) -> ImageURIScalar: + return ImageURIScalar(value) - @classmethod - def __get_pydantic_core_schema__( - cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler - ) -> CoreSchema: - def validate_from_str(value: str) -> ImageURIScalar: - return ImageURIScalar(value) + from_str_schema = core_schema.chain_schema( + [ + core_schema.str_schema(), + core_schema.no_info_plain_validator_function(validate_from_str), + ] + ) - from_str_schema = core_schema.chain_schema( + return core_schema.json_or_python_schema( + json_schema=from_str_schema, + python_schema=core_schema.union_schema( [ - core_schema.str_schema(), - core_schema.no_info_plain_validator_function(validate_from_str), + core_schema.is_instance_schema(ImageURIArray), + from_str_schema, ] - ) + ), + serialization=core_schema.plain_serializer_function_ser_schema( + lambda instance: instance.values + ), + ) - return core_schema.json_or_python_schema( - json_schema=from_str_schema, - python_schema=core_schema.union_schema( - [ - core_schema.is_instance_schema(ImageURIArray), - from_str_schema, - ] - ), - serialization=core_schema.plain_serializer_function_ser_schema( - lambda instance: instance.values - ), - ) + @classmethod + def __get_validators__(cls) -> Generator[Callable, None, None]: + yield cls.validate + + # For pydantic v2 + @classmethod + def validate(cls, v): + from lance.arrow import ImageURIArray, ImageURIType + + if isinstance(v, (str, pa.StringArray)): + v = pa.ExtensionArray.from_storage(ImageURIType(), v) + if not isinstance(v, ImageURIArray): + raise TypeError("Invalid input array type", type(v)) + + return v + + if PYDANTIC_VERSION < (2, 0): @classmethod - def __get_validators__(cls) -> Generator[Callable, None, None]: - yield cls.validate - - # For pydantic v2 - @classmethod - def validate(cls, v): - from lance.arrow import ImageURIArray, ImageURIType - - if isinstance(v, (str, pa.StringArray)): - v = pa.ExtensionArray.from_storage(ImageURIType(), v) - if not isinstance(v, ImageURIArray): - raise TypeError("Invalid input array type", type(v)) - - return v - - if PYDANTIC_VERSION < (2, 0): - - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]): - field_schema["type"] = "string" - field_schema["format"] = "string" - - return ImageURI + def __modify_schema__(cls, field_schema: Dict[str, Any]): + field_schema["type"] = "string" + field_schema["format"] = "string" if PYDANTIC_VERSION.major < 2: diff --git a/python/python/tests/test_pydantic.py b/python/python/tests/test_pydantic.py index a23bc166..16e6a678 100644 --- a/python/python/tests/test_pydantic.py +++ b/python/python/tests/test_pydantic.py @@ -28,7 +28,7 @@ from pydantic import Field from lancedb.pydantic import ( PYDANTIC_VERSION, - EncodedImage, + Image, ImageURI, LanceModel, Vector, @@ -268,7 +268,7 @@ def test_lance_model_with_lance_types(): default_encoded_images = default_image_uris.read_uris() class TestModel(LanceModel): - encoded_images: EncodedImage() = Field(default=default_encoded_images) + encoded_images: Image() = Field(default=default_encoded_images) image_uris: ImageURI() = Field(default=default_image_uris) schema = pydantic_to_schema(TestModel)