__get_pydantic_core_schema__

This commit is contained in:
Rok Mihevc
2024-01-06 01:34:47 +01:00
committed by Chang She
parent 2ec0e79303
commit 0b0f4e9d1c
2 changed files with 67 additions and 12 deletions

View File

@@ -36,10 +36,22 @@ import numpy as np
import pyarrow as pa
import pydantic
import semver
from pydantic.fields import FieldInfo
from lance.arrow import (
EncodedImageScalar,
ImageURIScalar,
ImageURIArray,
EncodedImageType,
EncodedImageArray,
)
from .embeddings import EmbeddingFunctionRegistry
PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__)
try:
from pydantic_core import CoreSchema, core_schema
from pydantic import GetJsonSchemaHandler
from pydantic.json_schema import JsonSchemaValue
except ImportError:
if PYDANTIC_VERSION >= (2,):
raise
@@ -225,20 +237,42 @@ def EncodedImage() -> Type[ImageMixin]:
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
) -> CoreSchema:
return core_schema.no_info_after_validator_function(
cls,
core_schema.str_schema(),
def validate_from_bytes(value: bytes) -> EncodedImageScalar:
return EncodedImageScalar(value)
from_bytes_schema = core_schema.chain_schema(
[
core_schema.bytes_schema(),
core_schema.no_info_plain_validator_function(validate_from_bytes),
]
)
return core_schema.json_or_python_schema(
json_schema=from_bytes_schema,
python_schema=core_schema.union_schema(
[
core_schema.is_instance_schema(EncodedImageArray),
from_bytes_schema,
]
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: instance.values
),
)
@classmethod
def __get_pydantic_json_schema__(
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
return handler(core_schema.bytes_schema())
@classmethod
def __get_validators__(cls) -> Generator[Callable, None, None]:
yield cls.validate
# For pydantic v1
# For pydantic v2
@classmethod
def validate(cls, v):
from lance.arrow import ImageURIArray, EncodedImageType, EncodedImageArray
if isinstance(v, ImageURIArray):
v = v.read_uris()
if isinstance(v, pa.BinaryArray):
@@ -290,16 +324,34 @@ def ImageURI() -> Type[ImageMixin]:
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
) -> CoreSchema:
return core_schema.no_info_after_validator_function(
cls,
core_schema.str_schema(),
def validate_from_str(value: str) -> ImageURIScalar:
return ImageURIScalar(value)
from_str_schema = core_schema.chain_schema(
[
core_schema.str_schema(),
core_schema.no_info_plain_validator_function(validate_from_str),
]
)
return core_schema.json_or_python_schema(
json_schema=from_str_schema,
python_schema=core_schema.union_schema(
[
core_schema.is_instance_schema(ImageURIArray),
from_str_schema,
]
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: instance.values
),
)
@classmethod
def __get_validators__(cls) -> Generator[Callable, None, None]:
yield cls.validate
# For pydantic v1
# For pydantic v2
@classmethod
def validate(cls, v):
from lance.arrow import ImageURIArray, ImageURIType

View File

@@ -18,7 +18,7 @@ import os
from datetime import date, datetime
from typing import List, Optional, Tuple
import numpy as np
from pathlib import Path
import pyarrow as pa
import pydantic
import pytest
@@ -260,6 +260,9 @@ def test_lance_model_with_lance_types():
png_uris = [
"file://" + os.path.join(os.path.dirname(__file__), "images/1.png"),
]
if os.name == "nt":
png_uris = [str(Path(x)) for x in png_uris]
default_image_uris = ImageURIArray.from_uris(png_uris)
default_encoded_images = default_image_uris.read_uris()
@@ -279,6 +282,6 @@ def test_lance_model_with_lance_types():
assert expected_model == actual_model
actual_model = TestModel(
encoded_images=default_image_uris, image_uris=default_image_uris
encoded_images=default_encoded_images, image_uris=default_image_uris
)
assert expected_model == actual_model