From ba01d274eba0d2c79efb3ee6b44cbed515f1eadc Mon Sep 17 00:00:00 2001 From: Chang She <759245+changhiskhan@users.noreply.github.com> Date: Tue, 9 Jan 2024 07:03:29 -0800 Subject: [PATCH] feat(python): support new style optional syntax (#793) --- python/lancedb/pydantic.py | 13 +++++++++++++ python/tests/test_fts.py | 2 +- python/tests/test_pydantic.py | 22 ++++++++++++++++++++++ 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/python/lancedb/pydantic.py b/python/lancedb/pydantic.py index 48a67189..859eeaa8 100644 --- a/python/lancedb/pydantic.py +++ b/python/lancedb/pydantic.py @@ -192,6 +192,7 @@ else: def _pydantic_to_arrow_type(field: pydantic.fields.FieldInfo) -> pa.DataType: """Convert a Pydantic FieldInfo to Arrow DataType""" + if isinstance(field.annotation, _GenericAlias) or ( sys.version_info > (3, 9) and isinstance(field.annotation, types.GenericAlias) ): @@ -203,6 +204,13 @@ def _pydantic_to_arrow_type(field: pydantic.fields.FieldInfo) -> pa.DataType: elif origin == Union: if len(args) == 2 and args[1] == type(None): return _py_type_to_arrow_type(args[0], field) + elif sys.version_info >= (3, 10) and isinstance(field.annotation, types.UnionType): + args = field.annotation.__args__ + if len(args) == 2: + for typ in args: + if typ == type(None): + continue + return _py_type_to_arrow_type(typ, field) elif inspect.isclass(field.annotation): if issubclass(field.annotation, pydantic.BaseModel): # Struct @@ -221,6 +229,11 @@ def is_nullable(field: pydantic.fields.FieldInfo) -> bool: if origin == Union: if len(args) == 2 and args[1] == type(None): return True + elif sys.version_info >= (3, 10) and isinstance(field.annotation, types.UnionType): + args = field.annotation.__args__ + for typ in args: + if typ == type(None): + return True return False diff --git a/python/tests/test_fts.py b/python/tests/test_fts.py index b7c81f61..f65dc4ca 100644 --- a/python/tests/test_fts.py +++ b/python/tests/test_fts.py @@ -82,7 +82,7 @@ def test_search_index(tmp_path, table): def test_create_index_from_table(tmp_path, table): table.create_fts_index("text") df = table.search("puppy").limit(10).select(["text"]).to_pandas() - assert len(df) == 10 + assert len(df) <= 10 assert "text" in df.columns # Check whether it can be updated diff --git a/python/tests/test_pydantic.py b/python/tests/test_pydantic.py index 8a3ee16b..c6376dce 100644 --- a/python/tests/test_pydantic.py +++ b/python/tests/test_pydantic.py @@ -88,6 +88,28 @@ def test_pydantic_to_arrow(): assert schema == expect_schema +@pytest.mark.skipif( + sys.version_info < (3, 10), + reason="using | type syntax requires python3.10 or higher", +) +def test_optional_types_py310(): + class TestModel(pydantic.BaseModel): + a: str | None + b: None | str + c: Optional[str] + + schema = pydantic_to_schema(TestModel) + + expect_schema = pa.schema( + [ + pa.field("a", pa.utf8(), True), + pa.field("b", pa.utf8(), True), + pa.field("c", pa.utf8(), True), + ] + ) + assert schema == expect_schema + + @pytest.mark.skipif( sys.version_info > (3, 8), reason="using native type alias requires python3.9 or higher",