diff --git a/python/pyproject.toml b/python/pyproject.toml index 45508911..dec1edc1 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -57,6 +57,7 @@ tests = [ "duckdb", "pytz", "polars>=0.19", + "PIL" ] dev = ["ruff", "pre-commit"] docs = [ diff --git a/python/python/lancedb/pydantic.py b/python/python/lancedb/pydantic.py index c2db4481..8233d930 100644 --- a/python/python/lancedb/pydantic.py +++ b/python/python/lancedb/pydantic.py @@ -16,6 +16,7 @@ from __future__ import annotations import inspect +import io import sys import types from abc import ABC, abstractmethod @@ -38,11 +39,12 @@ import pydantic import semver from pydantic.fields import FieldInfo from lance.arrow import ( - EncodedImageScalar, - ImageURIScalar, - ImageURIArray, - EncodedImageType, EncodedImageArray, + EncodedImageScalar, + EncodedImageType, + ImageURIArray, + ImageURIScalar, + ImageURIType, ) PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__) @@ -203,165 +205,87 @@ class ImageMixin(ABC): raise NotImplementedError -class Image(bytes, ImageMixin): - """Pydantic type for inlined images. +def EncodedImage(): + import PIL.Image - !!! warning - Experimental feature. + class EncodedImage(bytes, ImageMixin): + """Pydantic type for inlined images. - Examples - -------- + !!! warning + Experimental feature. - >>> import pydantic - >>> from lancedb.pydantic import Image - ... - >>> class MyModel(pydantic.BaseModel): - ... image: Image - >>> schema = pydantic_to_schema(MyModel) - >>> assert schema == pa.schema([ - ... pa.field("image", pa.binary(), False) - ... ]) - """ + Examples + -------- - def __repr__(self): - return "Image()" + >>> import pydantic + >>> from lancedb.pydantic import EncodedImage + ... + >>> class MyModel(pydantic.BaseModel): + ... image: EncodedImage() + >>> schema = pydantic_to_schema(MyModel) + >>> assert schema == pa.schema([ + ... pa.field("image", pa.binary(), False) + ... ]) + """ - @staticmethod - def value_arrow_type() -> pa.DataType: - return pa.binary() + def __repr__(self): + return "EncodedImage()" - @classmethod - def __get_pydantic_core_schema__( - cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler - ) -> CoreSchema: - def validate_from_bytes(value: bytes) -> EncodedImageScalar: - return EncodedImageScalar(value) - - from_bytes_schema = core_schema.chain_schema( - [ - core_schema.bytes_schema(), - core_schema.no_info_plain_validator_function(validate_from_bytes), - ] - ) - - return core_schema.json_or_python_schema( - json_schema=from_bytes_schema, - python_schema=core_schema.union_schema( - [ - core_schema.is_instance_schema(EncodedImageArray), - from_bytes_schema, - ] - ), - serialization=core_schema.plain_serializer_function_ser_schema( - lambda instance: instance.values - ), - ) - - @classmethod - def __get_pydantic_json_schema__( - cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler - ) -> JsonSchemaValue: - return handler(core_schema.bytes_schema()) - - @classmethod - def __get_validators__(cls) -> Generator[Callable, None, None]: - yield cls.validate - - # For pydantic v2 - @classmethod - def validate(cls, v): - if isinstance(v, ImageURIArray): - v = v.read_uris() - if isinstance(v, pa.BinaryArray): - v = pa.ExtensionArray.from_storage(EncodedImageType(), v) - if not isinstance(v, EncodedImageArray): - raise TypeError("Invalid input array type", type(v)) - - return v - - if PYDANTIC_VERSION < (2, 0): + @staticmethod + def value_arrow_type() -> pa.DataType: + return EncodedImageType() @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]): - field_schema["type"] = "string" - field_schema["format"] = "binary" + def __get_pydantic_core_schema__( + cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler + ) -> CoreSchema: + from_bytes_schema = core_schema.bytes_schema() - -class ImageURI(str, ImageMixin): - """Pydantic ImageUri Type. - - !!! warning - Experimental feature. - - Examples - -------- - - >>> import pydantic - >>> from lancedb.pydantic import ImageURI - ... - >>> class MyModel(pydantic.BaseModel): - ... url: ImageURI() - >>> schema = pydantic_to_schema(MyModel) - >>> assert schema == pa.schema([ - ... pa.field("url", pa.utf8(), False), - ... ]) - """ - def __repr__(self): - return "ImageURI()" - - @staticmethod - def value_arrow_type() -> pa.DataType: - return pa.string() - - @classmethod - def __get_pydantic_core_schema__( - cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler - ) -> CoreSchema: - def validate_from_str(value: str) -> ImageURIScalar: - return ImageURIScalar(value) - - from_str_schema = core_schema.chain_schema( - [ - core_schema.str_schema(), - core_schema.no_info_plain_validator_function(validate_from_str), - ] - ) - - return core_schema.json_or_python_schema( - json_schema=from_str_schema, - python_schema=core_schema.union_schema( - [ - core_schema.is_instance_schema(ImageURIArray), - from_str_schema, - ] - ), - serialization=core_schema.plain_serializer_function_ser_schema( - lambda instance: instance.values - ), - ) - - @classmethod - def __get_validators__(cls) -> Generator[Callable, None, None]: - yield cls.validate - - # For pydantic v2 - @classmethod - def validate(cls, v): - from lance.arrow import ImageURIArray, ImageURIType - - if isinstance(v, (str, pa.StringArray)): - v = pa.ExtensionArray.from_storage(ImageURIType(), v) - if not isinstance(v, ImageURIArray): - raise TypeError("Invalid input array type", type(v)) - - return v - - if PYDANTIC_VERSION < (2, 0): + return core_schema.json_or_python_schema( + json_schema=from_bytes_schema, + python_schema=core_schema.union_schema( + [ + core_schema.is_instance_schema(PIL.Image.Image), + from_bytes_schema, + ] + ), + serialization=core_schema.plain_serializer_function_ser_schema( + lambda instance: cls.validate(instance) + ), + ) @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]): - field_schema["type"] = "string" - field_schema["format"] = "string" + def __get_pydantic_json_schema__( + cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: + return handler(core_schema.bytes_schema()) + + @classmethod + def __get_validators__(cls) -> Generator[Callable, None, None]: + yield cls.validate + + # For pydantic v2 + @classmethod + def validate(cls, v): + if isinstance(v, bytes): + return v + if isinstance(v, PIL.Image.Image): + with io.BytesIO() as output: + v.save(output, format=v.format) + return output.getvalue() + raise TypeError( + "EncodedImage can take bytes or PIL.Image.Image " + f"as input but got {type(v)}" + ) + + if PYDANTIC_VERSION < (2, 0): + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]): + field_schema["type"] = "string" + field_schema["format"] = "binary" + + return EncodedImage if PYDANTIC_VERSION.major < 2: diff --git a/python/tests/images/1.png b/python/python/tests/images/1.png similarity index 100% rename from python/tests/images/1.png rename to python/python/tests/images/1.png diff --git a/python/python/tests/test_pydantic.py b/python/python/tests/test_pydantic.py index 16e6a678..2e458a47 100644 --- a/python/python/tests/test_pydantic.py +++ b/python/python/tests/test_pydantic.py @@ -12,6 +12,7 @@ # limitations under the License. +import io import json import sys import os @@ -23,13 +24,12 @@ import pyarrow as pa import pydantic import pytest import pytz -from lance.arrow import ImageURIArray from pydantic import Field +from lance.arrow import EncodedImageType from lancedb.pydantic import ( PYDANTIC_VERSION, - Image, - ImageURI, + EncodedImage, LanceModel, Vector, pydantic_to_schema, @@ -257,32 +257,25 @@ def test_lance_model(): assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3]) -def test_lance_model_with_lance_types(): - png_uris = [ - "file://" + os.path.join(os.path.dirname(__file__), "images/1.png"), - ] - if os.name == "nt": - png_uris = [str(Path(x)) for x in png_uris] - - default_image_uris = ImageURIArray.from_uris(png_uris) - default_encoded_images = default_image_uris.read_uris() +def test_schema_with_images(): + pytest.importorskip("PIL") + import PIL.Image class TestModel(LanceModel): - encoded_images: Image() = Field(default=default_encoded_images) - image_uris: ImageURI() = Field(default=default_image_uris) + img: EncodedImage() - schema = pydantic_to_schema(TestModel) + schema = pa.schema([pa.field("img", EncodedImageType(), False)]) assert schema == TestModel.to_arrow_schema() - assert TestModel.field_names() == ["encoded_images", "image_uris"] + assert TestModel.field_names() == ["img"] - expected_model = TestModel() + img_path = Path(os.path.dirname(__file__)) / "images/1.png" + with open(img_path, "rb") as f: + img_bytes = f.read() - actual_model = TestModel( - encoded_images=default_encoded_images, image_uris=default_image_uris - ) - assert expected_model == actual_model + m1 = TestModel(img=PIL.Image.open(img_path)) + m2 = TestModel(img=img_bytes) - actual_model = TestModel( - encoded_images=default_encoded_images, image_uris=default_image_uris - ) - assert expected_model == actual_model + def tobytes(m): + return PIL.Image.open(io.BytesIO(m.model_dump()["img"])).tobytes() + + assert tobytes(m1) == tobytes(m2) diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index be6b4041..01134838 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -14,6 +14,8 @@ import functools from copy import copy from datetime import date, datetime, timedelta +import io +import os from pathlib import Path from time import sleep from typing import List @@ -30,7 +32,7 @@ import pytest_asyncio from lancedb.conftest import MockTextEmbeddingFunction from lancedb.db import AsyncConnection, LanceDBConnection from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry -from lancedb.pydantic import LanceModel, Vector +from lancedb.pydantic import EncodedImage, LanceModel, Vector from lancedb.table import LanceTable from pydantic import BaseModel from lance.arrow import EncodedImageArray, EncodedImageType, ImageURIType @@ -110,7 +112,6 @@ def test_create_table(db): pa.field("item", pa.string()), pa.field("price", pa.float32()), pa.field("encoded_image", EncodedImageType()), - pa.field("image_uris", ImageURIType()), ] ) expected = pa.Table.from_arrays( @@ -121,9 +122,6 @@ def test_create_table(db): pa.ExtensionArray.from_storage( EncodedImageType(), pa.array([b"foo", b"bar"], pa.binary()) ), - pa.ExtensionArray.from_storage( - ImageURIType(), pa.array(["/tmp/foo", "/tmp/bar"], pa.string()) - ), ], schema=schema, ) @@ -134,14 +132,12 @@ def test_create_table(db): "item": "foo", "price": 10.0, "encoded_image": b"foo", - "image_uris": "/tmp/foo", }, { "vector": [5.9, 26.5], "item": "bar", "price": 20.0, "encoded_image": b"bar", - "image_uris": "/tmp/bar", }, ] ] @@ -1046,3 +1042,22 @@ async def test_time_travel(db_async: AsyncConnection): # Can't use restore if not checked out with pytest.raises(ValueError, match="checkout before running restore"): await table.restore() + + +def test_add_image(tmp_path): + pytest.importorskip("PIL") + import PIL.Image + + db = lancedb.connect(tmp_path) + + class TestModel(LanceModel): + img: EncodedImage() + + img_path = Path(os.path.dirname(__file__)) / "images/1.png" + m1 = TestModel(img=PIL.Image.open(img_path)) + + def tobytes(m): + return PIL.Image.open(io.BytesIO(m.model_dump()["img"])).tobytes() + + table = LanceTable.create(db, "my_table", schema=TestModel) + table.add([m1])