just keep EncodedImage for now

This commit is contained in:
Chang She
2024-03-03 21:36:46 -08:00
parent e68fbf65cc
commit 408988abce
5 changed files with 117 additions and 184 deletions

View File

@@ -57,6 +57,7 @@ tests = [
"duckdb",
"pytz",
"polars>=0.19",
"PIL"
]
dev = ["ruff", "pre-commit"]
docs = [

View File

@@ -16,6 +16,7 @@
from __future__ import annotations
import inspect
import io
import sys
import types
from abc import ABC, abstractmethod
@@ -38,11 +39,12 @@ import pydantic
import semver
from pydantic.fields import FieldInfo
from lance.arrow import (
EncodedImageScalar,
ImageURIScalar,
ImageURIArray,
EncodedImageType,
EncodedImageArray,
EncodedImageScalar,
EncodedImageType,
ImageURIArray,
ImageURIScalar,
ImageURIType,
)
PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__)
@@ -203,165 +205,87 @@ class ImageMixin(ABC):
raise NotImplementedError
class Image(bytes, ImageMixin):
"""Pydantic type for inlined images.
def EncodedImage():
import PIL.Image
!!! warning
Experimental feature.
class EncodedImage(bytes, ImageMixin):
"""Pydantic type for inlined images.
Examples
--------
!!! warning
Experimental feature.
>>> import pydantic
>>> from lancedb.pydantic import Image
...
>>> class MyModel(pydantic.BaseModel):
... image: Image
>>> schema = pydantic_to_schema(MyModel)
>>> assert schema == pa.schema([
... pa.field("image", pa.binary(), False)
... ])
"""
Examples
--------
def __repr__(self):
return "Image()"
>>> import pydantic
>>> from lancedb.pydantic import EncodedImage
...
>>> class MyModel(pydantic.BaseModel):
... image: EncodedImage()
>>> schema = pydantic_to_schema(MyModel)
>>> assert schema == pa.schema([
... pa.field("image", pa.binary(), False)
... ])
"""
@staticmethod
def value_arrow_type() -> pa.DataType:
return pa.binary()
def __repr__(self):
return "EncodedImage()"
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
) -> CoreSchema:
def validate_from_bytes(value: bytes) -> EncodedImageScalar:
return EncodedImageScalar(value)
from_bytes_schema = core_schema.chain_schema(
[
core_schema.bytes_schema(),
core_schema.no_info_plain_validator_function(validate_from_bytes),
]
)
return core_schema.json_or_python_schema(
json_schema=from_bytes_schema,
python_schema=core_schema.union_schema(
[
core_schema.is_instance_schema(EncodedImageArray),
from_bytes_schema,
]
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: instance.values
),
)
@classmethod
def __get_pydantic_json_schema__(
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
return handler(core_schema.bytes_schema())
@classmethod
def __get_validators__(cls) -> Generator[Callable, None, None]:
yield cls.validate
# For pydantic v2
@classmethod
def validate(cls, v):
if isinstance(v, ImageURIArray):
v = v.read_uris()
if isinstance(v, pa.BinaryArray):
v = pa.ExtensionArray.from_storage(EncodedImageType(), v)
if not isinstance(v, EncodedImageArray):
raise TypeError("Invalid input array type", type(v))
return v
if PYDANTIC_VERSION < (2, 0):
@staticmethod
def value_arrow_type() -> pa.DataType:
return EncodedImageType()
@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]):
field_schema["type"] = "string"
field_schema["format"] = "binary"
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
) -> CoreSchema:
from_bytes_schema = core_schema.bytes_schema()
class ImageURI(str, ImageMixin):
"""Pydantic ImageUri Type.
!!! warning
Experimental feature.
Examples
--------
>>> import pydantic
>>> from lancedb.pydantic import ImageURI
...
>>> class MyModel(pydantic.BaseModel):
... url: ImageURI()
>>> schema = pydantic_to_schema(MyModel)
>>> assert schema == pa.schema([
... pa.field("url", pa.utf8(), False),
... ])
"""
def __repr__(self):
return "ImageURI()"
@staticmethod
def value_arrow_type() -> pa.DataType:
return pa.string()
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
) -> CoreSchema:
def validate_from_str(value: str) -> ImageURIScalar:
return ImageURIScalar(value)
from_str_schema = core_schema.chain_schema(
[
core_schema.str_schema(),
core_schema.no_info_plain_validator_function(validate_from_str),
]
)
return core_schema.json_or_python_schema(
json_schema=from_str_schema,
python_schema=core_schema.union_schema(
[
core_schema.is_instance_schema(ImageURIArray),
from_str_schema,
]
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: instance.values
),
)
@classmethod
def __get_validators__(cls) -> Generator[Callable, None, None]:
yield cls.validate
# For pydantic v2
@classmethod
def validate(cls, v):
from lance.arrow import ImageURIArray, ImageURIType
if isinstance(v, (str, pa.StringArray)):
v = pa.ExtensionArray.from_storage(ImageURIType(), v)
if not isinstance(v, ImageURIArray):
raise TypeError("Invalid input array type", type(v))
return v
if PYDANTIC_VERSION < (2, 0):
return core_schema.json_or_python_schema(
json_schema=from_bytes_schema,
python_schema=core_schema.union_schema(
[
core_schema.is_instance_schema(PIL.Image.Image),
from_bytes_schema,
]
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: cls.validate(instance)
),
)
@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]):
field_schema["type"] = "string"
field_schema["format"] = "string"
def __get_pydantic_json_schema__(
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
return handler(core_schema.bytes_schema())
@classmethod
def __get_validators__(cls) -> Generator[Callable, None, None]:
yield cls.validate
# For pydantic v2
@classmethod
def validate(cls, v):
if isinstance(v, bytes):
return v
if isinstance(v, PIL.Image.Image):
with io.BytesIO() as output:
v.save(output, format=v.format)
return output.getvalue()
raise TypeError(
"EncodedImage can take bytes or PIL.Image.Image "
f"as input but got {type(v)}"
)
if PYDANTIC_VERSION < (2, 0):
@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]):
field_schema["type"] = "string"
field_schema["format"] = "binary"
return EncodedImage
if PYDANTIC_VERSION.major < 2:

View File

Before

Width:  |  Height:  |  Size: 83 B

After

Width:  |  Height:  |  Size: 83 B

View File

@@ -12,6 +12,7 @@
# limitations under the License.
import io
import json
import sys
import os
@@ -23,13 +24,12 @@ import pyarrow as pa
import pydantic
import pytest
import pytz
from lance.arrow import ImageURIArray
from pydantic import Field
from lance.arrow import EncodedImageType
from lancedb.pydantic import (
PYDANTIC_VERSION,
Image,
ImageURI,
EncodedImage,
LanceModel,
Vector,
pydantic_to_schema,
@@ -257,32 +257,25 @@ def test_lance_model():
assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3])
def test_lance_model_with_lance_types():
png_uris = [
"file://" + os.path.join(os.path.dirname(__file__), "images/1.png"),
]
if os.name == "nt":
png_uris = [str(Path(x)) for x in png_uris]
default_image_uris = ImageURIArray.from_uris(png_uris)
default_encoded_images = default_image_uris.read_uris()
def test_schema_with_images():
pytest.importorskip("PIL")
import PIL.Image
class TestModel(LanceModel):
encoded_images: Image() = Field(default=default_encoded_images)
image_uris: ImageURI() = Field(default=default_image_uris)
img: EncodedImage()
schema = pydantic_to_schema(TestModel)
schema = pa.schema([pa.field("img", EncodedImageType(), False)])
assert schema == TestModel.to_arrow_schema()
assert TestModel.field_names() == ["encoded_images", "image_uris"]
assert TestModel.field_names() == ["img"]
expected_model = TestModel()
img_path = Path(os.path.dirname(__file__)) / "images/1.png"
with open(img_path, "rb") as f:
img_bytes = f.read()
actual_model = TestModel(
encoded_images=default_encoded_images, image_uris=default_image_uris
)
assert expected_model == actual_model
m1 = TestModel(img=PIL.Image.open(img_path))
m2 = TestModel(img=img_bytes)
actual_model = TestModel(
encoded_images=default_encoded_images, image_uris=default_image_uris
)
assert expected_model == actual_model
def tobytes(m):
return PIL.Image.open(io.BytesIO(m.model_dump()["img"])).tobytes()
assert tobytes(m1) == tobytes(m2)

View File

@@ -14,6 +14,8 @@
import functools
from copy import copy
from datetime import date, datetime, timedelta
import io
import os
from pathlib import Path
from time import sleep
from typing import List
@@ -30,7 +32,7 @@ import pytest_asyncio
from lancedb.conftest import MockTextEmbeddingFunction
from lancedb.db import AsyncConnection, LanceDBConnection
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
from lancedb.pydantic import LanceModel, Vector
from lancedb.pydantic import EncodedImage, LanceModel, Vector
from lancedb.table import LanceTable
from pydantic import BaseModel
from lance.arrow import EncodedImageArray, EncodedImageType, ImageURIType
@@ -110,7 +112,6 @@ def test_create_table(db):
pa.field("item", pa.string()),
pa.field("price", pa.float32()),
pa.field("encoded_image", EncodedImageType()),
pa.field("image_uris", ImageURIType()),
]
)
expected = pa.Table.from_arrays(
@@ -121,9 +122,6 @@ def test_create_table(db):
pa.ExtensionArray.from_storage(
EncodedImageType(), pa.array([b"foo", b"bar"], pa.binary())
),
pa.ExtensionArray.from_storage(
ImageURIType(), pa.array(["/tmp/foo", "/tmp/bar"], pa.string())
),
],
schema=schema,
)
@@ -134,14 +132,12 @@ def test_create_table(db):
"item": "foo",
"price": 10.0,
"encoded_image": b"foo",
"image_uris": "/tmp/foo",
},
{
"vector": [5.9, 26.5],
"item": "bar",
"price": 20.0,
"encoded_image": b"bar",
"image_uris": "/tmp/bar",
},
]
]
@@ -1046,3 +1042,22 @@ async def test_time_travel(db_async: AsyncConnection):
# Can't use restore if not checked out
with pytest.raises(ValueError, match="checkout before running restore"):
await table.restore()
def test_add_image(tmp_path):
pytest.importorskip("PIL")
import PIL.Image
db = lancedb.connect(tmp_path)
class TestModel(LanceModel):
img: EncodedImage()
img_path = Path(os.path.dirname(__file__)) / "images/1.png"
m1 = TestModel(img=PIL.Image.open(img_path))
def tobytes(m):
return PIL.Image.open(io.BytesIO(m.model_dump()["img"])).tobytes()
table = LanceTable.create(db, "my_table", schema=TestModel)
table.add([m1])