mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-03 18:32:55 +00:00
[Python] Support pydantic v1 as well (#337)
Support both Pydantic v1 and v2 (breaking changes)
This commit is contained in:
@@ -11,7 +11,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Pydantic adapter for LanceDB"""
|
||||
"""Pydantic (v1 / v2) adapter for LanceDB"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -19,11 +19,19 @@ import inspect
|
||||
import sys
|
||||
import types
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Type, Union, _GenericAlias
|
||||
from typing import Any, Callable, Dict, Generator, List, Type, Union, _GenericAlias
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pydantic
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
import semver
|
||||
|
||||
PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__)
|
||||
try:
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
except ImportError:
|
||||
if PYDANTIC_VERSION >= (2,):
|
||||
raise
|
||||
|
||||
|
||||
class FixedSizeListMixin(ABC):
|
||||
@@ -73,6 +81,9 @@ def vector(
|
||||
|
||||
# TODO: make a public parameterized type.
|
||||
class FixedSizeList(list, FixedSizeListMixin):
|
||||
def __repr__(self):
|
||||
return f"FixedSizeList(dim={dim})"
|
||||
|
||||
@staticmethod
|
||||
def dim() -> int:
|
||||
return dim
|
||||
@@ -94,6 +105,25 @@ def vector(
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> Generator[Callable, None, None]:
|
||||
yield cls.validate
|
||||
|
||||
# For pydantic v1
|
||||
@classmethod
|
||||
def validate(cls, v):
|
||||
if not isinstance(v, (list, range, np.ndarray)) or len(v) != dim:
|
||||
raise TypeError("A list of numbers or numpy.ndarray is needed")
|
||||
return v
|
||||
|
||||
if PYDANTIC_VERSION < (2, 0):
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]):
|
||||
field_schema["items"] = {"type": "number"}
|
||||
field_schema["maxItems"] = dim
|
||||
field_schema["minItems"] = dim
|
||||
|
||||
return FixedSizeList
|
||||
|
||||
|
||||
@@ -120,11 +150,20 @@ def _py_type_to_arrow_type(py_type: Type[Any]) -> pa.DataType:
|
||||
)
|
||||
|
||||
|
||||
def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]:
|
||||
fields = []
|
||||
for name, field in model.model_fields.items():
|
||||
fields.append(_pydantic_to_field(name, field))
|
||||
return fields
|
||||
if PYDANTIC_VERSION.major < 2:
|
||||
|
||||
def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]:
|
||||
return [
|
||||
_pydantic_to_field(name, field) for name, field in model.__fields__.items()
|
||||
]
|
||||
|
||||
else:
|
||||
|
||||
def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]:
|
||||
return [
|
||||
_pydantic_to_field(name, field)
|
||||
for name, field in model.model_fields.items()
|
||||
]
|
||||
|
||||
|
||||
def _pydantic_to_arrow_type(field: pydantic.fields.FieldInfo) -> pa.DataType:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[project]
|
||||
name = "lancedb"
|
||||
version = "0.1.12"
|
||||
dependencies = ["pylance~=0.5.8", "ratelimiter", "retry", "tqdm", "aiohttp", "pydantic>=2", "attr"]
|
||||
dependencies = ["pylance~=0.5.8", "ratelimiter", "retry", "tqdm", "aiohttp", "pydantic", "attr", "semver"]
|
||||
description = "lancedb"
|
||||
authors = [
|
||||
{ name = "LanceDB Devs", email = "dev@lancedb.com" },
|
||||
@@ -52,3 +52,6 @@ requires = [
|
||||
"wheel",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
|
||||
@@ -20,7 +20,7 @@ import pyarrow as pa
|
||||
import pydantic
|
||||
import pytest
|
||||
|
||||
from lancedb.pydantic import pydantic_to_schema, vector
|
||||
from lancedb.pydantic import PYDANTIC_VERSION, pydantic_to_schema, vector
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
@@ -111,10 +111,16 @@ def test_fixed_size_list_field():
|
||||
li: List[int]
|
||||
|
||||
data = TestModel(vec=list(range(16)), li=[1, 2, 3])
|
||||
assert json.loads(data.model_dump_json()) == {
|
||||
"vec": list(range(16)),
|
||||
"li": [1, 2, 3],
|
||||
}
|
||||
if PYDANTIC_VERSION >= (2,):
|
||||
assert json.loads(data.model_dump_json()) == {
|
||||
"vec": list(range(16)),
|
||||
"li": [1, 2, 3],
|
||||
}
|
||||
else:
|
||||
assert data.dict() == {
|
||||
"vec": list(range(16)),
|
||||
"li": [1, 2, 3],
|
||||
}
|
||||
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
assert schema == pa.schema(
|
||||
@@ -124,7 +130,11 @@ def test_fixed_size_list_field():
|
||||
]
|
||||
)
|
||||
|
||||
json_schema = TestModel.model_json_schema()
|
||||
if PYDANTIC_VERSION >= (2,):
|
||||
json_schema = TestModel.model_json_schema()
|
||||
else:
|
||||
json_schema = TestModel.schema()
|
||||
|
||||
assert json_schema == {
|
||||
"properties": {
|
||||
"vec": {
|
||||
|
||||
Reference in New Issue
Block a user