feat(python): support list of list fields from pydantic schema (#747)

For object detection, each row may correspond to an image and each image
can have multiple bounding boxes of x-y coordinates. This means that a
`bbox` field is potentially "list of list of float". This adds support
in our pydantic-pyarrow conversion for nested lists.
This commit is contained in:
Chang She
2023-12-27 09:10:09 -08:00
committed by Andrew Miracle
parent 55db26f59a
commit b71aa4117f
2 changed files with 16 additions and 1 deletions

View File

@@ -164,6 +164,9 @@ def _py_type_to_arrow_type(py_type: Type[Any]) -> pa.DataType:
return pa.date32()
elif py_type == datetime:
return pa.timestamp("us")
elif py_type.__origin__ in (list, tuple):
child = py_type.__args__[0]
return pa.list_(_py_type_to_arrow_type(child))
raise TypeError(
f"Converting Pydantic type to Arrow Type: unsupported type {py_type}"
)

View File

@@ -15,7 +15,7 @@
import json
import sys
from datetime import date, datetime
from typing import List, Optional
from typing import List, Optional, Tuple
import pyarrow as pa
import pydantic
@@ -39,6 +39,8 @@ def test_pydantic_to_arrow():
s: str
vec: list[float]
li: List[int]
lili: List[List[float]]
litu: List[Tuple[float, float]]
opt: Optional[str] = None
st: StructModel
dt: date
@@ -50,6 +52,8 @@ def test_pydantic_to_arrow():
s="hello",
vec=[1.0, 2.0, 3.0],
li=[2, 3, 4],
lili=[[2.5, 1.5], [3.5, 4.5], [5.5, 6.5]],
litu=[(2.5, 1.5), (3.5, 4.5), (5.5, 6.5)],
st=StructModel(a="a", b=1.0),
dt=date.today(),
dtt=datetime.now(),
@@ -63,6 +67,8 @@ def test_pydantic_to_arrow():
pa.field("s", pa.utf8(), False),
pa.field("vec", pa.list_(pa.float64()), False),
pa.field("li", pa.list_(pa.int64()), False),
pa.field("lili", pa.list_(pa.list_(pa.float64())), False),
pa.field("litu", pa.list_(pa.list_(pa.float64())), False),
pa.field("opt", pa.utf8(), True),
pa.field(
"st",
@@ -88,6 +94,8 @@ def test_pydantic_to_arrow_py38():
s: str
vec: List[float]
li: List[int]
lili: List[List[float]]
litu: List[Tuple[float, float]]
opt: Optional[str] = None
st: StructModel
dt: date
@@ -99,6 +107,8 @@ def test_pydantic_to_arrow_py38():
s="hello",
vec=[1.0, 2.0, 3.0],
li=[2, 3, 4],
lili=[[2.5, 1.5], [3.5, 4.5], [5.5, 6.5]],
litu=[(2.5, 1.5), (3.5, 4.5), (5.5, 6.5)],
st=StructModel(a="a", b=1.0),
dt=date.today(),
dtt=datetime.now(),
@@ -112,6 +122,8 @@ def test_pydantic_to_arrow_py38():
pa.field("s", pa.utf8(), False),
pa.field("vec", pa.list_(pa.float64()), False),
pa.field("li", pa.list_(pa.int64()), False),
pa.field("lili", pa.list_(pa.list_(pa.float64())), False),
pa.field("litu", pa.list_(pa.list_(pa.float64())), False),
pa.field("opt", pa.utf8(), True),
pa.field(
"st",