mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-10 22:02:58 +00:00
__get_pydantic_core_schema__
This commit is contained in:
@@ -36,10 +36,22 @@ import numpy as np
|
||||
import pyarrow as pa
|
||||
import pydantic
|
||||
import semver
|
||||
from pydantic.fields import FieldInfo
|
||||
from lance.arrow import (
|
||||
EncodedImageScalar,
|
||||
ImageURIScalar,
|
||||
ImageURIArray,
|
||||
EncodedImageType,
|
||||
EncodedImageArray,
|
||||
)
|
||||
|
||||
from .embeddings import EmbeddingFunctionRegistry
|
||||
|
||||
PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__)
|
||||
try:
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
from pydantic import GetJsonSchemaHandler
|
||||
from pydantic.json_schema import JsonSchemaValue
|
||||
except ImportError:
|
||||
if PYDANTIC_VERSION >= (2,):
|
||||
raise
|
||||
@@ -225,20 +237,42 @@ def EncodedImage() -> Type[ImageMixin]:
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
return core_schema.no_info_after_validator_function(
|
||||
cls,
|
||||
core_schema.str_schema(),
|
||||
def validate_from_bytes(value: bytes) -> EncodedImageScalar:
|
||||
return EncodedImageScalar(value)
|
||||
|
||||
from_bytes_schema = core_schema.chain_schema(
|
||||
[
|
||||
core_schema.bytes_schema(),
|
||||
core_schema.no_info_plain_validator_function(validate_from_bytes),
|
||||
]
|
||||
)
|
||||
|
||||
return core_schema.json_or_python_schema(
|
||||
json_schema=from_bytes_schema,
|
||||
python_schema=core_schema.union_schema(
|
||||
[
|
||||
core_schema.is_instance_schema(EncodedImageArray),
|
||||
from_bytes_schema,
|
||||
]
|
||||
),
|
||||
serialization=core_schema.plain_serializer_function_ser_schema(
|
||||
lambda instance: instance.values
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(
|
||||
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
|
||||
) -> JsonSchemaValue:
|
||||
return handler(core_schema.bytes_schema())
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> Generator[Callable, None, None]:
|
||||
yield cls.validate
|
||||
|
||||
# For pydantic v1
|
||||
# For pydantic v2
|
||||
@classmethod
|
||||
def validate(cls, v):
|
||||
from lance.arrow import ImageURIArray, EncodedImageType, EncodedImageArray
|
||||
|
||||
if isinstance(v, ImageURIArray):
|
||||
v = v.read_uris()
|
||||
if isinstance(v, pa.BinaryArray):
|
||||
@@ -290,16 +324,34 @@ def ImageURI() -> Type[ImageMixin]:
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
return core_schema.no_info_after_validator_function(
|
||||
cls,
|
||||
core_schema.str_schema(),
|
||||
def validate_from_str(value: str) -> ImageURIScalar:
|
||||
return ImageURIScalar(value)
|
||||
|
||||
from_str_schema = core_schema.chain_schema(
|
||||
[
|
||||
core_schema.str_schema(),
|
||||
core_schema.no_info_plain_validator_function(validate_from_str),
|
||||
]
|
||||
)
|
||||
|
||||
return core_schema.json_or_python_schema(
|
||||
json_schema=from_str_schema,
|
||||
python_schema=core_schema.union_schema(
|
||||
[
|
||||
core_schema.is_instance_schema(ImageURIArray),
|
||||
from_str_schema,
|
||||
]
|
||||
),
|
||||
serialization=core_schema.plain_serializer_function_ser_schema(
|
||||
lambda instance: instance.values
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> Generator[Callable, None, None]:
|
||||
yield cls.validate
|
||||
|
||||
# For pydantic v1
|
||||
# For pydantic v2
|
||||
@classmethod
|
||||
def validate(cls, v):
|
||||
from lance.arrow import ImageURIArray, ImageURIType
|
||||
|
||||
@@ -18,7 +18,7 @@ import os
|
||||
from datetime import date, datetime
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import pyarrow as pa
|
||||
import pydantic
|
||||
import pytest
|
||||
@@ -260,6 +260,9 @@ def test_lance_model_with_lance_types():
|
||||
png_uris = [
|
||||
"file://" + os.path.join(os.path.dirname(__file__), "images/1.png"),
|
||||
]
|
||||
if os.name == "nt":
|
||||
png_uris = [str(Path(x)) for x in png_uris]
|
||||
|
||||
default_image_uris = ImageURIArray.from_uris(png_uris)
|
||||
default_encoded_images = default_image_uris.read_uris()
|
||||
|
||||
@@ -279,6 +282,6 @@ def test_lance_model_with_lance_types():
|
||||
assert expected_model == actual_model
|
||||
|
||||
actual_model = TestModel(
|
||||
encoded_images=default_image_uris, image_uris=default_image_uris
|
||||
encoded_images=default_encoded_images, image_uris=default_image_uris
|
||||
)
|
||||
assert expected_model == actual_model
|
||||
|
||||
Reference in New Issue
Block a user