mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-12 06:42:56 +00:00
Minor change
This commit is contained in:
@@ -227,7 +227,7 @@ def EncodedImage() -> Type[ImageMixin]:
|
||||
) -> CoreSchema:
|
||||
return core_schema.no_info_after_validator_function(
|
||||
cls,
|
||||
core_schema.binary_schema(),
|
||||
core_schema.str_schema(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -292,7 +292,7 @@ def ImageURI() -> Type[ImageMixin]:
|
||||
) -> CoreSchema:
|
||||
return core_schema.no_info_after_validator_function(
|
||||
cls,
|
||||
core_schema.string_schema(),
|
||||
core_schema.str_schema(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -302,9 +302,13 @@ def ImageURI() -> Type[ImageMixin]:
|
||||
# For pydantic v1
|
||||
@classmethod
|
||||
def validate(cls, v):
|
||||
# if not isinstance(v, str):
|
||||
# raise TypeError("A str is needed")
|
||||
# return cls(v)
|
||||
from lance.arrow import ImageURIArray, ImageURIType
|
||||
|
||||
if isinstance(v, (str, pa.StringArray)):
|
||||
v = pa.ExtensionArray.from_storage(ImageURIType(), v)
|
||||
if not isinstance(v, ImageURIArray):
|
||||
raise TypeError("Invalid input array type", type(v))
|
||||
|
||||
return v
|
||||
|
||||
if PYDANTIC_VERSION < (2, 0):
|
||||
|
||||
@@ -18,6 +18,7 @@ import os
|
||||
from datetime import date, datetime
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pydantic
|
||||
import pytest
|
||||
@@ -256,18 +257,11 @@ def test_lance_model():
|
||||
|
||||
|
||||
def test_lance_model_with_lance_types():
|
||||
img = (
|
||||
b"\\x89PNG\\r\\n\\x1a\\n\\x00\\x00\\x00\\rIHDR\\x00\\x00\\x00"
|
||||
b"\\x01\\x00\\x00\\x00\\x01\\x08\\x06\\x00\\x00\\x00\\x1f"
|
||||
)
|
||||
png_uris = [
|
||||
"file://" + os.path.join(os.path.dirname(__file__), "images/1.png"),
|
||||
]
|
||||
default_image_uris = ImageURIArray.from_uris(png_uris)
|
||||
encoded_images = pa.array([img], pa.binary())
|
||||
default_encoded_images = pa.ExtensionArray.from_storage(
|
||||
EncodedImageType(), encoded_images
|
||||
)
|
||||
default_encoded_images = default_image_uris.read_uris()
|
||||
|
||||
class TestModel(LanceModel):
|
||||
encoded_images: EncodedImage() = Field(default=default_encoded_images)
|
||||
@@ -278,6 +272,7 @@ def test_lance_model_with_lance_types():
|
||||
assert TestModel.field_names() == ["encoded_images", "image_uris"]
|
||||
|
||||
expected_model = TestModel()
|
||||
|
||||
actual_model = TestModel(
|
||||
encoded_images=default_encoded_images, image_uris=default_image_uris
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user