From 08944bf4fdef70292853c0c036dbbe4f12f6dfce Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Thu, 13 Jul 2023 11:16:37 -0700 Subject: [PATCH] [Python] Convert Pydantic Model to Arrow Schema (#291) Provide utility to automatically convert Pydantic model to Arrow Schema Closes #256 --- docs/src/python/python.md | 11 +++ python/lancedb/pydantic.py | 169 ++++++++++++++++++++++++++++++++++ python/pyproject.toml | 2 +- python/tests/test_pydantic.py | 155 +++++++++++++++++++++++++++++++ 4 files changed, 336 insertions(+), 1 deletion(-) create mode 100644 python/lancedb/pydantic.py create mode 100644 python/tests/test_pydantic.py diff --git a/docs/src/python/python.md b/docs/src/python/python.md index 8fc2150e..99b6dd15 100644 --- a/docs/src/python/python.md +++ b/docs/src/python/python.md @@ -47,6 +47,17 @@ pip install lancedb ## Utilities ::: lancedb.schema.schema_to_dict + ::: lancedb.schema.dict_to_schema + ::: lancedb.vector +## Integrations + +### Pydantic + +::: lancedb.pydantic.pydantic_to_schema + +::: lancedb.pydantic.vector + + diff --git a/python/lancedb/pydantic.py b/python/lancedb/pydantic.py new file mode 100644 index 00000000..31cf2d5b --- /dev/null +++ b/python/lancedb/pydantic.py @@ -0,0 +1,169 @@ +# Copyright 2023 LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pydantic adapter for LanceDB""" + +from __future__ import annotations + +import inspect +import sys +import types +from abc import ABC, abstractstaticmethod +from typing import Any, List, Type, Union, _GenericAlias + +import pyarrow as pa +import pydantic +from pydantic_core import CoreSchema, core_schema + + +class FixedSizeListMixin(ABC): + @abstractstaticmethod + def dim() -> int: + raise NotImplementedError + + @abstractstaticmethod + def value_arrow_type() -> pa.DataType: + raise NotImplementedError + + +def vector( + dim: int, value_type: pa.DataType = pa.float32() +) -> Type[FixedSizeListMixin]: + """Pydantic Vector Type. + + Note + ---- + Experimental feature. + + Examples + -------- + + >>> import pydantic + >>> from lancedb.pydantic import vector + ... + >>> class MyModel(pydantic.BaseModel): + ... vector: vector(756) + ... id: int + ... description: str + """ + + # TODO: make a public parameterized type. + class FixedSizeList(list, FixedSizeListMixin): + @staticmethod + def dim() -> int: + return dim + + @staticmethod + def value_arrow_type() -> pa.DataType: + return value_type + + @classmethod + def __get_pydantic_core_schema__( + cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler + ) -> CoreSchema: + return core_schema.no_info_after_validator_function( + cls, + core_schema.list_schema( + min_length=dim, + max_length=dim, + items_schema=core_schema.float_schema(), + ), + ) + + return FixedSizeList + + +def _py_type_to_arrow_type(py_type: Type[Any]) -> pa.DataType: + """Convert Python Type to Arrow DataType. + + Raises + ------ + TypeError + If the type is not supported. + """ + if py_type == int: + return pa.int64() + elif py_type == float: + return pa.float64() + elif py_type == str: + return pa.utf8() + elif py_type == bool: + return pa.bool_() + elif py_type == bytes: + return pa.binary() + raise TypeError( + f"Converting Pydantic type to Arrow Type: unsupported type {py_type}" + ) + + +def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]: + fields = [] + for name, field in model.model_fields.items(): + fields.append(_pydantic_to_field(name, field)) + return fields + + +def _pydantic_to_arrow_type(field: pydantic.fields.FieldInfo) -> pa.DataType: + """Convert a Pydantic FieldInfo to Arrow DataType""" + if isinstance(field.annotation, _GenericAlias) or ( + sys.version_info > (3, 9) and isinstance(field.annotation, types.GenericAlias) + ): + origin = field.annotation.__origin__ + args = field.annotation.__args__ + if origin == list: + child = args[0] + return pa.list_(_py_type_to_arrow_type(child)) + elif origin == Union: + if len(args) == 2 and args[1] == type(None): + return _py_type_to_arrow_type(args[0]) + elif inspect.isclass(field.annotation): + if issubclass(field.annotation, pydantic.BaseModel): + # Struct + fields = _pydantic_model_to_fields(field.annotation) + return pa.struct(fields) + elif issubclass(field.annotation, FixedSizeListMixin): + return pa.list_(field.annotation.value_arrow_type(), field.annotation.dim()) + return _py_type_to_arrow_type(field.annotation) + + +def is_nullable(field: pydantic.fields.FieldInfo) -> bool: + """Check if a Pydantic FieldInfo is nullable.""" + if isinstance(field.annotation, _GenericAlias): + origin = field.annotation.__origin__ + args = field.annotation.__args__ + if origin == Union: + if len(args) == 2 and args[1] == type(None): + return True + return False + + +def _pydantic_to_field(name: str, field: pydantic.fields.FieldInfo) -> pa.Field: + """Convert a Pydantic field to a PyArrow Field.""" + dt = _pydantic_to_arrow_type(field) + return pa.field(name, dt, is_nullable(field)) + + +def pydantic_to_schema(model: Type[pydantic.BaseModel]) -> pa.Schema: + """Convert a Pydantic model to a PyArrow Schema. + + Parameters + ---------- + model : Type[pydantic.BaseModel] + The Pydantic BaseModel to convert to Arrow Schema. + + Returns + ------- + A PyArrow Schema. + """ + fields = _pydantic_model_to_fields(model) + return pa.schema(fields) diff --git a/python/pyproject.toml b/python/pyproject.toml index df8df9e9..2047b12c 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "lancedb" version = "0.1.10" -dependencies = ["pylance~=0.5.0", "ratelimiter", "retry", "tqdm", "aiohttp", "pydantic", "attr"] +dependencies = ["pylance~=0.5.0", "ratelimiter", "retry", "tqdm", "aiohttp", "pydantic>=2", "attr"] description = "lancedb" authors = [ { name = "LanceDB Devs", email = "dev@lancedb.com" }, diff --git a/python/tests/test_pydantic.py b/python/tests/test_pydantic.py new file mode 100644 index 00000000..d4189bde --- /dev/null +++ b/python/tests/test_pydantic.py @@ -0,0 +1,155 @@ +# Copyright 2023 LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +import sys +from typing import List, Optional + +import pyarrow as pa +import pydantic +import pytest + +from lancedb.pydantic import pydantic_to_schema, vector + + +@pytest.mark.skipif( + sys.version_info < (3, 9), + reason="using native type alias requires python3.9 or higher", +) +def test_pydantic_to_arrow(): + class StructModel(pydantic.BaseModel): + a: str + b: Optional[float] + + class TestModel(pydantic.BaseModel): + id: int + s: str + vec: list[float] + li: List[int] + opt: Optional[str] = None + st: StructModel + # d: dict + + m = TestModel( + id=1, s="hello", vec=[1.0, 2.0, 3.0], li=[2, 3, 4], st=StructModel(a="a", b=1.0) + ) + + schema = pydantic_to_schema(TestModel) + + expect_schema = pa.schema( + [ + pa.field("id", pa.int64(), False), + pa.field("s", pa.utf8(), False), + pa.field("vec", pa.list_(pa.float64()), False), + pa.field("li", pa.list_(pa.int64()), False), + pa.field("opt", pa.utf8(), True), + pa.field( + "st", + pa.struct( + [pa.field("a", pa.utf8(), False), pa.field("b", pa.float64(), True)] + ), + False, + ), + ] + ) + assert schema == expect_schema + + +def test_pydantic_to_arrow_py38(): + class StructModel(pydantic.BaseModel): + a: str + b: Optional[float] + + class TestModel(pydantic.BaseModel): + id: int + s: str + vec: List[float] + li: List[int] + opt: Optional[str] = None + st: StructModel + # d: dict + + m = TestModel( + id=1, s="hello", vec=[1.0, 2.0, 3.0], li=[2, 3, 4], st=StructModel(a="a", b=1.0) + ) + + schema = pydantic_to_schema(TestModel) + + expect_schema = pa.schema( + [ + pa.field("id", pa.int64(), False), + pa.field("s", pa.utf8(), False), + pa.field("vec", pa.list_(pa.float64()), False), + pa.field("li", pa.list_(pa.int64()), False), + pa.field("opt", pa.utf8(), True), + pa.field( + "st", + pa.struct( + [pa.field("a", pa.utf8(), False), pa.field("b", pa.float64(), True)] + ), + False, + ), + ] + ) + assert schema == expect_schema + + +def test_fixed_size_list_field(): + class TestModel(pydantic.BaseModel): + vec: vector(16) + li: List[int] + + data = TestModel(vec=list(range(16)), li=[1, 2, 3]) + assert json.loads(data.model_dump_json()) == { + "vec": list(range(16)), + "li": [1, 2, 3], + } + + schema = pydantic_to_schema(TestModel) + assert schema == pa.schema( + [ + pa.field("vec", pa.list_(pa.float32(), 16), False), + pa.field("li", pa.list_(pa.int64()), False), + ] + ) + + json_schema = TestModel.model_json_schema() + assert json_schema == { + "properties": { + "vec": { + "items": {"type": "number"}, + "maxItems": 16, + "minItems": 16, + "title": "Vec", + "type": "array", + }, + "li": {"items": {"type": "integer"}, "title": "Li", "type": "array"}, + }, + "required": ["vec", "li"], + "title": "TestModel", + "type": "object", + } + + +def test_fixed_size_list_validation(): + class TestModel(pydantic.BaseModel): + vec: vector(8) + + with pytest.raises(pydantic.ValidationError): + TestModel(vec=range(9)) + + with pytest.raises(pydantic.ValidationError): + TestModel(vec=range(7)) + + TestModel(vec=range(8))