Compare commits

...

17 Commits

Author SHA1 Message Date
Chang She
86c9bc0d2d stuff 2024-03-12 19:03:26 -07:00
Chang She
c1dfad675a see if we make EncodedImage work 2024-03-12 19:03:26 -07:00
Chang She
2e1838a62a ruff 2024-03-12 19:03:26 -07:00
Chang She
4d39f63cf6 add import guidance 2024-03-12 19:03:26 -07:00
Chang She
3c4f2a7020 fix 2024-03-12 19:03:26 -07:00
Chang She
48a4202748 fix 2024-03-12 19:03:26 -07:00
Chang She
2084fbcff4 working? 2024-03-12 19:03:26 -07:00
Chang She
408988abce just keep EncodedImage for now 2024-03-12 19:03:25 -07:00
Chang She
e68fbf65cc foo 2024-03-12 18:45:32 -07:00
Rok Mihevc
63399dc0ee unused imports 2024-03-12 18:45:32 -07:00
Rok Mihevc
0b0f4e9d1c __get_pydantic_core_schema__ 2024-03-12 18:45:32 -07:00
Rok Mihevc
2ec0e79303 Minor change 2024-03-12 18:45:32 -07:00
Rok Mihevc
d86dd2c60d test automatic reading of uris 2024-03-12 18:45:32 -07:00
Rok Mihevc
67b38d6115 changes 2024-03-12 18:45:32 -07:00
Rok Mihevc
c112dea28b work 2024-03-12 18:45:32 -07:00
Rok Mihevc
d662b9744e black 2024-03-12 18:45:32 -07:00
Rok Mihevc
ac955a5a7e initial commit 2024-03-12 18:45:32 -07:00
7 changed files with 268 additions and 46 deletions

View File

@@ -57,6 +57,7 @@ tests = [
"duckdb", "duckdb",
"pytz", "pytz",
"polars>=0.19", "polars>=0.19",
"pillow",
] ]
dev = ["ruff", "pre-commit"] dev = ["ruff", "pre-commit"]
docs = [ docs = [

View File

@@ -126,6 +126,10 @@ class OpenClipEmbeddings(EmbeddingFunction):
""" """
Issue concurrent requests to retrieve the image data Issue concurrent requests to retrieve the image data
""" """
return [
self.generate_image_embedding(image) for image in tqdm(images)
]
with concurrent.futures.ThreadPoolExecutor() as executor: with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [ futures = [
executor.submit(self.generate_image_embedding, image) executor.submit(self.generate_image_embedding, image)

View File

@@ -16,6 +16,7 @@
from __future__ import annotations from __future__ import annotations
import inspect import inspect
import io
import sys import sys
import types import types
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@@ -26,7 +27,9 @@ from typing import (
Callable, Callable,
Dict, Dict,
Generator, Generator,
Iterable,
List, List,
Tuple,
Type, Type,
Union, Union,
_GenericAlias, _GenericAlias,
@@ -36,19 +39,30 @@ import numpy as np
import pyarrow as pa import pyarrow as pa
import pydantic import pydantic
import semver import semver
from lance.arrow import (
EncodedImageType,
)
from lance.util import _check_huggingface
from pydantic.fields import FieldInfo
from pydantic_core import core_schema
from .util import attempt_import_or_raise
PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__) PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__)
try:
from pydantic_core import CoreSchema, core_schema
except ImportError:
if PYDANTIC_VERSION >= (2,):
raise
if TYPE_CHECKING: if TYPE_CHECKING:
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
from .embeddings import EmbeddingFunctionConfig from .embeddings import EmbeddingFunctionConfig
try:
from pydantic import GetJsonSchemaHandler
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import CoreSchema
except ImportError:
if PYDANTIC_VERSION >= (2,):
raise
class FixedSizeListMixin(ABC): class FixedSizeListMixin(ABC):
@staticmethod @staticmethod
@@ -123,7 +137,7 @@ def Vector(
@classmethod @classmethod
def __get_pydantic_core_schema__( def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
) -> CoreSchema: ) -> "CoreSchema":
return core_schema.no_info_after_validator_function( return core_schema.no_info_after_validator_function(
cls, cls,
core_schema.list_schema( core_schema.list_schema(
@@ -181,25 +195,118 @@ def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType:
elif getattr(py_type, "__origin__", None) in (list, tuple): elif getattr(py_type, "__origin__", None) in (list, tuple):
child = py_type.__args__[0] child = py_type.__args__[0]
return pa.list_(_py_type_to_arrow_type(child, field)) return pa.list_(_py_type_to_arrow_type(child, field))
elif _safe_is_huggingface_image():
import datasets
raise TypeError( raise TypeError(
f"Converting Pydantic type to Arrow Type: unsupported type {py_type}." f"Converting Pydantic type to Arrow Type: unsupported type {py_type}."
) )
class ImageMixin(ABC):
@staticmethod
@abstractmethod
def value_arrow_type() -> pa.DataType:
raise NotImplementedError
def EncodedImage():
attempt_import_or_raise("PIL", "pillow or pip install lancedb[embeddings]")
import PIL.Image
class EncodedImage(bytes, ImageMixin):
"""Pydantic type for inlined images.
!!! warning
Experimental feature.
Examples
--------
>>> import pydantic
>>> from lancedb.pydantic import EncodedImage
...
>>> class MyModel(pydantic.BaseModel):
... image: EncodedImage()
>>> schema = pydantic_to_schema(MyModel)
>>> assert schema == pa.schema([
... pa.field("image", pa.binary(), False)
... ])
"""
def __repr__(self):
return "EncodedImage()"
@staticmethod
def value_arrow_type() -> pa.DataType:
return EncodedImageType()
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
) -> "CoreSchema":
from_bytes_schema = core_schema.bytes_schema()
return core_schema.json_or_python_schema(
json_schema=from_bytes_schema,
python_schema=core_schema.union_schema(
[
core_schema.is_instance_schema(PIL.Image.Image),
from_bytes_schema,
]
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: cls.validate(instance)
),
)
@classmethod
def __get_pydantic_json_schema__(
cls, _core_schema: "CoreSchema", handler: "GetJsonSchemaHandler"
) -> "JsonSchemaValue":
return handler(core_schema.bytes_schema())
@classmethod
def __get_validators__(cls) -> Generator[Callable, None, None]:
yield cls.validate
# For pydantic v2
@classmethod
def validate(cls, v):
if isinstance(v, bytes):
return v
if isinstance(v, PIL.Image.Image):
with io.BytesIO() as output:
v.save(output, format=v.format)
return output.getvalue()
raise TypeError(
"EncodedImage can take bytes or PIL.Image.Image "
f"as input but got {type(v)}"
)
if PYDANTIC_VERSION < (2, 0):
@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]):
field_schema["type"] = "string"
field_schema["format"] = "binary"
return EncodedImage
if PYDANTIC_VERSION.major < 2: if PYDANTIC_VERSION.major < 2:
def _safe_get_fields(model: pydantic.BaseModel):
def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]: return model.__fields__
return [
_pydantic_to_field(name, field) for name, field in model.__fields__.items()
]
else: else:
def _safe_get_fields(model: pydantic.BaseModel):
return model.model_fields
def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]:
return [ def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]:
_pydantic_to_field(name, field) return [
for name, field in model.model_fields.items() _pydantic_to_field(name, field)
] for name, field in _safe_get_fields(model).items()
]
def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType: def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
@@ -230,6 +337,9 @@ def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
return pa.struct(fields) return pa.struct(fields)
elif issubclass(field.annotation, FixedSizeListMixin): elif issubclass(field.annotation, FixedSizeListMixin):
return pa.list_(field.annotation.value_arrow_type(), field.annotation.dim()) return pa.list_(field.annotation.value_arrow_type(), field.annotation.dim())
elif issubclass(field.annotation, ImageMixin):
return field.annotation.value_arrow_type()
return _py_type_to_arrow_type(field.annotation, field) return _py_type_to_arrow_type(field.annotation, field)
@@ -335,13 +445,7 @@ class LanceModel(pydantic.BaseModel):
""" """
Get the field names of this model. Get the field names of this model.
""" """
return list(cls.safe_get_fields().keys()) return list(_safe_get_fields(cls).keys())
@classmethod
def safe_get_fields(cls):
if PYDANTIC_VERSION.major < 2:
return cls.__fields__
return cls.model_fields
@classmethod @classmethod
def parse_embedding_functions(cls) -> List["EmbeddingFunctionConfig"]: def parse_embedding_functions(cls) -> List["EmbeddingFunctionConfig"]:
@@ -351,14 +455,16 @@ class LanceModel(pydantic.BaseModel):
from .embeddings import EmbeddingFunctionConfig from .embeddings import EmbeddingFunctionConfig
vec_and_function = [] vec_and_function = []
for name, field_info in cls.safe_get_fields().items(): def get_vector_column(name, field_info):
func = get_extras(field_info, "vector_column_for") fun = get_extras(field_info, "vector_column_for")
if func is not None: if func is not None:
vec_and_function.append([name, func]) vec_and_function.append([name, func])
visit_fields(_safe_get_fields(cls).items(), get_vector_column)
configs = [] configs = []
# find the source columns for each one
for vec, func in vec_and_function: for vec, func in vec_and_function:
for source, field_info in cls.safe_get_fields().items(): def get_source_column(source, field_info):
src_func = get_extras(field_info, "source_column_for") src_func = get_extras(field_info, "source_column_for")
if src_func is func: if src_func is func:
# note we can't use == here since the function is a pydantic # note we can't use == here since the function is a pydantic
@@ -371,20 +477,48 @@ class LanceModel(pydantic.BaseModel):
source_column=source, vector_column=vec, function=func source_column=source, vector_column=vec, function=func
) )
) )
visit_fields(_safe_get_fields(cls).items(), get_source_column)
return configs return configs
def visit_fields(fields: Iterable[Tuple[str, FieldInfo]],
visitor: Callable[[str, FieldInfo], Any]):
"""
Visit all the leaf fields in a Pydantic model.
Parameters
----------
fields : Iterable[Tuple(str, FieldInfo)]
The fields to visit.
visitor : Callable[[str, FieldInfo], Any]
The visitor function.
"""
for name, field_info in fields:
# if the field is a pydantic model then
# visit all subfields
if (isinstance(getattr(field_info, "annotation"), type)
and issubclass(field_info.annotation, pydantic.BaseModel)):
visit_fields(_safe_get_fields(field_info.annotation).items(),
_add_prefix(visitor, name))
else:
visitor(name, field_info)
def get_extras(field_info: FieldInfo, key: str) -> Any: def _add_prefix(visitor: Callable[[str, FieldInfo], Any], prefix: str) -> Callable[[str, FieldInfo], Any]:
""" def prefixed_visitor(name: str, field: FieldInfo):
Get the extra metadata from a Pydantic FieldInfo. return visitor(f"{prefix}.{name}", field)
""" return prefixed_visitor
if PYDANTIC_VERSION.major >= 2:
return (field_info.json_schema_extra or {}).get(key)
return (field_info.field_info.extra or {}).get("json_schema_extra", {}).get(key)
if PYDANTIC_VERSION.major < 2: if PYDANTIC_VERSION.major < 2:
def get_extras(field_info: FieldInfo, key: str) -> Any:
"""
Get the extra metadata from a Pydantic FieldInfo.
"""
return (field_info.field_info.extra or {}).get("json_schema_extra", {}).get(key)
def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]: def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]:
""" """
Convert a Pydantic model to a dictionary. Convert a Pydantic model to a dictionary.
@@ -393,6 +527,13 @@ if PYDANTIC_VERSION.major < 2:
else: else:
def get_extras(field_info: FieldInfo, key: str) -> Any:
"""
Get the extra metadata from a Pydantic FieldInfo.
"""
return (field_info.json_schema_extra or {}).get(key)
def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]: def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]:
""" """
Convert a Pydantic model to a dictionary. Convert a Pydantic model to a dictionary.

View File

@@ -37,6 +37,7 @@ import pyarrow as pa
import pyarrow.compute as pc import pyarrow.compute as pc
import pyarrow.fs as pa_fs import pyarrow.fs as pa_fs
from lance import LanceDataset from lance import LanceDataset
from lance.dependencies import _check_for_hugging_face
from lance.vector import vec_to_table from lance.vector import vec_to_table
from .common import DATA, VEC, VECTOR_COLUMN_NAME from .common import DATA, VEC, VECTOR_COLUMN_NAME
@@ -74,7 +75,16 @@ def _sanitize_data(
on_bad_vectors: str, on_bad_vectors: str,
fill_value: Any, fill_value: Any,
): ):
if isinstance(data, list): import pdb; pdb.set_trace()
if _check_for_hugging_face(data):
# Huggingface datasets
import datasets
if isinstance(data, datasets.Dataset):
if schema is None:
schema = data.features.arrow_schema
data = data.data.to_batches()
elif isinstance(data, list):
# convert to list of dict if data is a bunch of LanceModels # convert to list of dict if data is a bunch of LanceModels
if isinstance(data[0], LanceModel): if isinstance(data[0], LanceModel):
schema = data[0].__class__.to_arrow_schema() schema = data[0].__class__.to_arrow_schema()
@@ -136,11 +146,10 @@ def _to_record_batch_generator(
data: Iterable, schema, metadata, on_bad_vectors, fill_value data: Iterable, schema, metadata, on_bad_vectors, fill_value
): ):
for batch in data: for batch in data:
if not isinstance(batch, pa.RecordBatch): if isinstance(batch, pa.RecordBatch):
table = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value) batch = pa.Table.from_batches([batch])
for batch in table.to_batches(): table = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value)
yield batch for batch in table.to_batches():
else:
yield batch yield batch

Binary file not shown.

After

Width:  |  Height:  |  Size: 83 B

View File

@@ -12,17 +12,27 @@
# limitations under the License. # limitations under the License.
import io
import json import json
import os
import sys import sys
from datetime import date, datetime from datetime import date, datetime
from pathlib import Path
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import pyarrow as pa import pyarrow as pa
import pydantic import pydantic
import pytest import pytest
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema
from pydantic import Field from pydantic import Field
from lancedb.pydantic import (
PYDANTIC_VERSION,
EncodedImage,
LanceModel,
Vector,
pydantic_to_schema,
)
@pytest.mark.skipif( @pytest.mark.skipif(
sys.version_info < (3, 9), sys.version_info < (3, 9),
@@ -243,3 +253,23 @@ def test_lance_model():
t = TestModel() t = TestModel()
assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3]) assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3])
def test_schema_with_images():
pytest.importorskip("PIL")
import PIL.Image
class TestModel(LanceModel):
img: EncodedImage()
img_path = Path(os.path.dirname(__file__)) / "images/1.png"
with open(img_path, "rb") as f:
img_bytes = f.read()
m1 = TestModel(img=PIL.Image.open(img_path))
m2 = TestModel(img=img_bytes)
def tobytes(m):
return PIL.Image.open(io.BytesIO(m.model_dump()["img"])).tobytes()
assert tobytes(m1) == tobytes(m2)

View File

@@ -12,6 +12,8 @@
# limitations under the License. # limitations under the License.
import functools import functools
import io
import os
from copy import copy from copy import copy
from datetime import date, datetime, timedelta from datetime import date, datetime, timedelta
from pathlib import Path from pathlib import Path
@@ -20,19 +22,21 @@ from typing import List
from unittest.mock import PropertyMock, patch from unittest.mock import PropertyMock, patch
import lance import lance
import lancedb
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import polars as pl import polars as pl
import pyarrow as pa import pyarrow as pa
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from lance.arrow import EncodedImageType
from pydantic import BaseModel
import lancedb
from lancedb.conftest import MockTextEmbeddingFunction from lancedb.conftest import MockTextEmbeddingFunction
from lancedb.db import AsyncConnection, LanceDBConnection from lancedb.db import AsyncConnection, LanceDBConnection
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
from lancedb.pydantic import LanceModel, Vector from lancedb.pydantic import EncodedImage, LanceModel, Vector
from lancedb.table import LanceTable from lancedb.table import LanceTable
from pydantic import BaseModel
class MockDB: class MockDB:
@@ -108,6 +112,7 @@ def test_create_table(db):
pa.field("vector", pa.list_(pa.float32(), 2)), pa.field("vector", pa.list_(pa.float32(), 2)),
pa.field("item", pa.string()), pa.field("item", pa.string()),
pa.field("price", pa.float32()), pa.field("price", pa.float32()),
pa.field("encoded_image", EncodedImageType()),
] ]
) )
expected = pa.Table.from_arrays( expected = pa.Table.from_arrays(
@@ -115,13 +120,26 @@ def test_create_table(db):
pa.FixedSizeListArray.from_arrays(pa.array([3.1, 4.1, 5.9, 26.5]), 2), pa.FixedSizeListArray.from_arrays(pa.array([3.1, 4.1, 5.9, 26.5]), 2),
pa.array(["foo", "bar"]), pa.array(["foo", "bar"]),
pa.array([10.0, 20.0]), pa.array([10.0, 20.0]),
pa.ExtensionArray.from_storage(
EncodedImageType(), pa.array([b"foo", b"bar"], pa.binary())
),
], ],
schema=schema, schema=schema,
) )
data = [ data = [
[ [
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, {
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0}, "vector": [3.1, 4.1],
"item": "foo",
"price": 10.0,
"encoded_image": b"foo",
},
{
"vector": [5.9, 26.5],
"item": "bar",
"price": 20.0,
"encoded_image": b"bar",
},
] ]
] ]
df = pd.DataFrame(data[0]) df = pd.DataFrame(data[0])
@@ -1025,3 +1043,22 @@ async def test_time_travel(db_async: AsyncConnection):
# Can't use restore if not checked out # Can't use restore if not checked out
with pytest.raises(ValueError, match="checkout before running restore"): with pytest.raises(ValueError, match="checkout before running restore"):
await table.restore() await table.restore()
def test_add_image(tmp_path):
pytest.importorskip("PIL")
import PIL.Image
db = lancedb.connect(tmp_path)
class TestModel(LanceModel):
img: EncodedImage()
img_path = Path(os.path.dirname(__file__)) / "images/1.png"
m1 = TestModel(img=PIL.Image.open(img_path))
def tobytes(m):
return PIL.Image.open(io.BytesIO(m.model_dump()["img"])).tobytes()
table = LanceTable.create(db, "my_table", schema=TestModel)
table.add([m1])