mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-12 23:02:59 +00:00
initial commit
This commit is contained in:
@@ -186,6 +186,129 @@ def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType:
|
||||
)
|
||||
|
||||
|
||||
class ImageMixin(ABC):
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def value_arrow_type() -> pa.DataType:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def EncodedImage() -> Type[ImageMixin]:
|
||||
"""Pydantic EncodedImage Type.
|
||||
|
||||
!!! warning
|
||||
Experimental feature.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> import pydantic
|
||||
>>> from lancedb.pydantic import EncodedImage
|
||||
...
|
||||
>>> class MyModel(pydantic.BaseModel):
|
||||
... image: EncodedImage()
|
||||
>>> schema = pydantic_to_schema(MyModel)
|
||||
>>> assert schema == pa.schema([
|
||||
... pa.field("image", pa.binary(), False)
|
||||
... ])
|
||||
"""
|
||||
|
||||
class EncodedImage(bytes, ImageMixin):
|
||||
def __repr__(self):
|
||||
return "EncodedImage()"
|
||||
|
||||
@staticmethod
|
||||
def value_arrow_type() -> pa.DataType:
|
||||
return pa.binary()
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
return core_schema.no_info_after_validator_function(
|
||||
cls,
|
||||
core_schema.binary_schema(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> Generator[Callable, None, None]:
|
||||
yield cls.validate
|
||||
|
||||
# For pydantic v1
|
||||
@classmethod
|
||||
def validate(cls, v):
|
||||
if not isinstance(v, bytes):
|
||||
raise TypeError("A bytes is needed")
|
||||
return cls(v)
|
||||
|
||||
if PYDANTIC_VERSION < (2, 0):
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]):
|
||||
field_schema["type"] = "string"
|
||||
field_schema["format"] = "binary"
|
||||
|
||||
return EncodedImage
|
||||
|
||||
|
||||
def ImageURI() -> Type[ImageMixin]:
|
||||
"""Pydantic ImageUri Type.
|
||||
|
||||
!!! warning
|
||||
Experimental feature.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> import pydantic
|
||||
>>> from lancedb.pydantic import ImageURI
|
||||
...
|
||||
>>> class MyModel(pydantic.BaseModel):
|
||||
... url: ImageURI()
|
||||
>>> schema = pydantic_to_schema(MyModel)
|
||||
>>> assert schema == pa.schema([
|
||||
... pa.field("url", pa.utf8(), False),
|
||||
... ])
|
||||
"""
|
||||
|
||||
class ImageURI(str, ImageMixin):
|
||||
def __repr__(self):
|
||||
return "ImageURI()"
|
||||
|
||||
@staticmethod
|
||||
def value_arrow_type() -> pa.DataType:
|
||||
return pa.string()
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
return core_schema.no_info_after_validator_function(
|
||||
cls,
|
||||
core_schema.string_schema(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> Generator[Callable, None, None]:
|
||||
yield cls.validate
|
||||
|
||||
# For pydantic v1
|
||||
@classmethod
|
||||
def validate(cls, v):
|
||||
if not isinstance(v, str):
|
||||
raise TypeError("A str is needed")
|
||||
return cls(v)
|
||||
|
||||
if PYDANTIC_VERSION < (2, 0):
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]):
|
||||
field_schema["type"] = "string"
|
||||
field_schema["format"] = "string"
|
||||
|
||||
return ImageURI
|
||||
|
||||
|
||||
if PYDANTIC_VERSION.major < 2:
|
||||
|
||||
def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]:
|
||||
|
||||
@@ -20,7 +20,14 @@ from typing import List, Optional, Tuple
|
||||
import pyarrow as pa
|
||||
import pydantic
|
||||
import pytest
|
||||
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema
|
||||
from lancedb.pydantic import (
|
||||
PYDANTIC_VERSION,
|
||||
LanceModel,
|
||||
Vector,
|
||||
pydantic_to_schema,
|
||||
EncodedImage,
|
||||
ImageURI,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
@@ -243,3 +250,18 @@ def test_lance_model():
|
||||
|
||||
t = TestModel()
|
||||
assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3])
|
||||
|
||||
|
||||
def test_lance_model_with_lance_types():
|
||||
class TestModel(LanceModel):
|
||||
image: EncodedImage() = Field()
|
||||
uri: ImageURI() = Field()
|
||||
# TODO: tensor type?
|
||||
|
||||
# TODO
|
||||
# schema = pydantic_to_schema(TestModel)
|
||||
# assert schema == TestModel.to_arrow_schema()
|
||||
# assert TestModel.field_names() == ["image", "uri"]
|
||||
#
|
||||
# t = TestModel()
|
||||
# assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3], image=EncodedImageArray(), uri="https://lancedb.dev")
|
||||
|
||||
@@ -33,6 +33,7 @@ from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistr
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.table import LanceTable
|
||||
from pydantic import BaseModel
|
||||
from lance.arrow import EncodedImageArray, EncodedImageType, ImageURIType
|
||||
|
||||
|
||||
class MockDB:
|
||||
@@ -108,6 +109,8 @@ def test_create_table(db):
|
||||
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
pa.field("item", pa.string()),
|
||||
pa.field("price", pa.float32()),
|
||||
pa.field("encoded_image", EncodedImageType()),
|
||||
pa.field("image_uris", ImageURIType()),
|
||||
]
|
||||
)
|
||||
expected = pa.Table.from_arrays(
|
||||
@@ -115,13 +118,31 @@ def test_create_table(db):
|
||||
pa.FixedSizeListArray.from_arrays(pa.array([3.1, 4.1, 5.9, 26.5]), 2),
|
||||
pa.array(["foo", "bar"]),
|
||||
pa.array([10.0, 20.0]),
|
||||
pa.ExtensionArray.from_storage(
|
||||
EncodedImageType(), pa.array([b"foo", b"bar"], pa.binary())
|
||||
),
|
||||
pa.ExtensionArray.from_storage(
|
||||
ImageURIType(), pa.array(["/tmp/foo", "/tmp/bar"], pa.string())
|
||||
),
|
||||
],
|
||||
schema=schema,
|
||||
)
|
||||
data = [
|
||||
[
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||
{
|
||||
"vector": [3.1, 4.1],
|
||||
"item": "foo",
|
||||
"price": 10.0,
|
||||
"encoded_image": b"foo",
|
||||
"image_uris": "/tmp/foo",
|
||||
},
|
||||
{
|
||||
"vector": [5.9, 26.5],
|
||||
"item": "bar",
|
||||
"price": 20.0,
|
||||
"encoded_image": b"bar",
|
||||
"image_uris": "/tmp/bar",
|
||||
},
|
||||
]
|
||||
]
|
||||
df = pd.DataFrame(data[0])
|
||||
|
||||
Reference in New Issue
Block a user