diff --git a/python/lancedb/pydantic.py b/python/lancedb/pydantic.py index 2584d075..ed55e677 100644 --- a/python/lancedb/pydantic.py +++ b/python/lancedb/pydantic.py @@ -11,7 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Pydantic adapter for LanceDB""" +"""Pydantic (v1 / v2) adapter for LanceDB""" from __future__ import annotations @@ -19,11 +19,19 @@ import inspect import sys import types from abc import ABC, abstractmethod -from typing import Any, List, Type, Union, _GenericAlias +from typing import Any, Callable, Dict, Generator, List, Type, Union, _GenericAlias +import numpy as np import pyarrow as pa import pydantic -from pydantic_core import CoreSchema, core_schema +import semver + +PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__) +try: + from pydantic_core import CoreSchema, core_schema +except ImportError: + if PYDANTIC_VERSION >= (2,): + raise class FixedSizeListMixin(ABC): @@ -73,6 +81,9 @@ def vector( # TODO: make a public parameterized type. class FixedSizeList(list, FixedSizeListMixin): + def __repr__(self): + return f"FixedSizeList(dim={dim})" + @staticmethod def dim() -> int: return dim @@ -94,6 +105,25 @@ def vector( ), ) + @classmethod + def __get_validators__(cls) -> Generator[Callable, None, None]: + yield cls.validate + + # For pydantic v1 + @classmethod + def validate(cls, v): + if not isinstance(v, (list, range, np.ndarray)) or len(v) != dim: + raise TypeError("A list of numbers or numpy.ndarray is needed") + return v + + if PYDANTIC_VERSION < (2, 0): + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]): + field_schema["items"] = {"type": "number"} + field_schema["maxItems"] = dim + field_schema["minItems"] = dim + return FixedSizeList @@ -120,11 +150,20 @@ def _py_type_to_arrow_type(py_type: Type[Any]) -> pa.DataType: ) -def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]: - fields = [] - for name, field in model.model_fields.items(): - fields.append(_pydantic_to_field(name, field)) - return fields +if PYDANTIC_VERSION.major < 2: + + def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]: + return [ + _pydantic_to_field(name, field) for name, field in model.__fields__.items() + ] + +else: + + def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]: + return [ + _pydantic_to_field(name, field) + for name, field in model.model_fields.items() + ] def _pydantic_to_arrow_type(field: pydantic.fields.FieldInfo) -> pa.DataType: diff --git a/python/pyproject.toml b/python/pyproject.toml index 3af055e7..8dc85828 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "lancedb" version = "0.1.12" -dependencies = ["pylance~=0.5.8", "ratelimiter", "retry", "tqdm", "aiohttp", "pydantic>=2", "attr"] +dependencies = ["pylance~=0.5.8", "ratelimiter", "retry", "tqdm", "aiohttp", "pydantic", "attr", "semver"] description = "lancedb" authors = [ { name = "LanceDB Devs", email = "dev@lancedb.com" }, @@ -52,3 +52,6 @@ requires = [ "wheel", ] build-backend = "setuptools.build_meta" + +[tool.isort] +profile = "black" diff --git a/python/tests/test_pydantic.py b/python/tests/test_pydantic.py index d4189bde..c38262db 100644 --- a/python/tests/test_pydantic.py +++ b/python/tests/test_pydantic.py @@ -20,7 +20,7 @@ import pyarrow as pa import pydantic import pytest -from lancedb.pydantic import pydantic_to_schema, vector +from lancedb.pydantic import PYDANTIC_VERSION, pydantic_to_schema, vector @pytest.mark.skipif( @@ -111,10 +111,16 @@ def test_fixed_size_list_field(): li: List[int] data = TestModel(vec=list(range(16)), li=[1, 2, 3]) - assert json.loads(data.model_dump_json()) == { - "vec": list(range(16)), - "li": [1, 2, 3], - } + if PYDANTIC_VERSION >= (2,): + assert json.loads(data.model_dump_json()) == { + "vec": list(range(16)), + "li": [1, 2, 3], + } + else: + assert data.dict() == { + "vec": list(range(16)), + "li": [1, 2, 3], + } schema = pydantic_to_schema(TestModel) assert schema == pa.schema( @@ -124,7 +130,11 @@ def test_fixed_size_list_field(): ] ) - json_schema = TestModel.model_json_schema() + if PYDANTIC_VERSION >= (2,): + json_schema = TestModel.model_json_schema() + else: + json_schema = TestModel.schema() + assert json_schema == { "properties": { "vec": {