mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 13:29:57 +00:00
Compare commits
17 Commits
python-v0.
...
600
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
86c9bc0d2d | ||
|
|
c1dfad675a | ||
|
|
2e1838a62a | ||
|
|
4d39f63cf6 | ||
|
|
3c4f2a7020 | ||
|
|
48a4202748 | ||
|
|
2084fbcff4 | ||
|
|
408988abce | ||
|
|
e68fbf65cc | ||
|
|
63399dc0ee | ||
|
|
0b0f4e9d1c | ||
|
|
2ec0e79303 | ||
|
|
d86dd2c60d | ||
|
|
67b38d6115 | ||
|
|
c112dea28b | ||
|
|
d662b9744e | ||
|
|
ac955a5a7e |
@@ -57,6 +57,7 @@ tests = [
|
|||||||
"duckdb",
|
"duckdb",
|
||||||
"pytz",
|
"pytz",
|
||||||
"polars>=0.19",
|
"polars>=0.19",
|
||||||
|
"pillow",
|
||||||
]
|
]
|
||||||
dev = ["ruff", "pre-commit"]
|
dev = ["ruff", "pre-commit"]
|
||||||
docs = [
|
docs = [
|
||||||
|
|||||||
@@ -126,6 +126,10 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
|||||||
"""
|
"""
|
||||||
Issue concurrent requests to retrieve the image data
|
Issue concurrent requests to retrieve the image data
|
||||||
"""
|
"""
|
||||||
|
return [
|
||||||
|
self.generate_image_embedding(image) for image in tqdm(images)
|
||||||
|
]
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
futures = [
|
futures = [
|
||||||
executor.submit(self.generate_image_embedding, image)
|
executor.submit(self.generate_image_embedding, image)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
|
import io
|
||||||
import sys
|
import sys
|
||||||
import types
|
import types
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
@@ -26,7 +27,9 @@ from typing import (
|
|||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
Generator,
|
Generator,
|
||||||
|
Iterable,
|
||||||
List,
|
List,
|
||||||
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
Union,
|
Union,
|
||||||
_GenericAlias,
|
_GenericAlias,
|
||||||
@@ -36,19 +39,30 @@ import numpy as np
|
|||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import pydantic
|
import pydantic
|
||||||
import semver
|
import semver
|
||||||
|
from lance.arrow import (
|
||||||
|
EncodedImageType,
|
||||||
|
)
|
||||||
|
from lance.util import _check_huggingface
|
||||||
|
from pydantic.fields import FieldInfo
|
||||||
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
|
from .util import attempt_import_or_raise
|
||||||
|
|
||||||
PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__)
|
PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__)
|
||||||
try:
|
|
||||||
from pydantic_core import CoreSchema, core_schema
|
|
||||||
except ImportError:
|
|
||||||
if PYDANTIC_VERSION >= (2,):
|
|
||||||
raise
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from pydantic.fields import FieldInfo
|
from pydantic.fields import FieldInfo
|
||||||
|
|
||||||
from .embeddings import EmbeddingFunctionConfig
|
from .embeddings import EmbeddingFunctionConfig
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pydantic import GetJsonSchemaHandler
|
||||||
|
from pydantic.json_schema import JsonSchemaValue
|
||||||
|
from pydantic_core import CoreSchema
|
||||||
|
except ImportError:
|
||||||
|
if PYDANTIC_VERSION >= (2,):
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
class FixedSizeListMixin(ABC):
|
class FixedSizeListMixin(ABC):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -123,7 +137,7 @@ def Vector(
|
|||||||
@classmethod
|
@classmethod
|
||||||
def __get_pydantic_core_schema__(
|
def __get_pydantic_core_schema__(
|
||||||
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
|
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
|
||||||
) -> CoreSchema:
|
) -> "CoreSchema":
|
||||||
return core_schema.no_info_after_validator_function(
|
return core_schema.no_info_after_validator_function(
|
||||||
cls,
|
cls,
|
||||||
core_schema.list_schema(
|
core_schema.list_schema(
|
||||||
@@ -181,25 +195,118 @@ def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType:
|
|||||||
elif getattr(py_type, "__origin__", None) in (list, tuple):
|
elif getattr(py_type, "__origin__", None) in (list, tuple):
|
||||||
child = py_type.__args__[0]
|
child = py_type.__args__[0]
|
||||||
return pa.list_(_py_type_to_arrow_type(child, field))
|
return pa.list_(_py_type_to_arrow_type(child, field))
|
||||||
|
elif _safe_is_huggingface_image():
|
||||||
|
import datasets
|
||||||
|
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Converting Pydantic type to Arrow Type: unsupported type {py_type}."
|
f"Converting Pydantic type to Arrow Type: unsupported type {py_type}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageMixin(ABC):
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def value_arrow_type() -> pa.DataType:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def EncodedImage():
|
||||||
|
attempt_import_or_raise("PIL", "pillow or pip install lancedb[embeddings]")
|
||||||
|
import PIL.Image
|
||||||
|
|
||||||
|
class EncodedImage(bytes, ImageMixin):
|
||||||
|
"""Pydantic type for inlined images.
|
||||||
|
|
||||||
|
!!! warning
|
||||||
|
Experimental feature.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
|
||||||
|
>>> import pydantic
|
||||||
|
>>> from lancedb.pydantic import EncodedImage
|
||||||
|
...
|
||||||
|
>>> class MyModel(pydantic.BaseModel):
|
||||||
|
... image: EncodedImage()
|
||||||
|
>>> schema = pydantic_to_schema(MyModel)
|
||||||
|
>>> assert schema == pa.schema([
|
||||||
|
... pa.field("image", pa.binary(), False)
|
||||||
|
... ])
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "EncodedImage()"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def value_arrow_type() -> pa.DataType:
|
||||||
|
return EncodedImageType()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(
|
||||||
|
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
|
||||||
|
) -> "CoreSchema":
|
||||||
|
from_bytes_schema = core_schema.bytes_schema()
|
||||||
|
|
||||||
|
return core_schema.json_or_python_schema(
|
||||||
|
json_schema=from_bytes_schema,
|
||||||
|
python_schema=core_schema.union_schema(
|
||||||
|
[
|
||||||
|
core_schema.is_instance_schema(PIL.Image.Image),
|
||||||
|
from_bytes_schema,
|
||||||
|
]
|
||||||
|
),
|
||||||
|
serialization=core_schema.plain_serializer_function_ser_schema(
|
||||||
|
lambda instance: cls.validate(instance)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_json_schema__(
|
||||||
|
cls, _core_schema: "CoreSchema", handler: "GetJsonSchemaHandler"
|
||||||
|
) -> "JsonSchemaValue":
|
||||||
|
return handler(core_schema.bytes_schema())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_validators__(cls) -> Generator[Callable, None, None]:
|
||||||
|
yield cls.validate
|
||||||
|
|
||||||
|
# For pydantic v2
|
||||||
|
@classmethod
|
||||||
|
def validate(cls, v):
|
||||||
|
if isinstance(v, bytes):
|
||||||
|
return v
|
||||||
|
if isinstance(v, PIL.Image.Image):
|
||||||
|
with io.BytesIO() as output:
|
||||||
|
v.save(output, format=v.format)
|
||||||
|
return output.getvalue()
|
||||||
|
raise TypeError(
|
||||||
|
"EncodedImage can take bytes or PIL.Image.Image "
|
||||||
|
f"as input but got {type(v)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if PYDANTIC_VERSION < (2, 0):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __modify_schema__(cls, field_schema: Dict[str, Any]):
|
||||||
|
field_schema["type"] = "string"
|
||||||
|
field_schema["format"] = "binary"
|
||||||
|
|
||||||
|
return EncodedImage
|
||||||
|
|
||||||
|
|
||||||
if PYDANTIC_VERSION.major < 2:
|
if PYDANTIC_VERSION.major < 2:
|
||||||
|
def _safe_get_fields(model: pydantic.BaseModel):
|
||||||
def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]:
|
return model.__fields__
|
||||||
return [
|
|
||||||
_pydantic_to_field(name, field) for name, field in model.__fields__.items()
|
|
||||||
]
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
def _safe_get_fields(model: pydantic.BaseModel):
|
||||||
|
return model.model_fields
|
||||||
|
|
||||||
def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]:
|
|
||||||
return [
|
def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]:
|
||||||
_pydantic_to_field(name, field)
|
return [
|
||||||
for name, field in model.model_fields.items()
|
_pydantic_to_field(name, field)
|
||||||
]
|
for name, field in _safe_get_fields(model).items()
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
|
def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
|
||||||
@@ -230,6 +337,9 @@ def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
|
|||||||
return pa.struct(fields)
|
return pa.struct(fields)
|
||||||
elif issubclass(field.annotation, FixedSizeListMixin):
|
elif issubclass(field.annotation, FixedSizeListMixin):
|
||||||
return pa.list_(field.annotation.value_arrow_type(), field.annotation.dim())
|
return pa.list_(field.annotation.value_arrow_type(), field.annotation.dim())
|
||||||
|
elif issubclass(field.annotation, ImageMixin):
|
||||||
|
return field.annotation.value_arrow_type()
|
||||||
|
|
||||||
return _py_type_to_arrow_type(field.annotation, field)
|
return _py_type_to_arrow_type(field.annotation, field)
|
||||||
|
|
||||||
|
|
||||||
@@ -335,13 +445,7 @@ class LanceModel(pydantic.BaseModel):
|
|||||||
"""
|
"""
|
||||||
Get the field names of this model.
|
Get the field names of this model.
|
||||||
"""
|
"""
|
||||||
return list(cls.safe_get_fields().keys())
|
return list(_safe_get_fields(cls).keys())
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def safe_get_fields(cls):
|
|
||||||
if PYDANTIC_VERSION.major < 2:
|
|
||||||
return cls.__fields__
|
|
||||||
return cls.model_fields
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse_embedding_functions(cls) -> List["EmbeddingFunctionConfig"]:
|
def parse_embedding_functions(cls) -> List["EmbeddingFunctionConfig"]:
|
||||||
@@ -351,14 +455,16 @@ class LanceModel(pydantic.BaseModel):
|
|||||||
from .embeddings import EmbeddingFunctionConfig
|
from .embeddings import EmbeddingFunctionConfig
|
||||||
|
|
||||||
vec_and_function = []
|
vec_and_function = []
|
||||||
for name, field_info in cls.safe_get_fields().items():
|
def get_vector_column(name, field_info):
|
||||||
func = get_extras(field_info, "vector_column_for")
|
fun = get_extras(field_info, "vector_column_for")
|
||||||
if func is not None:
|
if func is not None:
|
||||||
vec_and_function.append([name, func])
|
vec_and_function.append([name, func])
|
||||||
|
visit_fields(_safe_get_fields(cls).items(), get_vector_column)
|
||||||
|
|
||||||
configs = []
|
configs = []
|
||||||
|
# find the source columns for each one
|
||||||
for vec, func in vec_and_function:
|
for vec, func in vec_and_function:
|
||||||
for source, field_info in cls.safe_get_fields().items():
|
def get_source_column(source, field_info):
|
||||||
src_func = get_extras(field_info, "source_column_for")
|
src_func = get_extras(field_info, "source_column_for")
|
||||||
if src_func is func:
|
if src_func is func:
|
||||||
# note we can't use == here since the function is a pydantic
|
# note we can't use == here since the function is a pydantic
|
||||||
@@ -371,20 +477,48 @@ class LanceModel(pydantic.BaseModel):
|
|||||||
source_column=source, vector_column=vec, function=func
|
source_column=source, vector_column=vec, function=func
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
visit_fields(_safe_get_fields(cls).items(), get_source_column)
|
||||||
return configs
|
return configs
|
||||||
|
|
||||||
|
|
||||||
|
def visit_fields(fields: Iterable[Tuple[str, FieldInfo]],
|
||||||
|
visitor: Callable[[str, FieldInfo], Any]):
|
||||||
|
"""
|
||||||
|
Visit all the leaf fields in a Pydantic model.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
fields : Iterable[Tuple(str, FieldInfo)]
|
||||||
|
The fields to visit.
|
||||||
|
visitor : Callable[[str, FieldInfo], Any]
|
||||||
|
The visitor function.
|
||||||
|
"""
|
||||||
|
for name, field_info in fields:
|
||||||
|
# if the field is a pydantic model then
|
||||||
|
# visit all subfields
|
||||||
|
if (isinstance(getattr(field_info, "annotation"), type)
|
||||||
|
and issubclass(field_info.annotation, pydantic.BaseModel)):
|
||||||
|
visit_fields(_safe_get_fields(field_info.annotation).items(),
|
||||||
|
_add_prefix(visitor, name))
|
||||||
|
else:
|
||||||
|
visitor(name, field_info)
|
||||||
|
|
||||||
|
|
||||||
def get_extras(field_info: FieldInfo, key: str) -> Any:
|
def _add_prefix(visitor: Callable[[str, FieldInfo], Any], prefix: str) -> Callable[[str, FieldInfo], Any]:
|
||||||
"""
|
def prefixed_visitor(name: str, field: FieldInfo):
|
||||||
Get the extra metadata from a Pydantic FieldInfo.
|
return visitor(f"{prefix}.{name}", field)
|
||||||
"""
|
return prefixed_visitor
|
||||||
if PYDANTIC_VERSION.major >= 2:
|
|
||||||
return (field_info.json_schema_extra or {}).get(key)
|
|
||||||
return (field_info.field_info.extra or {}).get("json_schema_extra", {}).get(key)
|
|
||||||
|
|
||||||
|
|
||||||
if PYDANTIC_VERSION.major < 2:
|
if PYDANTIC_VERSION.major < 2:
|
||||||
|
|
||||||
|
def get_extras(field_info: FieldInfo, key: str) -> Any:
|
||||||
|
"""
|
||||||
|
Get the extra metadata from a Pydantic FieldInfo.
|
||||||
|
"""
|
||||||
|
return (field_info.field_info.extra or {}).get("json_schema_extra", {}).get(key)
|
||||||
|
|
||||||
|
|
||||||
def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]:
|
def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Convert a Pydantic model to a dictionary.
|
Convert a Pydantic model to a dictionary.
|
||||||
@@ -393,6 +527,13 @@ if PYDANTIC_VERSION.major < 2:
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
|
def get_extras(field_info: FieldInfo, key: str) -> Any:
|
||||||
|
"""
|
||||||
|
Get the extra metadata from a Pydantic FieldInfo.
|
||||||
|
"""
|
||||||
|
return (field_info.json_schema_extra or {}).get(key)
|
||||||
|
|
||||||
|
|
||||||
def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]:
|
def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Convert a Pydantic model to a dictionary.
|
Convert a Pydantic model to a dictionary.
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ import pyarrow as pa
|
|||||||
import pyarrow.compute as pc
|
import pyarrow.compute as pc
|
||||||
import pyarrow.fs as pa_fs
|
import pyarrow.fs as pa_fs
|
||||||
from lance import LanceDataset
|
from lance import LanceDataset
|
||||||
|
from lance.dependencies import _check_for_hugging_face
|
||||||
from lance.vector import vec_to_table
|
from lance.vector import vec_to_table
|
||||||
|
|
||||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||||
@@ -74,7 +75,16 @@ def _sanitize_data(
|
|||||||
on_bad_vectors: str,
|
on_bad_vectors: str,
|
||||||
fill_value: Any,
|
fill_value: Any,
|
||||||
):
|
):
|
||||||
if isinstance(data, list):
|
import pdb; pdb.set_trace()
|
||||||
|
if _check_for_hugging_face(data):
|
||||||
|
# Huggingface datasets
|
||||||
|
import datasets
|
||||||
|
|
||||||
|
if isinstance(data, datasets.Dataset):
|
||||||
|
if schema is None:
|
||||||
|
schema = data.features.arrow_schema
|
||||||
|
data = data.data.to_batches()
|
||||||
|
elif isinstance(data, list):
|
||||||
# convert to list of dict if data is a bunch of LanceModels
|
# convert to list of dict if data is a bunch of LanceModels
|
||||||
if isinstance(data[0], LanceModel):
|
if isinstance(data[0], LanceModel):
|
||||||
schema = data[0].__class__.to_arrow_schema()
|
schema = data[0].__class__.to_arrow_schema()
|
||||||
@@ -136,11 +146,10 @@ def _to_record_batch_generator(
|
|||||||
data: Iterable, schema, metadata, on_bad_vectors, fill_value
|
data: Iterable, schema, metadata, on_bad_vectors, fill_value
|
||||||
):
|
):
|
||||||
for batch in data:
|
for batch in data:
|
||||||
if not isinstance(batch, pa.RecordBatch):
|
if isinstance(batch, pa.RecordBatch):
|
||||||
table = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value)
|
batch = pa.Table.from_batches([batch])
|
||||||
for batch in table.to_batches():
|
table = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value)
|
||||||
yield batch
|
for batch in table.to_batches():
|
||||||
else:
|
|
||||||
yield batch
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
BIN
python/python/tests/images/1.png
Normal file
BIN
python/python/tests/images/1.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 83 B |
@@ -12,17 +12,27 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import io
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
from datetime import date, datetime
|
from datetime import date, datetime
|
||||||
|
from pathlib import Path
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import pydantic
|
import pydantic
|
||||||
import pytest
|
import pytest
|
||||||
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
|
from lancedb.pydantic import (
|
||||||
|
PYDANTIC_VERSION,
|
||||||
|
EncodedImage,
|
||||||
|
LanceModel,
|
||||||
|
Vector,
|
||||||
|
pydantic_to_schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
sys.version_info < (3, 9),
|
sys.version_info < (3, 9),
|
||||||
@@ -243,3 +253,23 @@ def test_lance_model():
|
|||||||
|
|
||||||
t = TestModel()
|
t = TestModel()
|
||||||
assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3])
|
assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3])
|
||||||
|
|
||||||
|
|
||||||
|
def test_schema_with_images():
|
||||||
|
pytest.importorskip("PIL")
|
||||||
|
import PIL.Image
|
||||||
|
|
||||||
|
class TestModel(LanceModel):
|
||||||
|
img: EncodedImage()
|
||||||
|
|
||||||
|
img_path = Path(os.path.dirname(__file__)) / "images/1.png"
|
||||||
|
with open(img_path, "rb") as f:
|
||||||
|
img_bytes = f.read()
|
||||||
|
|
||||||
|
m1 = TestModel(img=PIL.Image.open(img_path))
|
||||||
|
m2 = TestModel(img=img_bytes)
|
||||||
|
|
||||||
|
def tobytes(m):
|
||||||
|
return PIL.Image.open(io.BytesIO(m.model_dump()["img"])).tobytes()
|
||||||
|
|
||||||
|
assert tobytes(m1) == tobytes(m2)
|
||||||
|
|||||||
@@ -12,6 +12,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
import io
|
||||||
|
import os
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from datetime import date, datetime, timedelta
|
from datetime import date, datetime, timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -20,19 +22,21 @@ from typing import List
|
|||||||
from unittest.mock import PropertyMock, patch
|
from unittest.mock import PropertyMock, patch
|
||||||
|
|
||||||
import lance
|
import lance
|
||||||
import lancedb
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import polars as pl
|
import polars as pl
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
from lance.arrow import EncodedImageType
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
import lancedb
|
||||||
from lancedb.conftest import MockTextEmbeddingFunction
|
from lancedb.conftest import MockTextEmbeddingFunction
|
||||||
from lancedb.db import AsyncConnection, LanceDBConnection
|
from lancedb.db import AsyncConnection, LanceDBConnection
|
||||||
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||||
from lancedb.pydantic import LanceModel, Vector
|
from lancedb.pydantic import EncodedImage, LanceModel, Vector
|
||||||
from lancedb.table import LanceTable
|
from lancedb.table import LanceTable
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class MockDB:
|
class MockDB:
|
||||||
@@ -108,6 +112,7 @@ def test_create_table(db):
|
|||||||
pa.field("vector", pa.list_(pa.float32(), 2)),
|
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||||
pa.field("item", pa.string()),
|
pa.field("item", pa.string()),
|
||||||
pa.field("price", pa.float32()),
|
pa.field("price", pa.float32()),
|
||||||
|
pa.field("encoded_image", EncodedImageType()),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
expected = pa.Table.from_arrays(
|
expected = pa.Table.from_arrays(
|
||||||
@@ -115,13 +120,26 @@ def test_create_table(db):
|
|||||||
pa.FixedSizeListArray.from_arrays(pa.array([3.1, 4.1, 5.9, 26.5]), 2),
|
pa.FixedSizeListArray.from_arrays(pa.array([3.1, 4.1, 5.9, 26.5]), 2),
|
||||||
pa.array(["foo", "bar"]),
|
pa.array(["foo", "bar"]),
|
||||||
pa.array([10.0, 20.0]),
|
pa.array([10.0, 20.0]),
|
||||||
|
pa.ExtensionArray.from_storage(
|
||||||
|
EncodedImageType(), pa.array([b"foo", b"bar"], pa.binary())
|
||||||
|
),
|
||||||
],
|
],
|
||||||
schema=schema,
|
schema=schema,
|
||||||
)
|
)
|
||||||
data = [
|
data = [
|
||||||
[
|
[
|
||||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
{
|
||||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
"vector": [3.1, 4.1],
|
||||||
|
"item": "foo",
|
||||||
|
"price": 10.0,
|
||||||
|
"encoded_image": b"foo",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"vector": [5.9, 26.5],
|
||||||
|
"item": "bar",
|
||||||
|
"price": 20.0,
|
||||||
|
"encoded_image": b"bar",
|
||||||
|
},
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
df = pd.DataFrame(data[0])
|
df = pd.DataFrame(data[0])
|
||||||
@@ -1025,3 +1043,22 @@ async def test_time_travel(db_async: AsyncConnection):
|
|||||||
# Can't use restore if not checked out
|
# Can't use restore if not checked out
|
||||||
with pytest.raises(ValueError, match="checkout before running restore"):
|
with pytest.raises(ValueError, match="checkout before running restore"):
|
||||||
await table.restore()
|
await table.restore()
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_image(tmp_path):
|
||||||
|
pytest.importorskip("PIL")
|
||||||
|
import PIL.Image
|
||||||
|
|
||||||
|
db = lancedb.connect(tmp_path)
|
||||||
|
|
||||||
|
class TestModel(LanceModel):
|
||||||
|
img: EncodedImage()
|
||||||
|
|
||||||
|
img_path = Path(os.path.dirname(__file__)) / "images/1.png"
|
||||||
|
m1 = TestModel(img=PIL.Image.open(img_path))
|
||||||
|
|
||||||
|
def tobytes(m):
|
||||||
|
return PIL.Image.open(io.BytesIO(m.model_dump()["img"])).tobytes()
|
||||||
|
|
||||||
|
table = LanceTable.create(db, "my_table", schema=TestModel)
|
||||||
|
table.add([m1])
|
||||||
|
|||||||
Reference in New Issue
Block a user