From b71aa4117ff373a41d432ff94e45eb55cebb5312 Mon Sep 17 00:00:00 2001 From: Chang She <759245+changhiskhan@users.noreply.github.com> Date: Wed, 27 Dec 2023 09:10:09 -0800 Subject: [PATCH] feat(python): support list of list fields from pydantic schema (#747) For object detection, each row may correspond to an image and each image can have multiple bounding boxes of x-y coordinates. This means that a `bbox` field is potentially "list of list of float". This adds support in our pydantic-pyarrow conversion for nested lists. --- python/lancedb/pydantic.py | 3 +++ python/tests/test_pydantic.py | 14 +++++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/python/lancedb/pydantic.py b/python/lancedb/pydantic.py index caa69405..537d60a0 100644 --- a/python/lancedb/pydantic.py +++ b/python/lancedb/pydantic.py @@ -164,6 +164,9 @@ def _py_type_to_arrow_type(py_type: Type[Any]) -> pa.DataType: return pa.date32() elif py_type == datetime: return pa.timestamp("us") + elif py_type.__origin__ in (list, tuple): + child = py_type.__args__[0] + return pa.list_(_py_type_to_arrow_type(child)) raise TypeError( f"Converting Pydantic type to Arrow Type: unsupported type {py_type}" ) diff --git a/python/tests/test_pydantic.py b/python/tests/test_pydantic.py index fa7e4849..e6739032 100644 --- a/python/tests/test_pydantic.py +++ b/python/tests/test_pydantic.py @@ -15,7 +15,7 @@ import json import sys from datetime import date, datetime -from typing import List, Optional +from typing import List, Optional, Tuple import pyarrow as pa import pydantic @@ -39,6 +39,8 @@ def test_pydantic_to_arrow(): s: str vec: list[float] li: List[int] + lili: List[List[float]] + litu: List[Tuple[float, float]] opt: Optional[str] = None st: StructModel dt: date @@ -50,6 +52,8 @@ def test_pydantic_to_arrow(): s="hello", vec=[1.0, 2.0, 3.0], li=[2, 3, 4], + lili=[[2.5, 1.5], [3.5, 4.5], [5.5, 6.5]], + litu=[(2.5, 1.5), (3.5, 4.5), (5.5, 6.5)], st=StructModel(a="a", b=1.0), dt=date.today(), dtt=datetime.now(), @@ -63,6 +67,8 @@ def test_pydantic_to_arrow(): pa.field("s", pa.utf8(), False), pa.field("vec", pa.list_(pa.float64()), False), pa.field("li", pa.list_(pa.int64()), False), + pa.field("lili", pa.list_(pa.list_(pa.float64())), False), + pa.field("litu", pa.list_(pa.list_(pa.float64())), False), pa.field("opt", pa.utf8(), True), pa.field( "st", @@ -88,6 +94,8 @@ def test_pydantic_to_arrow_py38(): s: str vec: List[float] li: List[int] + lili: List[List[float]] + litu: List[Tuple[float, float]] opt: Optional[str] = None st: StructModel dt: date @@ -99,6 +107,8 @@ def test_pydantic_to_arrow_py38(): s="hello", vec=[1.0, 2.0, 3.0], li=[2, 3, 4], + lili=[[2.5, 1.5], [3.5, 4.5], [5.5, 6.5]], + litu=[(2.5, 1.5), (3.5, 4.5), (5.5, 6.5)], st=StructModel(a="a", b=1.0), dt=date.today(), dtt=datetime.now(), @@ -112,6 +122,8 @@ def test_pydantic_to_arrow_py38(): pa.field("s", pa.utf8(), False), pa.field("vec", pa.list_(pa.float64()), False), pa.field("li", pa.list_(pa.int64()), False), + pa.field("lili", pa.list_(pa.list_(pa.float64())), False), + pa.field("litu", pa.list_(pa.list_(pa.float64())), False), pa.field("opt", pa.utf8(), True), pa.field( "st",