From 9d8699f99e9ba2e7517bee0ae532780eb81ec057 Mon Sep 17 00:00:00 2001 From: Zelys Date: Fri, 3 Apr 2026 12:40:49 -0500 Subject: [PATCH] feat(python): support Enum types in Pydantic to Arrow schema conversion (#3232) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Fixes #1846. Python `Enum` fields raised `TypeError: Converting Pydantic type to Arrow Type: unsupported type ` when converting a Pydantic model to an Arrow schema. The fix adds Enum detection in `_pydantic_type_to_arrow_type`. When an Enum subclass is encountered, the value type of its members is inspected and mapped to the appropriate Arrow type: - `str`-valued enums (e.g. `class Status(str, Enum)`) → `pa.utf8()` - `int`-valued enums (e.g. `class Priority(int, Enum)`) → `pa.int64()` - Other homogeneous value types → the Arrow type for that Python type - Mixed-value or empty enums → `pa.utf8()` (safe fallback) This covers the common `(str, Enum)` and `(int, Enum)` mixin patterns used in practice. ## Changes - `python/python/lancedb/pydantic.py`: add Enum branch in `_pydantic_type_to_arrow_type` - `python/python/tests/test_pydantic.py`: add `test_enum_types` covering `str`, `int`, and `Optional` Enum fields ## Note on #2395 PR #2395 handles `StrEnum` (Python 3.11+) specifically, using a dictionary-encoded type. This PR handles the broader `(str, Enum)` / `(int, Enum)` mixin pattern that works across all Python versions and stores values as their natural Arrow type. AI assistance was used in developing this fix. --- python/python/lancedb/pydantic.py | 14 ++++++++++++++ python/python/tests/test_pydantic.py | 27 +++++++++++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/python/python/lancedb/pydantic.py b/python/python/lancedb/pydantic.py index 653ea3333..8b1f991f9 100644 --- a/python/python/lancedb/pydantic.py +++ b/python/python/lancedb/pydantic.py @@ -10,6 +10,7 @@ import sys import types from abc import ABC, abstractmethod from datetime import date, datetime +from enum import Enum from typing import ( TYPE_CHECKING, Any, @@ -314,6 +315,19 @@ def _pydantic_type_to_arrow_type(tp: Any, field: FieldInfo) -> pa.DataType: return pa.list_(pa.list_(tp.value_arrow_type(), tp.dim())) # For regular Vector return pa.list_(tp.value_arrow_type(), tp.dim()) + if _safe_issubclass(tp, Enum): + # Map Enum to the Arrow type of its value. + # For string-valued enums, use dictionary encoding for efficiency. + # For integer enums, use the native type. + # Fall back to utf8 for mixed-type or empty enums. + value_types = {type(m.value) for m in tp} + if len(value_types) == 1: + value_type = value_types.pop() + if value_type is str: + # Use dictionary encoding for string enums + return pa.dictionary(pa.int32(), pa.utf8()) + return _py_type_to_arrow_type(value_type, field) + return pa.utf8() return _py_type_to_arrow_type(tp, field) diff --git a/python/python/tests/test_pydantic.py b/python/python/tests/test_pydantic.py index fd0bb2c64..701bbef5a 100644 --- a/python/python/tests/test_pydantic.py +++ b/python/python/tests/test_pydantic.py @@ -3,6 +3,7 @@ import json from datetime import date, datetime +from enum import Enum from typing import List, Optional, Tuple import pyarrow as pa @@ -673,3 +674,29 @@ async def test_aliases_in_lance_model_async(mem_db_async): assert hasattr(model, "name") assert hasattr(model, "distance") assert model.distance < 0.01 + + +def test_enum_types(): + """Enum fields should map to the Arrow type of their value (issue #1846).""" + + class StrStatus(str, Enum): + PENDING = "pending" + RUNNING = "running" + DONE = "done" + + class IntPriority(int, Enum): + LOW = 1 + MEDIUM = 2 + HIGH = 3 + + class TestModel(pydantic.BaseModel): + status: StrStatus + priority: IntPriority + opt_status: Optional[StrStatus] = None + + schema = pydantic_to_schema(TestModel) + + assert schema.field("status").type == pa.dictionary(pa.int32(), pa.utf8()) + assert schema.field("priority").type == pa.int64() + assert schema.field("opt_status").type == pa.dictionary(pa.int32(), pa.utf8()) + assert schema.field("opt_status").nullable