From 0b0f4e9d1c81d3b24c8cd700252349d689ba296b Mon Sep 17 00:00:00 2001 From: Rok Mihevc Date: Sat, 6 Jan 2024 01:34:47 +0100 Subject: [PATCH] __get_pydantic_core_schema__ --- python/python/lancedb/pydantic.py | 72 ++++++++++++++++++++++++---- python/python/tests/test_pydantic.py | 7 ++- 2 files changed, 67 insertions(+), 12 deletions(-) diff --git a/python/python/lancedb/pydantic.py b/python/python/lancedb/pydantic.py index 355cfcb1..54cd443e 100644 --- a/python/python/lancedb/pydantic.py +++ b/python/python/lancedb/pydantic.py @@ -36,10 +36,22 @@ import numpy as np import pyarrow as pa import pydantic import semver +from pydantic.fields import FieldInfo +from lance.arrow import ( + EncodedImageScalar, + ImageURIScalar, + ImageURIArray, + EncodedImageType, + EncodedImageArray, +) + +from .embeddings import EmbeddingFunctionRegistry PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__) try: from pydantic_core import CoreSchema, core_schema + from pydantic import GetJsonSchemaHandler + from pydantic.json_schema import JsonSchemaValue except ImportError: if PYDANTIC_VERSION >= (2,): raise @@ -225,20 +237,42 @@ def EncodedImage() -> Type[ImageMixin]: def __get_pydantic_core_schema__( cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler ) -> CoreSchema: - return core_schema.no_info_after_validator_function( - cls, - core_schema.str_schema(), + def validate_from_bytes(value: bytes) -> EncodedImageScalar: + return EncodedImageScalar(value) + + from_bytes_schema = core_schema.chain_schema( + [ + core_schema.bytes_schema(), + core_schema.no_info_plain_validator_function(validate_from_bytes), + ] ) + return core_schema.json_or_python_schema( + json_schema=from_bytes_schema, + python_schema=core_schema.union_schema( + [ + core_schema.is_instance_schema(EncodedImageArray), + from_bytes_schema, + ] + ), + serialization=core_schema.plain_serializer_function_ser_schema( + lambda instance: instance.values + ), + ) + + @classmethod + def __get_pydantic_json_schema__( + cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: + return handler(core_schema.bytes_schema()) + @classmethod def __get_validators__(cls) -> Generator[Callable, None, None]: yield cls.validate - # For pydantic v1 + # For pydantic v2 @classmethod def validate(cls, v): - from lance.arrow import ImageURIArray, EncodedImageType, EncodedImageArray - if isinstance(v, ImageURIArray): v = v.read_uris() if isinstance(v, pa.BinaryArray): @@ -290,16 +324,34 @@ def ImageURI() -> Type[ImageMixin]: def __get_pydantic_core_schema__( cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler ) -> CoreSchema: - return core_schema.no_info_after_validator_function( - cls, - core_schema.str_schema(), + def validate_from_str(value: str) -> ImageURIScalar: + return ImageURIScalar(value) + + from_str_schema = core_schema.chain_schema( + [ + core_schema.str_schema(), + core_schema.no_info_plain_validator_function(validate_from_str), + ] + ) + + return core_schema.json_or_python_schema( + json_schema=from_str_schema, + python_schema=core_schema.union_schema( + [ + core_schema.is_instance_schema(ImageURIArray), + from_str_schema, + ] + ), + serialization=core_schema.plain_serializer_function_ser_schema( + lambda instance: instance.values + ), ) @classmethod def __get_validators__(cls) -> Generator[Callable, None, None]: yield cls.validate - # For pydantic v1 + # For pydantic v2 @classmethod def validate(cls, v): from lance.arrow import ImageURIArray, ImageURIType diff --git a/python/python/tests/test_pydantic.py b/python/python/tests/test_pydantic.py index d6afb2ec..3cbc7dcc 100644 --- a/python/python/tests/test_pydantic.py +++ b/python/python/tests/test_pydantic.py @@ -18,7 +18,7 @@ import os from datetime import date, datetime from typing import List, Optional, Tuple -import numpy as np +from pathlib import Path import pyarrow as pa import pydantic import pytest @@ -260,6 +260,9 @@ def test_lance_model_with_lance_types(): png_uris = [ "file://" + os.path.join(os.path.dirname(__file__), "images/1.png"), ] + if os.name == "nt": + png_uris = [str(Path(x)) for x in png_uris] + default_image_uris = ImageURIArray.from_uris(png_uris) default_encoded_images = default_image_uris.read_uris() @@ -279,6 +282,6 @@ def test_lance_model_with_lance_types(): assert expected_model == actual_model actual_model = TestModel( - encoded_images=default_image_uris, image_uris=default_image_uris + encoded_images=default_encoded_images, image_uris=default_image_uris ) assert expected_model == actual_model