diff --git a/python/python/lancedb/pydantic.py b/python/python/lancedb/pydantic.py index 6838ccf3..855503f2 100644 --- a/python/python/lancedb/pydantic.py +++ b/python/python/lancedb/pydantic.py @@ -199,18 +199,29 @@ else: ] +def _pydantic_type_to_arrow_type(tp: Any, field: FieldInfo) -> pa.DataType: + if inspect.isclass(tp): + if issubclass(tp, pydantic.BaseModel): + # Struct + fields = _pydantic_model_to_fields(tp) + return pa.struct(fields) + if issubclass(tp, FixedSizeListMixin): + return pa.list_(tp.value_arrow_type(), tp.dim()) + return _py_type_to_arrow_type(tp, field) + + def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType: """Convert a Pydantic FieldInfo to Arrow DataType""" - if isinstance(field.annotation, (_GenericAlias, GenericAlias)): origin = field.annotation.__origin__ args = field.annotation.__args__ + if origin is list: child = args[0] return pa.list_(_py_type_to_arrow_type(child, field)) elif origin == Union: if len(args) == 2 and args[1] is type(None): - return _py_type_to_arrow_type(args[0], field) + return _pydantic_type_to_arrow_type(args[0], field) elif sys.version_info >= (3, 10) and isinstance(field.annotation, types.UnionType): args = field.annotation.__args__ if len(args) == 2: @@ -218,14 +229,7 @@ def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType: if typ is type(None): continue return _py_type_to_arrow_type(typ, field) - elif inspect.isclass(field.annotation): - if issubclass(field.annotation, pydantic.BaseModel): - # Struct - fields = _pydantic_model_to_fields(field.annotation) - return pa.struct(fields) - elif issubclass(field.annotation, FixedSizeListMixin): - return pa.list_(field.annotation.value_arrow_type(), field.annotation.dim()) - return _py_type_to_arrow_type(field.annotation, field) + return _pydantic_type_to_arrow_type(field.annotation, field) def is_nullable(field: FieldInfo) -> bool: diff --git a/python/python/tests/test_pydantic.py b/python/python/tests/test_pydantic.py index 0e76c3ad..1648a518 100644 --- a/python/python/tests/test_pydantic.py +++ b/python/python/tests/test_pydantic.py @@ -10,6 +10,7 @@ import pyarrow as pa import pydantic import pytest from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema +from pydantic import BaseModel from pydantic import Field @@ -252,3 +253,104 @@ def test_lance_model(): t = TestModel() assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3]) + + +def test_optional_nested_model(): + class WAMedia(BaseModel): + url: str + mimetype: str + filename: Optional[str] + error: Optional[str] + data: bytes + + class WALocation(BaseModel): + description: Optional[str] + latitude: str + longitude: str + + class ReplyToMessage(BaseModel): + id: str + participant: str + body: str + + class Message(BaseModel): + id: str + timestamp: int + from_: str + fromMe: bool + to: str + body: str + hasMedia: Optional[bool] + media: WAMedia + mediaUrl: Optional[str] + ack: Optional[int] + ackName: Optional[str] + author: Optional[str] + location: Optional[WALocation] + vCards: Optional[List[str]] + replyTo: Optional[ReplyToMessage] + + class AnyEvent(LanceModel): + id: str + session: str + metadata: Optional[str] = None + engine: str + event: str + + class MessageEvent(AnyEvent): + payload: Message + + schema = pydantic_to_schema(MessageEvent) + + payload = schema.field("payload") + assert payload.type == pa.struct( + [ + pa.field("id", pa.utf8(), False), + pa.field("timestamp", pa.int64(), False), + pa.field("from_", pa.utf8(), False), + pa.field("fromMe", pa.bool_(), False), + pa.field("to", pa.utf8(), False), + pa.field("body", pa.utf8(), False), + pa.field("hasMedia", pa.bool_(), True), + pa.field( + "media", + pa.struct( + [ + pa.field("url", pa.utf8(), False), + pa.field("mimetype", pa.utf8(), False), + pa.field("filename", pa.utf8(), True), + pa.field("error", pa.utf8(), True), + pa.field("data", pa.binary(), False), + ] + ), + False, + ), + pa.field("mediaUrl", pa.utf8(), True), + pa.field("ack", pa.int64(), True), + pa.field("ackName", pa.utf8(), True), + pa.field("author", pa.utf8(), True), + pa.field( + "location", + pa.struct( + [ + pa.field("description", pa.utf8(), True), + pa.field("latitude", pa.utf8(), False), + pa.field("longitude", pa.utf8(), False), + ] + ), + True, # Optional + ), + pa.field("vCards", pa.list_(pa.utf8()), True), + pa.field( + "replyTo", + pa.struct( + [ + pa.field("id", pa.utf8(), False), + pa.field("participant", pa.utf8(), False), + pa.field("body", pa.utf8(), False), + ] + ), + True, + ), + ] + )