feat(python): support new style optional syntax (#793)

This commit is contained in:
Chang She
2024-01-09 07:03:29 -08:00
committed by Andrew Miracle
parent 615c469af2
commit ba01d274eb
3 changed files with 36 additions and 1 deletions

View File

@@ -192,6 +192,7 @@ else:
def _pydantic_to_arrow_type(field: pydantic.fields.FieldInfo) -> pa.DataType:
"""Convert a Pydantic FieldInfo to Arrow DataType"""
if isinstance(field.annotation, _GenericAlias) or (
sys.version_info > (3, 9) and isinstance(field.annotation, types.GenericAlias)
):
@@ -203,6 +204,13 @@ def _pydantic_to_arrow_type(field: pydantic.fields.FieldInfo) -> pa.DataType:
elif origin == Union:
if len(args) == 2 and args[1] == type(None):
return _py_type_to_arrow_type(args[0], field)
elif sys.version_info >= (3, 10) and isinstance(field.annotation, types.UnionType):
args = field.annotation.__args__
if len(args) == 2:
for typ in args:
if typ == type(None):
continue
return _py_type_to_arrow_type(typ, field)
elif inspect.isclass(field.annotation):
if issubclass(field.annotation, pydantic.BaseModel):
# Struct
@@ -221,6 +229,11 @@ def is_nullable(field: pydantic.fields.FieldInfo) -> bool:
if origin == Union:
if len(args) == 2 and args[1] == type(None):
return True
elif sys.version_info >= (3, 10) and isinstance(field.annotation, types.UnionType):
args = field.annotation.__args__
for typ in args:
if typ == type(None):
return True
return False

View File

@@ -82,7 +82,7 @@ def test_search_index(tmp_path, table):
def test_create_index_from_table(tmp_path, table):
table.create_fts_index("text")
df = table.search("puppy").limit(10).select(["text"]).to_pandas()
assert len(df) == 10
assert len(df) <= 10
assert "text" in df.columns
# Check whether it can be updated

View File

@@ -88,6 +88,28 @@ def test_pydantic_to_arrow():
assert schema == expect_schema
@pytest.mark.skipif(
sys.version_info < (3, 10),
reason="using | type syntax requires python3.10 or higher",
)
def test_optional_types_py310():
class TestModel(pydantic.BaseModel):
a: str | None
b: None | str
c: Optional[str]
schema = pydantic_to_schema(TestModel)
expect_schema = pa.schema(
[
pa.field("a", pa.utf8(), True),
pa.field("b", pa.utf8(), True),
pa.field("c", pa.utf8(), True),
]
)
assert schema == expect_schema
@pytest.mark.skipif(
sys.version_info > (3, 8),
reason="using native type alias requires python3.9 or higher",