mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-05 19:32:56 +00:00
just keep EncodedImage for now
This commit is contained in:
@@ -57,6 +57,7 @@ tests = [
|
||||
"duckdb",
|
||||
"pytz",
|
||||
"polars>=0.19",
|
||||
"PIL"
|
||||
]
|
||||
dev = ["ruff", "pre-commit"]
|
||||
docs = [
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import io
|
||||
import sys
|
||||
import types
|
||||
from abc import ABC, abstractmethod
|
||||
@@ -38,11 +39,12 @@ import pydantic
|
||||
import semver
|
||||
from pydantic.fields import FieldInfo
|
||||
from lance.arrow import (
|
||||
EncodedImageScalar,
|
||||
ImageURIScalar,
|
||||
ImageURIArray,
|
||||
EncodedImageType,
|
||||
EncodedImageArray,
|
||||
EncodedImageScalar,
|
||||
EncodedImageType,
|
||||
ImageURIArray,
|
||||
ImageURIScalar,
|
||||
ImageURIType,
|
||||
)
|
||||
|
||||
PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__)
|
||||
@@ -203,165 +205,87 @@ class ImageMixin(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Image(bytes, ImageMixin):
|
||||
"""Pydantic type for inlined images.
|
||||
def EncodedImage():
|
||||
import PIL.Image
|
||||
|
||||
!!! warning
|
||||
Experimental feature.
|
||||
class EncodedImage(bytes, ImageMixin):
|
||||
"""Pydantic type for inlined images.
|
||||
|
||||
Examples
|
||||
--------
|
||||
!!! warning
|
||||
Experimental feature.
|
||||
|
||||
>>> import pydantic
|
||||
>>> from lancedb.pydantic import Image
|
||||
...
|
||||
>>> class MyModel(pydantic.BaseModel):
|
||||
... image: Image
|
||||
>>> schema = pydantic_to_schema(MyModel)
|
||||
>>> assert schema == pa.schema([
|
||||
... pa.field("image", pa.binary(), False)
|
||||
... ])
|
||||
"""
|
||||
Examples
|
||||
--------
|
||||
|
||||
def __repr__(self):
|
||||
return "Image()"
|
||||
>>> import pydantic
|
||||
>>> from lancedb.pydantic import EncodedImage
|
||||
...
|
||||
>>> class MyModel(pydantic.BaseModel):
|
||||
... image: EncodedImage()
|
||||
>>> schema = pydantic_to_schema(MyModel)
|
||||
>>> assert schema == pa.schema([
|
||||
... pa.field("image", pa.binary(), False)
|
||||
... ])
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def value_arrow_type() -> pa.DataType:
|
||||
return pa.binary()
|
||||
def __repr__(self):
|
||||
return "EncodedImage()"
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
def validate_from_bytes(value: bytes) -> EncodedImageScalar:
|
||||
return EncodedImageScalar(value)
|
||||
|
||||
from_bytes_schema = core_schema.chain_schema(
|
||||
[
|
||||
core_schema.bytes_schema(),
|
||||
core_schema.no_info_plain_validator_function(validate_from_bytes),
|
||||
]
|
||||
)
|
||||
|
||||
return core_schema.json_or_python_schema(
|
||||
json_schema=from_bytes_schema,
|
||||
python_schema=core_schema.union_schema(
|
||||
[
|
||||
core_schema.is_instance_schema(EncodedImageArray),
|
||||
from_bytes_schema,
|
||||
]
|
||||
),
|
||||
serialization=core_schema.plain_serializer_function_ser_schema(
|
||||
lambda instance: instance.values
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(
|
||||
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
|
||||
) -> JsonSchemaValue:
|
||||
return handler(core_schema.bytes_schema())
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> Generator[Callable, None, None]:
|
||||
yield cls.validate
|
||||
|
||||
# For pydantic v2
|
||||
@classmethod
|
||||
def validate(cls, v):
|
||||
if isinstance(v, ImageURIArray):
|
||||
v = v.read_uris()
|
||||
if isinstance(v, pa.BinaryArray):
|
||||
v = pa.ExtensionArray.from_storage(EncodedImageType(), v)
|
||||
if not isinstance(v, EncodedImageArray):
|
||||
raise TypeError("Invalid input array type", type(v))
|
||||
|
||||
return v
|
||||
|
||||
if PYDANTIC_VERSION < (2, 0):
|
||||
@staticmethod
|
||||
def value_arrow_type() -> pa.DataType:
|
||||
return EncodedImageType()
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]):
|
||||
field_schema["type"] = "string"
|
||||
field_schema["format"] = "binary"
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
from_bytes_schema = core_schema.bytes_schema()
|
||||
|
||||
|
||||
class ImageURI(str, ImageMixin):
|
||||
"""Pydantic ImageUri Type.
|
||||
|
||||
!!! warning
|
||||
Experimental feature.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> import pydantic
|
||||
>>> from lancedb.pydantic import ImageURI
|
||||
...
|
||||
>>> class MyModel(pydantic.BaseModel):
|
||||
... url: ImageURI()
|
||||
>>> schema = pydantic_to_schema(MyModel)
|
||||
>>> assert schema == pa.schema([
|
||||
... pa.field("url", pa.utf8(), False),
|
||||
... ])
|
||||
"""
|
||||
def __repr__(self):
|
||||
return "ImageURI()"
|
||||
|
||||
@staticmethod
|
||||
def value_arrow_type() -> pa.DataType:
|
||||
return pa.string()
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
def validate_from_str(value: str) -> ImageURIScalar:
|
||||
return ImageURIScalar(value)
|
||||
|
||||
from_str_schema = core_schema.chain_schema(
|
||||
[
|
||||
core_schema.str_schema(),
|
||||
core_schema.no_info_plain_validator_function(validate_from_str),
|
||||
]
|
||||
)
|
||||
|
||||
return core_schema.json_or_python_schema(
|
||||
json_schema=from_str_schema,
|
||||
python_schema=core_schema.union_schema(
|
||||
[
|
||||
core_schema.is_instance_schema(ImageURIArray),
|
||||
from_str_schema,
|
||||
]
|
||||
),
|
||||
serialization=core_schema.plain_serializer_function_ser_schema(
|
||||
lambda instance: instance.values
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> Generator[Callable, None, None]:
|
||||
yield cls.validate
|
||||
|
||||
# For pydantic v2
|
||||
@classmethod
|
||||
def validate(cls, v):
|
||||
from lance.arrow import ImageURIArray, ImageURIType
|
||||
|
||||
if isinstance(v, (str, pa.StringArray)):
|
||||
v = pa.ExtensionArray.from_storage(ImageURIType(), v)
|
||||
if not isinstance(v, ImageURIArray):
|
||||
raise TypeError("Invalid input array type", type(v))
|
||||
|
||||
return v
|
||||
|
||||
if PYDANTIC_VERSION < (2, 0):
|
||||
return core_schema.json_or_python_schema(
|
||||
json_schema=from_bytes_schema,
|
||||
python_schema=core_schema.union_schema(
|
||||
[
|
||||
core_schema.is_instance_schema(PIL.Image.Image),
|
||||
from_bytes_schema,
|
||||
]
|
||||
),
|
||||
serialization=core_schema.plain_serializer_function_ser_schema(
|
||||
lambda instance: cls.validate(instance)
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]):
|
||||
field_schema["type"] = "string"
|
||||
field_schema["format"] = "string"
|
||||
def __get_pydantic_json_schema__(
|
||||
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
|
||||
) -> JsonSchemaValue:
|
||||
return handler(core_schema.bytes_schema())
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> Generator[Callable, None, None]:
|
||||
yield cls.validate
|
||||
|
||||
# For pydantic v2
|
||||
@classmethod
|
||||
def validate(cls, v):
|
||||
if isinstance(v, bytes):
|
||||
return v
|
||||
if isinstance(v, PIL.Image.Image):
|
||||
with io.BytesIO() as output:
|
||||
v.save(output, format=v.format)
|
||||
return output.getvalue()
|
||||
raise TypeError(
|
||||
"EncodedImage can take bytes or PIL.Image.Image "
|
||||
f"as input but got {type(v)}"
|
||||
)
|
||||
|
||||
if PYDANTIC_VERSION < (2, 0):
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]):
|
||||
field_schema["type"] = "string"
|
||||
field_schema["format"] = "binary"
|
||||
|
||||
return EncodedImage
|
||||
|
||||
|
||||
if PYDANTIC_VERSION.major < 2:
|
||||
|
||||
|
Before Width: | Height: | Size: 83 B After Width: | Height: | Size: 83 B |
@@ -12,6 +12,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import io
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
@@ -23,13 +24,12 @@ import pyarrow as pa
|
||||
import pydantic
|
||||
import pytest
|
||||
import pytz
|
||||
from lance.arrow import ImageURIArray
|
||||
from pydantic import Field
|
||||
from lance.arrow import EncodedImageType
|
||||
|
||||
from lancedb.pydantic import (
|
||||
PYDANTIC_VERSION,
|
||||
Image,
|
||||
ImageURI,
|
||||
EncodedImage,
|
||||
LanceModel,
|
||||
Vector,
|
||||
pydantic_to_schema,
|
||||
@@ -257,32 +257,25 @@ def test_lance_model():
|
||||
assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3])
|
||||
|
||||
|
||||
def test_lance_model_with_lance_types():
|
||||
png_uris = [
|
||||
"file://" + os.path.join(os.path.dirname(__file__), "images/1.png"),
|
||||
]
|
||||
if os.name == "nt":
|
||||
png_uris = [str(Path(x)) for x in png_uris]
|
||||
|
||||
default_image_uris = ImageURIArray.from_uris(png_uris)
|
||||
default_encoded_images = default_image_uris.read_uris()
|
||||
def test_schema_with_images():
|
||||
pytest.importorskip("PIL")
|
||||
import PIL.Image
|
||||
|
||||
class TestModel(LanceModel):
|
||||
encoded_images: Image() = Field(default=default_encoded_images)
|
||||
image_uris: ImageURI() = Field(default=default_image_uris)
|
||||
img: EncodedImage()
|
||||
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
schema = pa.schema([pa.field("img", EncodedImageType(), False)])
|
||||
assert schema == TestModel.to_arrow_schema()
|
||||
assert TestModel.field_names() == ["encoded_images", "image_uris"]
|
||||
assert TestModel.field_names() == ["img"]
|
||||
|
||||
expected_model = TestModel()
|
||||
img_path = Path(os.path.dirname(__file__)) / "images/1.png"
|
||||
with open(img_path, "rb") as f:
|
||||
img_bytes = f.read()
|
||||
|
||||
actual_model = TestModel(
|
||||
encoded_images=default_encoded_images, image_uris=default_image_uris
|
||||
)
|
||||
assert expected_model == actual_model
|
||||
m1 = TestModel(img=PIL.Image.open(img_path))
|
||||
m2 = TestModel(img=img_bytes)
|
||||
|
||||
actual_model = TestModel(
|
||||
encoded_images=default_encoded_images, image_uris=default_image_uris
|
||||
)
|
||||
assert expected_model == actual_model
|
||||
def tobytes(m):
|
||||
return PIL.Image.open(io.BytesIO(m.model_dump()["img"])).tobytes()
|
||||
|
||||
assert tobytes(m1) == tobytes(m2)
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
import functools
|
||||
from copy import copy
|
||||
from datetime import date, datetime, timedelta
|
||||
import io
|
||||
import os
|
||||
from pathlib import Path
|
||||
from time import sleep
|
||||
from typing import List
|
||||
@@ -30,7 +32,7 @@ import pytest_asyncio
|
||||
from lancedb.conftest import MockTextEmbeddingFunction
|
||||
from lancedb.db import AsyncConnection, LanceDBConnection
|
||||
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.pydantic import EncodedImage, LanceModel, Vector
|
||||
from lancedb.table import LanceTable
|
||||
from pydantic import BaseModel
|
||||
from lance.arrow import EncodedImageArray, EncodedImageType, ImageURIType
|
||||
@@ -110,7 +112,6 @@ def test_create_table(db):
|
||||
pa.field("item", pa.string()),
|
||||
pa.field("price", pa.float32()),
|
||||
pa.field("encoded_image", EncodedImageType()),
|
||||
pa.field("image_uris", ImageURIType()),
|
||||
]
|
||||
)
|
||||
expected = pa.Table.from_arrays(
|
||||
@@ -121,9 +122,6 @@ def test_create_table(db):
|
||||
pa.ExtensionArray.from_storage(
|
||||
EncodedImageType(), pa.array([b"foo", b"bar"], pa.binary())
|
||||
),
|
||||
pa.ExtensionArray.from_storage(
|
||||
ImageURIType(), pa.array(["/tmp/foo", "/tmp/bar"], pa.string())
|
||||
),
|
||||
],
|
||||
schema=schema,
|
||||
)
|
||||
@@ -134,14 +132,12 @@ def test_create_table(db):
|
||||
"item": "foo",
|
||||
"price": 10.0,
|
||||
"encoded_image": b"foo",
|
||||
"image_uris": "/tmp/foo",
|
||||
},
|
||||
{
|
||||
"vector": [5.9, 26.5],
|
||||
"item": "bar",
|
||||
"price": 20.0,
|
||||
"encoded_image": b"bar",
|
||||
"image_uris": "/tmp/bar",
|
||||
},
|
||||
]
|
||||
]
|
||||
@@ -1046,3 +1042,22 @@ async def test_time_travel(db_async: AsyncConnection):
|
||||
# Can't use restore if not checked out
|
||||
with pytest.raises(ValueError, match="checkout before running restore"):
|
||||
await table.restore()
|
||||
|
||||
|
||||
def test_add_image(tmp_path):
|
||||
pytest.importorskip("PIL")
|
||||
import PIL.Image
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
|
||||
class TestModel(LanceModel):
|
||||
img: EncodedImage()
|
||||
|
||||
img_path = Path(os.path.dirname(__file__)) / "images/1.png"
|
||||
m1 = TestModel(img=PIL.Image.open(img_path))
|
||||
|
||||
def tobytes(m):
|
||||
return PIL.Image.open(io.BytesIO(m.model_dump()["img"])).tobytes()
|
||||
|
||||
table = LanceTable.create(db, "my_table", schema=TestModel)
|
||||
table.add([m1])
|
||||
|
||||
Reference in New Issue
Block a user