From ac955a5a7efcc73f2162a2526e0c0e896eb16eb2 Mon Sep 17 00:00:00 2001 From: Rok Mihevc Date: Wed, 25 Oct 2023 04:36:57 +0200 Subject: [PATCH] initial commit --- python/python/lancedb/pydantic.py | 123 +++++++++++++++++++++++++++ python/python/tests/test_pydantic.py | 24 +++++- python/python/tests/test_table.py | 25 +++++- 3 files changed, 169 insertions(+), 3 deletions(-) diff --git a/python/python/lancedb/pydantic.py b/python/python/lancedb/pydantic.py index 89bf08d6..58b6294f 100644 --- a/python/python/lancedb/pydantic.py +++ b/python/python/lancedb/pydantic.py @@ -186,6 +186,129 @@ def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType: ) +class ImageMixin(ABC): + @staticmethod + @abstractmethod + def value_arrow_type() -> pa.DataType: + raise NotImplementedError + + +def EncodedImage() -> Type[ImageMixin]: + """Pydantic EncodedImage Type. + + !!! warning + Experimental feature. + + Examples + -------- + + >>> import pydantic + >>> from lancedb.pydantic import EncodedImage + ... + >>> class MyModel(pydantic.BaseModel): + ... image: EncodedImage() + >>> schema = pydantic_to_schema(MyModel) + >>> assert schema == pa.schema([ + ... pa.field("image", pa.binary(), False) + ... ]) + """ + + class EncodedImage(bytes, ImageMixin): + def __repr__(self): + return "EncodedImage()" + + @staticmethod + def value_arrow_type() -> pa.DataType: + return pa.binary() + + @classmethod + def __get_pydantic_core_schema__( + cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler + ) -> CoreSchema: + return core_schema.no_info_after_validator_function( + cls, + core_schema.binary_schema(), + ) + + @classmethod + def __get_validators__(cls) -> Generator[Callable, None, None]: + yield cls.validate + + # For pydantic v1 + @classmethod + def validate(cls, v): + if not isinstance(v, bytes): + raise TypeError("A bytes is needed") + return cls(v) + + if PYDANTIC_VERSION < (2, 0): + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]): + field_schema["type"] = "string" + field_schema["format"] = "binary" + + return EncodedImage + + +def ImageURI() -> Type[ImageMixin]: + """Pydantic ImageUri Type. + + !!! warning + Experimental feature. + + Examples + -------- + + >>> import pydantic + >>> from lancedb.pydantic import ImageURI + ... + >>> class MyModel(pydantic.BaseModel): + ... url: ImageURI() + >>> schema = pydantic_to_schema(MyModel) + >>> assert schema == pa.schema([ + ... pa.field("url", pa.utf8(), False), + ... ]) + """ + + class ImageURI(str, ImageMixin): + def __repr__(self): + return "ImageURI()" + + @staticmethod + def value_arrow_type() -> pa.DataType: + return pa.string() + + @classmethod + def __get_pydantic_core_schema__( + cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler + ) -> CoreSchema: + return core_schema.no_info_after_validator_function( + cls, + core_schema.string_schema(), + ) + + @classmethod + def __get_validators__(cls) -> Generator[Callable, None, None]: + yield cls.validate + + # For pydantic v1 + @classmethod + def validate(cls, v): + if not isinstance(v, str): + raise TypeError("A str is needed") + return cls(v) + + if PYDANTIC_VERSION < (2, 0): + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]): + field_schema["type"] = "string" + field_schema["format"] = "string" + + return ImageURI + + if PYDANTIC_VERSION.major < 2: def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]: diff --git a/python/python/tests/test_pydantic.py b/python/python/tests/test_pydantic.py index 8f9d335c..97c8a0aa 100644 --- a/python/python/tests/test_pydantic.py +++ b/python/python/tests/test_pydantic.py @@ -20,7 +20,14 @@ from typing import List, Optional, Tuple import pyarrow as pa import pydantic import pytest -from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema +from lancedb.pydantic import ( + PYDANTIC_VERSION, + LanceModel, + Vector, + pydantic_to_schema, + EncodedImage, + ImageURI, +) from pydantic import Field @@ -243,3 +250,18 @@ def test_lance_model(): t = TestModel() assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3]) + + +def test_lance_model_with_lance_types(): + class TestModel(LanceModel): + image: EncodedImage() = Field() + uri: ImageURI() = Field() + # TODO: tensor type? + + # TODO + # schema = pydantic_to_schema(TestModel) + # assert schema == TestModel.to_arrow_schema() + # assert TestModel.field_names() == ["image", "uri"] + # + # t = TestModel() + # assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3], image=EncodedImageArray(), uri="https://lancedb.dev") diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index 518b19e1..be6b4041 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -33,6 +33,7 @@ from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistr from lancedb.pydantic import LanceModel, Vector from lancedb.table import LanceTable from pydantic import BaseModel +from lance.arrow import EncodedImageArray, EncodedImageType, ImageURIType class MockDB: @@ -108,6 +109,8 @@ def test_create_table(db): pa.field("vector", pa.list_(pa.float32(), 2)), pa.field("item", pa.string()), pa.field("price", pa.float32()), + pa.field("encoded_image", EncodedImageType()), + pa.field("image_uris", ImageURIType()), ] ) expected = pa.Table.from_arrays( @@ -115,13 +118,31 @@ def test_create_table(db): pa.FixedSizeListArray.from_arrays(pa.array([3.1, 4.1, 5.9, 26.5]), 2), pa.array(["foo", "bar"]), pa.array([10.0, 20.0]), + pa.ExtensionArray.from_storage( + EncodedImageType(), pa.array([b"foo", b"bar"], pa.binary()) + ), + pa.ExtensionArray.from_storage( + ImageURIType(), pa.array(["/tmp/foo", "/tmp/bar"], pa.string()) + ), ], schema=schema, ) data = [ [ - {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, - {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}, + { + "vector": [3.1, 4.1], + "item": "foo", + "price": 10.0, + "encoded_image": b"foo", + "image_uris": "/tmp/foo", + }, + { + "vector": [5.9, 26.5], + "item": "bar", + "price": 20.0, + "encoded_image": b"bar", + "image_uris": "/tmp/bar", + }, ] ] df = pd.DataFrame(data[0])