feat: add timezone handling for datetime in pydantic (#578)

If you add timezone information in the Field annotation for a datetime
then that will now be passed to the pyarrow data type.

I'm not sure how pyarrow enforces timezones, right now, it silently
coerces to the timezone given in the column regardless of whether the
input had the matching timezone or not. This is probably not the right
behavior. Though we could just make it so the user has to make the
pydantic model do the validation instead of doing that at the pyarrow
conversion layer.
This commit is contained in:
Chang She
2023-12-28 11:02:56 -08:00
committed by Weston Pace
parent bc83bc9838
commit 75d575ef4e
4 changed files with 105 additions and 14 deletions

View File

@@ -26,6 +26,7 @@ import numpy as np
import pyarrow as pa
import pydantic
import semver
from pydantic.fields import FieldInfo
from .embeddings import EmbeddingFunctionRegistry
@@ -142,8 +143,8 @@ def Vector(
return FixedSizeList
def _py_type_to_arrow_type(py_type: Type[Any]) -> pa.DataType:
"""Convert Python Type to Arrow DataType.
def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType:
"""Convert a field with native Python type to Arrow data type.
Raises
------
@@ -163,12 +164,13 @@ def _py_type_to_arrow_type(py_type: Type[Any]) -> pa.DataType:
elif py_type == date:
return pa.date32()
elif py_type == datetime:
return pa.timestamp("us")
elif py_type.__origin__ in (list, tuple):
tz = get_extras(field, "tz")
return pa.timestamp("us", tz=tz)
elif getattr(py_type, "__origin__", None) in (list, tuple):
child = py_type.__args__[0]
return pa.list_(_py_type_to_arrow_type(child))
return pa.list_(_py_type_to_arrow_type(child, field))
raise TypeError(
f"Converting Pydantic type to Arrow Type: unsupported type {py_type}"
f"Converting Pydantic type to Arrow Type: unsupported type {py_type}."
)
@@ -197,10 +199,10 @@ def _pydantic_to_arrow_type(field: pydantic.fields.FieldInfo) -> pa.DataType:
args = field.annotation.__args__
if origin == list:
child = args[0]
return pa.list_(_py_type_to_arrow_type(child))
return pa.list_(_py_type_to_arrow_type(child, field))
elif origin == Union:
if len(args) == 2 and args[1] == type(None):
return _py_type_to_arrow_type(args[0])
return _py_type_to_arrow_type(args[0], field)
elif inspect.isclass(field.annotation):
if issubclass(field.annotation, pydantic.BaseModel):
# Struct
@@ -208,7 +210,7 @@ def _pydantic_to_arrow_type(field: pydantic.fields.FieldInfo) -> pa.DataType:
return pa.struct(fields)
elif issubclass(field.annotation, FixedSizeListMixin):
return pa.list_(field.annotation.value_arrow_type(), field.annotation.dim())
return _py_type_to_arrow_type(field.annotation)
return _py_type_to_arrow_type(field.annotation, field)
def is_nullable(field: pydantic.fields.FieldInfo) -> bool:

View File

@@ -46,7 +46,7 @@ classifiers = [
repository = "https://github.com/lancedb/lancedb"
[project.optional-dependencies]
tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests", "duckdb"]
tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests", "duckdb", "pytz"]
dev = ["ruff", "pre-commit", "black"]
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
clip = ["torch", "pillow", "open-clip"]

View File

@@ -13,6 +13,7 @@
import json
import pytz
import sys
from datetime import date, datetime
from typing import List, Optional, Tuple
@@ -38,13 +39,14 @@ def test_pydantic_to_arrow():
id: int
s: str
vec: list[float]
li: List[int]
lili: List[List[float]]
litu: List[Tuple[float, float]]
li: list[int]
lili: list[list[float]]
litu: list[tuple[float, float]]
opt: Optional[str] = None
st: StructModel
dt: date
dtt: datetime
dt_with_tz: datetime = Field(json_schema_extra={"tz": "Asia/Shanghai"})
# d: dict
m = TestModel(
@@ -57,6 +59,7 @@ def test_pydantic_to_arrow():
st=StructModel(a="a", b=1.0),
dt=date.today(),
dtt=datetime.now(),
dt_with_tz=datetime.now(pytz.timezone("Asia/Shanghai")),
)
schema = pydantic_to_schema(TestModel)
@@ -79,11 +82,16 @@ def test_pydantic_to_arrow():
),
pa.field("dt", pa.date32(), False),
pa.field("dtt", pa.timestamp("us"), False),
pa.field("dt_with_tz", pa.timestamp("us", tz="Asia/Shanghai"), False),
]
)
assert schema == expect_schema
@pytest.mark.skipif(
sys.version_info > (3, 8),
reason="using native type alias requires python3.9 or higher",
)
def test_pydantic_to_arrow_py38():
class StructModel(pydantic.BaseModel):
a: str
@@ -100,6 +108,7 @@ def test_pydantic_to_arrow_py38():
st: StructModel
dt: date
dtt: datetime
dt_with_tz: datetime = Field(json_schema_extra={"tz": "Asia/Shanghai"})
# d: dict
m = TestModel(
@@ -112,6 +121,7 @@ def test_pydantic_to_arrow_py38():
st=StructModel(a="a", b=1.0),
dt=date.today(),
dtt=datetime.now(),
dt_with_tz=datetime.now(pytz.timezone("Asia/Shanghai")),
)
schema = pydantic_to_schema(TestModel)
@@ -134,6 +144,7 @@ def test_pydantic_to_arrow_py38():
),
pa.field("dt", pa.date32(), False),
pa.field("dtt", pa.timestamp("us"), False),
pa.field("dt_with_tz", pa.timestamp("us", tz="Asia/Shanghai"), False),
]
)
assert schema == expect_schema