mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-07 04:12:59 +00:00
feat: add timezone handling for datetime in pydantic (#578)
If you add timezone information in the Field annotation for a datetime then that will now be passed to the pyarrow data type. I'm not sure how pyarrow enforces timezones, right now, it silently coerces to the timezone given in the column regardless of whether the input had the matching timezone or not. This is probably not the right behavior. Though we could just make it so the user has to make the pydantic model do the validation instead of doing that at the pyarrow conversion layer.
This commit is contained in:
@@ -118,6 +118,84 @@ This guide will show how to create tables, insert data into them, and update the
|
|||||||
table = db.create_table(table_name, schema=Content)
|
table = db.create_table(table_name, schema=Content)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Nested schemas
|
||||||
|
|
||||||
|
Sometimes your data model may contain nested objects.
|
||||||
|
For example, you may want to store the document string
|
||||||
|
and the document soure name as a nested Document object:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class Document(BaseModel):
|
||||||
|
content: str
|
||||||
|
source: str
|
||||||
|
```
|
||||||
|
|
||||||
|
This can be used as the type of a LanceDB table column:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class NestedSchema(LanceModel):
|
||||||
|
id: str
|
||||||
|
vector: Vector(1536)
|
||||||
|
document: Document
|
||||||
|
|
||||||
|
tbl = db.create_table("nested_table", schema=NestedSchema, mode="overwrite")
|
||||||
|
```
|
||||||
|
|
||||||
|
This creates a struct column called "document" that has two subfields
|
||||||
|
called "content" and "source":
|
||||||
|
|
||||||
|
```
|
||||||
|
In [28]: tbl.schema
|
||||||
|
Out[28]:
|
||||||
|
id: string not null
|
||||||
|
vector: fixed_size_list<item: float>[1536] not null
|
||||||
|
child 0, item: float
|
||||||
|
document: struct<content: string not null, source: string not null> not null
|
||||||
|
child 0, content: string not null
|
||||||
|
child 1, source: string not null
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Validators
|
||||||
|
|
||||||
|
Note that neither pydantic nor pyarrow automatically validates that input data
|
||||||
|
is of the *correct* timezone, but this is easy to add as a custom field validator:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from datetime import datetime
|
||||||
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
|
from lancedb.pydantic import LanceModel
|
||||||
|
from pydantic import Field, field_validator, ValidationError, ValidationInfo
|
||||||
|
|
||||||
|
tzname = "America/New_York"
|
||||||
|
tz = ZoneInfo(tzname)
|
||||||
|
|
||||||
|
class TestModel(LanceModel):
|
||||||
|
dt_with_tz: datetime = Field(json_schema_extra={"tz": tzname})
|
||||||
|
|
||||||
|
@field_validator('dt_with_tz')
|
||||||
|
@classmethod
|
||||||
|
def tz_must_match(cls, dt: datetime) -> datetime:
|
||||||
|
assert dt.tzinfo == tz
|
||||||
|
return dt
|
||||||
|
|
||||||
|
ok = TestModel(dt_with_tz=datetime.now(tz))
|
||||||
|
|
||||||
|
try:
|
||||||
|
TestModel(dt_with_tz=datetime.now(ZoneInfo("Asia/Shanghai")))
|
||||||
|
assert 0 == 1, "this should raise ValidationError"
|
||||||
|
except ValidationError:
|
||||||
|
print("A ValidationError was raised.")
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
|
When you run this code it should print "A ValidationError was raised."
|
||||||
|
|
||||||
|
#### Pydantic custom types
|
||||||
|
|
||||||
|
LanceDB does NOT yet support converting pydantic custom types. If this is something you need,
|
||||||
|
please file a feature request on the [LanceDB Github repo](https://github.com/lancedb/lancedb/issues/new).
|
||||||
|
|
||||||
### Using Iterators / Writing Large Datasets
|
### Using Iterators / Writing Large Datasets
|
||||||
|
|
||||||
It is recommended to use itertators to add large datasets in batches when creating your table in one go. This does not create multiple versions of your dataset unlike manually adding batches using `table.add()`
|
It is recommended to use itertators to add large datasets in batches when creating your table in one go. This does not create multiple versions of your dataset unlike manually adding batches using `table.add()`
|
||||||
@@ -153,7 +231,7 @@ This guide will show how to create tables, insert data into them, and update the
|
|||||||
You can also use iterators of other types like Pandas dataframe or Pylists directly in the above example.
|
You can also use iterators of other types like Pandas dataframe or Pylists directly in the above example.
|
||||||
|
|
||||||
## Creating Empty Table
|
## Creating Empty Table
|
||||||
You can also create empty tables in python. Initialize it with schema and later ingest data into it.
|
You can create empty tables in python. Initialize it with schema and later ingest data into it.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import lancedb
|
import lancedb
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ import numpy as np
|
|||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import pydantic
|
import pydantic
|
||||||
import semver
|
import semver
|
||||||
|
from pydantic.fields import FieldInfo
|
||||||
|
|
||||||
from .embeddings import EmbeddingFunctionRegistry
|
from .embeddings import EmbeddingFunctionRegistry
|
||||||
|
|
||||||
@@ -142,8 +143,8 @@ def Vector(
|
|||||||
return FixedSizeList
|
return FixedSizeList
|
||||||
|
|
||||||
|
|
||||||
def _py_type_to_arrow_type(py_type: Type[Any]) -> pa.DataType:
|
def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType:
|
||||||
"""Convert Python Type to Arrow DataType.
|
"""Convert a field with native Python type to Arrow data type.
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
@@ -163,12 +164,13 @@ def _py_type_to_arrow_type(py_type: Type[Any]) -> pa.DataType:
|
|||||||
elif py_type == date:
|
elif py_type == date:
|
||||||
return pa.date32()
|
return pa.date32()
|
||||||
elif py_type == datetime:
|
elif py_type == datetime:
|
||||||
return pa.timestamp("us")
|
tz = get_extras(field, "tz")
|
||||||
elif py_type.__origin__ in (list, tuple):
|
return pa.timestamp("us", tz=tz)
|
||||||
|
elif getattr(py_type, "__origin__", None) in (list, tuple):
|
||||||
child = py_type.__args__[0]
|
child = py_type.__args__[0]
|
||||||
return pa.list_(_py_type_to_arrow_type(child))
|
return pa.list_(_py_type_to_arrow_type(child, field))
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Converting Pydantic type to Arrow Type: unsupported type {py_type}"
|
f"Converting Pydantic type to Arrow Type: unsupported type {py_type}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -197,10 +199,10 @@ def _pydantic_to_arrow_type(field: pydantic.fields.FieldInfo) -> pa.DataType:
|
|||||||
args = field.annotation.__args__
|
args = field.annotation.__args__
|
||||||
if origin == list:
|
if origin == list:
|
||||||
child = args[0]
|
child = args[0]
|
||||||
return pa.list_(_py_type_to_arrow_type(child))
|
return pa.list_(_py_type_to_arrow_type(child, field))
|
||||||
elif origin == Union:
|
elif origin == Union:
|
||||||
if len(args) == 2 and args[1] == type(None):
|
if len(args) == 2 and args[1] == type(None):
|
||||||
return _py_type_to_arrow_type(args[0])
|
return _py_type_to_arrow_type(args[0], field)
|
||||||
elif inspect.isclass(field.annotation):
|
elif inspect.isclass(field.annotation):
|
||||||
if issubclass(field.annotation, pydantic.BaseModel):
|
if issubclass(field.annotation, pydantic.BaseModel):
|
||||||
# Struct
|
# Struct
|
||||||
@@ -208,7 +210,7 @@ def _pydantic_to_arrow_type(field: pydantic.fields.FieldInfo) -> pa.DataType:
|
|||||||
return pa.struct(fields)
|
return pa.struct(fields)
|
||||||
elif issubclass(field.annotation, FixedSizeListMixin):
|
elif issubclass(field.annotation, FixedSizeListMixin):
|
||||||
return pa.list_(field.annotation.value_arrow_type(), field.annotation.dim())
|
return pa.list_(field.annotation.value_arrow_type(), field.annotation.dim())
|
||||||
return _py_type_to_arrow_type(field.annotation)
|
return _py_type_to_arrow_type(field.annotation, field)
|
||||||
|
|
||||||
|
|
||||||
def is_nullable(field: pydantic.fields.FieldInfo) -> bool:
|
def is_nullable(field: pydantic.fields.FieldInfo) -> bool:
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ classifiers = [
|
|||||||
repository = "https://github.com/lancedb/lancedb"
|
repository = "https://github.com/lancedb/lancedb"
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests", "duckdb"]
|
tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests", "duckdb", "pytz"]
|
||||||
dev = ["ruff", "pre-commit", "black"]
|
dev = ["ruff", "pre-commit", "black"]
|
||||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
||||||
clip = ["torch", "pillow", "open-clip"]
|
clip = ["torch", "pillow", "open-clip"]
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import pytz
|
||||||
import sys
|
import sys
|
||||||
from datetime import date, datetime
|
from datetime import date, datetime
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
@@ -38,13 +39,14 @@ def test_pydantic_to_arrow():
|
|||||||
id: int
|
id: int
|
||||||
s: str
|
s: str
|
||||||
vec: list[float]
|
vec: list[float]
|
||||||
li: List[int]
|
li: list[int]
|
||||||
lili: List[List[float]]
|
lili: list[list[float]]
|
||||||
litu: List[Tuple[float, float]]
|
litu: list[tuple[float, float]]
|
||||||
opt: Optional[str] = None
|
opt: Optional[str] = None
|
||||||
st: StructModel
|
st: StructModel
|
||||||
dt: date
|
dt: date
|
||||||
dtt: datetime
|
dtt: datetime
|
||||||
|
dt_with_tz: datetime = Field(json_schema_extra={"tz": "Asia/Shanghai"})
|
||||||
# d: dict
|
# d: dict
|
||||||
|
|
||||||
m = TestModel(
|
m = TestModel(
|
||||||
@@ -57,6 +59,7 @@ def test_pydantic_to_arrow():
|
|||||||
st=StructModel(a="a", b=1.0),
|
st=StructModel(a="a", b=1.0),
|
||||||
dt=date.today(),
|
dt=date.today(),
|
||||||
dtt=datetime.now(),
|
dtt=datetime.now(),
|
||||||
|
dt_with_tz=datetime.now(pytz.timezone("Asia/Shanghai")),
|
||||||
)
|
)
|
||||||
|
|
||||||
schema = pydantic_to_schema(TestModel)
|
schema = pydantic_to_schema(TestModel)
|
||||||
@@ -79,11 +82,16 @@ def test_pydantic_to_arrow():
|
|||||||
),
|
),
|
||||||
pa.field("dt", pa.date32(), False),
|
pa.field("dt", pa.date32(), False),
|
||||||
pa.field("dtt", pa.timestamp("us"), False),
|
pa.field("dtt", pa.timestamp("us"), False),
|
||||||
|
pa.field("dt_with_tz", pa.timestamp("us", tz="Asia/Shanghai"), False),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
assert schema == expect_schema
|
assert schema == expect_schema
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info > (3, 8),
|
||||||
|
reason="using native type alias requires python3.9 or higher",
|
||||||
|
)
|
||||||
def test_pydantic_to_arrow_py38():
|
def test_pydantic_to_arrow_py38():
|
||||||
class StructModel(pydantic.BaseModel):
|
class StructModel(pydantic.BaseModel):
|
||||||
a: str
|
a: str
|
||||||
@@ -100,6 +108,7 @@ def test_pydantic_to_arrow_py38():
|
|||||||
st: StructModel
|
st: StructModel
|
||||||
dt: date
|
dt: date
|
||||||
dtt: datetime
|
dtt: datetime
|
||||||
|
dt_with_tz: datetime = Field(json_schema_extra={"tz": "Asia/Shanghai"})
|
||||||
# d: dict
|
# d: dict
|
||||||
|
|
||||||
m = TestModel(
|
m = TestModel(
|
||||||
@@ -112,6 +121,7 @@ def test_pydantic_to_arrow_py38():
|
|||||||
st=StructModel(a="a", b=1.0),
|
st=StructModel(a="a", b=1.0),
|
||||||
dt=date.today(),
|
dt=date.today(),
|
||||||
dtt=datetime.now(),
|
dtt=datetime.now(),
|
||||||
|
dt_with_tz=datetime.now(pytz.timezone("Asia/Shanghai")),
|
||||||
)
|
)
|
||||||
|
|
||||||
schema = pydantic_to_schema(TestModel)
|
schema = pydantic_to_schema(TestModel)
|
||||||
@@ -134,6 +144,7 @@ def test_pydantic_to_arrow_py38():
|
|||||||
),
|
),
|
||||||
pa.field("dt", pa.date32(), False),
|
pa.field("dt", pa.date32(), False),
|
||||||
pa.field("dtt", pa.timestamp("us"), False),
|
pa.field("dtt", pa.timestamp("us"), False),
|
||||||
|
pa.field("dt_with_tz", pa.timestamp("us", tz="Asia/Shanghai"), False),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
assert schema == expect_schema
|
assert schema == expect_schema
|
||||||
|
|||||||
Reference in New Issue
Block a user