feat: add timezone handling for datetime in pydantic (#578)

If you add timezone information in the Field annotation for a datetime
then that will now be passed to the pyarrow data type.

I'm not sure how pyarrow enforces timezones, right now, it silently
coerces to the timezone given in the column regardless of whether the
input had the matching timezone or not. This is probably not the right
behavior. Though we could just make it so the user has to make the
pydantic model do the validation instead of doing that at the pyarrow
conversion layer.
This commit is contained in:
Chang She
2023-12-28 11:02:56 -08:00
committed by GitHub
parent c8728d4ca1
commit 4b8af261a3
4 changed files with 105 additions and 14 deletions

View File

@@ -118,6 +118,84 @@ This guide will show how to create tables, insert data into them, and update the
table = db.create_table(table_name, schema=Content) table = db.create_table(table_name, schema=Content)
``` ```
#### Nested schemas
Sometimes your data model may contain nested objects.
For example, you may want to store the document string
and the document soure name as a nested Document object:
```python
class Document(BaseModel):
content: str
source: str
```
This can be used as the type of a LanceDB table column:
```python
class NestedSchema(LanceModel):
id: str
vector: Vector(1536)
document: Document
tbl = db.create_table("nested_table", schema=NestedSchema, mode="overwrite")
```
This creates a struct column called "document" that has two subfields
called "content" and "source":
```
In [28]: tbl.schema
Out[28]:
id: string not null
vector: fixed_size_list<item: float>[1536] not null
child 0, item: float
document: struct<content: string not null, source: string not null> not null
child 0, content: string not null
child 1, source: string not null
```
#### Validators
Note that neither pydantic nor pyarrow automatically validates that input data
is of the *correct* timezone, but this is easy to add as a custom field validator:
```python
from datetime import datetime
from zoneinfo import ZoneInfo
from lancedb.pydantic import LanceModel
from pydantic import Field, field_validator, ValidationError, ValidationInfo
tzname = "America/New_York"
tz = ZoneInfo(tzname)
class TestModel(LanceModel):
dt_with_tz: datetime = Field(json_schema_extra={"tz": tzname})
@field_validator('dt_with_tz')
@classmethod
def tz_must_match(cls, dt: datetime) -> datetime:
assert dt.tzinfo == tz
return dt
ok = TestModel(dt_with_tz=datetime.now(tz))
try:
TestModel(dt_with_tz=datetime.now(ZoneInfo("Asia/Shanghai")))
assert 0 == 1, "this should raise ValidationError"
except ValidationError:
print("A ValidationError was raised.")
pass
```
When you run this code it should print "A ValidationError was raised."
#### Pydantic custom types
LanceDB does NOT yet support converting pydantic custom types. If this is something you need,
please file a feature request on the [LanceDB Github repo](https://github.com/lancedb/lancedb/issues/new).
### Using Iterators / Writing Large Datasets ### Using Iterators / Writing Large Datasets
It is recommended to use itertators to add large datasets in batches when creating your table in one go. This does not create multiple versions of your dataset unlike manually adding batches using `table.add()` It is recommended to use itertators to add large datasets in batches when creating your table in one go. This does not create multiple versions of your dataset unlike manually adding batches using `table.add()`
@@ -153,7 +231,7 @@ This guide will show how to create tables, insert data into them, and update the
You can also use iterators of other types like Pandas dataframe or Pylists directly in the above example. You can also use iterators of other types like Pandas dataframe or Pylists directly in the above example.
## Creating Empty Table ## Creating Empty Table
You can also create empty tables in python. Initialize it with schema and later ingest data into it. You can create empty tables in python. Initialize it with schema and later ingest data into it.
```python ```python
import lancedb import lancedb

View File

@@ -26,6 +26,7 @@ import numpy as np
import pyarrow as pa import pyarrow as pa
import pydantic import pydantic
import semver import semver
from pydantic.fields import FieldInfo
from .embeddings import EmbeddingFunctionRegistry from .embeddings import EmbeddingFunctionRegistry
@@ -142,8 +143,8 @@ def Vector(
return FixedSizeList return FixedSizeList
def _py_type_to_arrow_type(py_type: Type[Any]) -> pa.DataType: def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType:
"""Convert Python Type to Arrow DataType. """Convert a field with native Python type to Arrow data type.
Raises Raises
------ ------
@@ -163,12 +164,13 @@ def _py_type_to_arrow_type(py_type: Type[Any]) -> pa.DataType:
elif py_type == date: elif py_type == date:
return pa.date32() return pa.date32()
elif py_type == datetime: elif py_type == datetime:
return pa.timestamp("us") tz = get_extras(field, "tz")
elif py_type.__origin__ in (list, tuple): return pa.timestamp("us", tz=tz)
elif getattr(py_type, "__origin__", None) in (list, tuple):
child = py_type.__args__[0] child = py_type.__args__[0]
return pa.list_(_py_type_to_arrow_type(child)) return pa.list_(_py_type_to_arrow_type(child, field))
raise TypeError( raise TypeError(
f"Converting Pydantic type to Arrow Type: unsupported type {py_type}" f"Converting Pydantic type to Arrow Type: unsupported type {py_type}."
) )
@@ -197,10 +199,10 @@ def _pydantic_to_arrow_type(field: pydantic.fields.FieldInfo) -> pa.DataType:
args = field.annotation.__args__ args = field.annotation.__args__
if origin == list: if origin == list:
child = args[0] child = args[0]
return pa.list_(_py_type_to_arrow_type(child)) return pa.list_(_py_type_to_arrow_type(child, field))
elif origin == Union: elif origin == Union:
if len(args) == 2 and args[1] == type(None): if len(args) == 2 and args[1] == type(None):
return _py_type_to_arrow_type(args[0]) return _py_type_to_arrow_type(args[0], field)
elif inspect.isclass(field.annotation): elif inspect.isclass(field.annotation):
if issubclass(field.annotation, pydantic.BaseModel): if issubclass(field.annotation, pydantic.BaseModel):
# Struct # Struct
@@ -208,7 +210,7 @@ def _pydantic_to_arrow_type(field: pydantic.fields.FieldInfo) -> pa.DataType:
return pa.struct(fields) return pa.struct(fields)
elif issubclass(field.annotation, FixedSizeListMixin): elif issubclass(field.annotation, FixedSizeListMixin):
return pa.list_(field.annotation.value_arrow_type(), field.annotation.dim()) return pa.list_(field.annotation.value_arrow_type(), field.annotation.dim())
return _py_type_to_arrow_type(field.annotation) return _py_type_to_arrow_type(field.annotation, field)
def is_nullable(field: pydantic.fields.FieldInfo) -> bool: def is_nullable(field: pydantic.fields.FieldInfo) -> bool:

View File

@@ -49,7 +49,7 @@ classifiers = [
repository = "https://github.com/lancedb/lancedb" repository = "https://github.com/lancedb/lancedb"
[project.optional-dependencies] [project.optional-dependencies]
tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests", "duckdb"] tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests", "duckdb", "pytz"]
dev = ["ruff", "pre-commit", "black"] dev = ["ruff", "pre-commit", "black"]
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
clip = ["torch", "pillow", "open-clip"] clip = ["torch", "pillow", "open-clip"]

View File

@@ -13,6 +13,7 @@
import json import json
import pytz
import sys import sys
from datetime import date, datetime from datetime import date, datetime
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
@@ -38,13 +39,14 @@ def test_pydantic_to_arrow():
id: int id: int
s: str s: str
vec: list[float] vec: list[float]
li: List[int] li: list[int]
lili: List[List[float]] lili: list[list[float]]
litu: List[Tuple[float, float]] litu: list[tuple[float, float]]
opt: Optional[str] = None opt: Optional[str] = None
st: StructModel st: StructModel
dt: date dt: date
dtt: datetime dtt: datetime
dt_with_tz: datetime = Field(json_schema_extra={"tz": "Asia/Shanghai"})
# d: dict # d: dict
m = TestModel( m = TestModel(
@@ -57,6 +59,7 @@ def test_pydantic_to_arrow():
st=StructModel(a="a", b=1.0), st=StructModel(a="a", b=1.0),
dt=date.today(), dt=date.today(),
dtt=datetime.now(), dtt=datetime.now(),
dt_with_tz=datetime.now(pytz.timezone("Asia/Shanghai")),
) )
schema = pydantic_to_schema(TestModel) schema = pydantic_to_schema(TestModel)
@@ -79,11 +82,16 @@ def test_pydantic_to_arrow():
), ),
pa.field("dt", pa.date32(), False), pa.field("dt", pa.date32(), False),
pa.field("dtt", pa.timestamp("us"), False), pa.field("dtt", pa.timestamp("us"), False),
pa.field("dt_with_tz", pa.timestamp("us", tz="Asia/Shanghai"), False),
] ]
) )
assert schema == expect_schema assert schema == expect_schema
@pytest.mark.skipif(
sys.version_info > (3, 8),
reason="using native type alias requires python3.9 or higher",
)
def test_pydantic_to_arrow_py38(): def test_pydantic_to_arrow_py38():
class StructModel(pydantic.BaseModel): class StructModel(pydantic.BaseModel):
a: str a: str
@@ -100,6 +108,7 @@ def test_pydantic_to_arrow_py38():
st: StructModel st: StructModel
dt: date dt: date
dtt: datetime dtt: datetime
dt_with_tz: datetime = Field(json_schema_extra={"tz": "Asia/Shanghai"})
# d: dict # d: dict
m = TestModel( m = TestModel(
@@ -112,6 +121,7 @@ def test_pydantic_to_arrow_py38():
st=StructModel(a="a", b=1.0), st=StructModel(a="a", b=1.0),
dt=date.today(), dt=date.today(),
dtt=datetime.now(), dtt=datetime.now(),
dt_with_tz=datetime.now(pytz.timezone("Asia/Shanghai")),
) )
schema = pydantic_to_schema(TestModel) schema = pydantic_to_schema(TestModel)
@@ -134,6 +144,7 @@ def test_pydantic_to_arrow_py38():
), ),
pa.field("dt", pa.date32(), False), pa.field("dt", pa.date32(), False),
pa.field("dtt", pa.timestamp("us"), False), pa.field("dtt", pa.timestamp("us"), False),
pa.field("dt_with_tz", pa.timestamp("us", tz="Asia/Shanghai"), False),
] ]
) )
assert schema == expect_schema assert schema == expect_schema