From ad51e2dd1fafb0e6f8e5e2c65d3b3d32567eeeaf Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Wed, 28 Jan 2026 21:08:18 -0800 Subject: [PATCH] fix: support pydantic list of structs or optional struct (#2953) Closes #2950 *This code is generated by codex-gpt5.2* --- python/python/lancedb/pydantic.py | 66 ++++--- python/python/tests/test_pydantic.py | 247 +++++++++++++++++++++++++++ 2 files changed, 294 insertions(+), 19 deletions(-) diff --git a/python/python/lancedb/pydantic.py b/python/python/lancedb/pydantic.py index 79f005177..653ea3333 100644 --- a/python/python/lancedb/pydantic.py +++ b/python/python/lancedb/pydantic.py @@ -275,7 +275,7 @@ def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType: return pa.timestamp("us", tz=tz) elif getattr(py_type, "__origin__", None) in (list, tuple): child = py_type.__args__[0] - return pa.list_(_py_type_to_arrow_type(child, field)) + return _pydantic_list_child_to_arrow(child, field) raise TypeError( f"Converting Pydantic type to Arrow Type: unsupported type {py_type}." ) @@ -298,12 +298,18 @@ else: def _pydantic_type_to_arrow_type(tp: Any, field: FieldInfo) -> pa.DataType: + def _safe_issubclass(candidate: Any, base: type) -> bool: + try: + return issubclass(candidate, base) + except TypeError: + return False + if inspect.isclass(tp): - if issubclass(tp, pydantic.BaseModel): + if _safe_issubclass(tp, pydantic.BaseModel): # Struct fields = _pydantic_model_to_fields(tp) return pa.struct(fields) - if issubclass(tp, FixedSizeListMixin): + if _safe_issubclass(tp, FixedSizeListMixin): if getattr(tp, "is_multi_vector", lambda: False)(): return pa.list_(pa.list_(tp.value_arrow_type(), tp.dim())) # For regular Vector @@ -311,45 +317,67 @@ def _pydantic_type_to_arrow_type(tp: Any, field: FieldInfo) -> pa.DataType: return _py_type_to_arrow_type(tp, field) +def _pydantic_list_child_to_arrow(child: Any, field: FieldInfo) -> pa.DataType: + unwrapped = _unwrap_optional_annotation(child) + if unwrapped is not None: + return pa.list_( + pa.field("item", _pydantic_type_to_arrow_type(unwrapped, field), True) + ) + return pa.list_(_pydantic_type_to_arrow_type(child, field)) + + +def _unwrap_optional_annotation(annotation: Any) -> Any | None: + if isinstance(annotation, (_GenericAlias, GenericAlias)): + origin = annotation.__origin__ + args = annotation.__args__ + if origin == Union: + non_none = [arg for arg in args if arg is not type(None)] + if len(non_none) == 1 and len(non_none) != len(args): + return non_none[0] + elif sys.version_info >= (3, 10) and isinstance(annotation, types.UnionType): + args = annotation.__args__ + non_none = [arg for arg in args if arg is not type(None)] + if len(non_none) == 1 and len(non_none) != len(args): + return non_none[0] + return None + + def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType: """Convert a Pydantic FieldInfo to Arrow DataType""" + unwrapped = _unwrap_optional_annotation(field.annotation) + if unwrapped is not None: + return _pydantic_type_to_arrow_type(unwrapped, field) if isinstance(field.annotation, (_GenericAlias, GenericAlias)): origin = field.annotation.__origin__ args = field.annotation.__args__ if origin is list: child = args[0] - return pa.list_(_py_type_to_arrow_type(child, field)) - elif origin == Union: - if len(args) == 2 and args[1] is type(None): - return _pydantic_type_to_arrow_type(args[0], field) - elif sys.version_info >= (3, 10) and isinstance(field.annotation, types.UnionType): - args = field.annotation.__args__ - if len(args) == 2: - for typ in args: - if typ is type(None): - continue - return _py_type_to_arrow_type(typ, field) + return _pydantic_list_child_to_arrow(child, field) return _pydantic_type_to_arrow_type(field.annotation, field) def is_nullable(field: FieldInfo) -> bool: """Check if a Pydantic FieldInfo is nullable.""" + if _unwrap_optional_annotation(field.annotation) is not None: + return True if isinstance(field.annotation, (_GenericAlias, GenericAlias)): origin = field.annotation.__origin__ args = field.annotation.__args__ if origin == Union: - if len(args) == 2 and args[1] is type(None): + if any(typ is type(None) for typ in args): return True elif sys.version_info >= (3, 10) and isinstance(field.annotation, types.UnionType): args = field.annotation.__args__ for typ in args: if typ is type(None): return True - elif inspect.isclass(field.annotation) and issubclass( - field.annotation, FixedSizeListMixin - ): - return field.annotation.nullable() + elif inspect.isclass(field.annotation): + try: + if issubclass(field.annotation, FixedSizeListMixin): + return field.annotation.nullable() + except TypeError: + return False return False diff --git a/python/python/tests/test_pydantic.py b/python/python/tests/test_pydantic.py index 41dbb6eca..2a889e7e5 100644 --- a/python/python/tests/test_pydantic.py +++ b/python/python/tests/test_pydantic.py @@ -105,6 +105,253 @@ def test_optional_types_py310(): assert schema == expect_schema +@pytest.mark.skipif( + sys.version_info < (3, 10), + reason="using PEP 604 union types requires python3.10 or higher", +) +def test_optional_structs_py310(): + class SplitInfo(pydantic.BaseModel): + start_frame: int + end_frame: int + + class TestModel(pydantic.BaseModel): + id: str + split: SplitInfo | None = None + + schema = pydantic_to_schema(TestModel) + + expect_schema = pa.schema( + [ + pa.field("id", pa.utf8(), False), + pa.field( + "split", + pa.struct( + [ + pa.field("start_frame", pa.int64(), False), + pa.field("end_frame", pa.int64(), False), + ] + ), + True, + ), + ] + ) + assert schema == expect_schema + + +@pytest.mark.skipif( + sys.version_info < (3, 10), + reason="using PEP 604 union types requires python3.10 or higher", +) +def test_optional_struct_list_py310(): + class SplitInfo(pydantic.BaseModel): + start_frame: int + end_frame: int + + class TestModel(pydantic.BaseModel): + id: str + splits: list[SplitInfo] | None = None + + schema = pydantic_to_schema(TestModel) + + expect_schema = pa.schema( + [ + pa.field("id", pa.utf8(), False), + pa.field( + "splits", + pa.list_( + pa.struct( + [ + pa.field("start_frame", pa.int64(), False), + pa.field("end_frame", pa.int64(), False), + ] + ) + ), + True, + ), + ] + ) + assert schema == expect_schema + + +@pytest.mark.skipif( + sys.version_info < (3, 9), + reason="using native type alias requires python3.9 or higher", +) +def test_nested_struct_list(): + class SplitInfo(pydantic.BaseModel): + start_frame: int + end_frame: int + + class TestModel(pydantic.BaseModel): + id: str + splits: list[SplitInfo] + + schema = pydantic_to_schema(TestModel) + + expect_schema = pa.schema( + [ + pa.field("id", pa.utf8(), False), + pa.field( + "splits", + pa.list_( + pa.struct( + [ + pa.field("start_frame", pa.int64(), False), + pa.field("end_frame", pa.int64(), False), + ] + ) + ), + False, + ), + ] + ) + assert schema == expect_schema + + +@pytest.mark.skipif( + sys.version_info < (3, 9), + reason="using native type alias requires python3.9 or higher", +) +def test_nested_struct_list_optional(): + class SplitInfo(pydantic.BaseModel): + start_frame: int + end_frame: int + + class TestModel(pydantic.BaseModel): + id: str + splits: Optional[list[SplitInfo]] = None + + schema = pydantic_to_schema(TestModel) + + expect_schema = pa.schema( + [ + pa.field("id", pa.utf8(), False), + pa.field( + "splits", + pa.list_( + pa.struct( + [ + pa.field("start_frame", pa.int64(), False), + pa.field("end_frame", pa.int64(), False), + ] + ) + ), + True, + ), + ] + ) + assert schema == expect_schema + + +def test_nested_struct_list_optional_items(): + class SplitInfo(pydantic.BaseModel): + start_frame: int + end_frame: int + + class TestModel(pydantic.BaseModel): + id: str + splits: list[Optional[SplitInfo]] + + schema = pydantic_to_schema(TestModel) + + expect_schema = pa.schema( + [ + pa.field("id", pa.utf8(), False), + pa.field( + "splits", + pa.list_( + pa.field( + "item", + pa.struct( + [ + pa.field("start_frame", pa.int64(), False), + pa.field("end_frame", pa.int64(), False), + ] + ), + True, + ) + ), + False, + ), + ] + ) + assert schema == expect_schema + + +def test_nested_struct_list_optional_container_and_items(): + class SplitInfo(pydantic.BaseModel): + start_frame: int + end_frame: int + + class TestModel(pydantic.BaseModel): + id: str + splits: Optional[list[Optional[SplitInfo]]] = None + + schema = pydantic_to_schema(TestModel) + + expect_schema = pa.schema( + [ + pa.field("id", pa.utf8(), False), + pa.field( + "splits", + pa.list_( + pa.field( + "item", + pa.struct( + [ + pa.field("start_frame", pa.int64(), False), + pa.field("end_frame", pa.int64(), False), + ] + ), + True, + ) + ), + True, + ), + ] + ) + assert schema == expect_schema + + +@pytest.mark.skipif( + sys.version_info < (3, 10), + reason="using PEP 604 union types requires python3.10 or higher", +) +def test_nested_struct_list_optional_items_pep604(): + class SplitInfo(pydantic.BaseModel): + start_frame: int + end_frame: int + + class TestModel(pydantic.BaseModel): + id: str + splits: list[SplitInfo | None] + + schema = pydantic_to_schema(TestModel) + + expect_schema = pa.schema( + [ + pa.field("id", pa.utf8(), False), + pa.field( + "splits", + pa.list_( + pa.field( + "item", + pa.struct( + [ + pa.field("start_frame", pa.int64(), False), + pa.field("end_frame", pa.int64(), False), + ] + ), + True, + ) + ), + False, + ), + ] + ) + assert schema == expect_schema + + @pytest.mark.skipif( sys.version_info > (3, 8), reason="using native type alias requires python3.9 or higher",