docs: improve pydantic integration docs (#2136)

Address usage mistakes in
https://github.com/lancedb/lancedb/issues/2135.

* Add example of how to use `LanceModel` and `Vector` decorator
* Add test for pydantic doc
* Fix the example to directly use LanceModel instead of calling
`MyModel.to_arrow_schema()` in the example.
* Add cross-reference link to pydantic doc site
* Configure mkdocs to watch code changes in python directory.
This commit is contained in:
Lei Xu
2025-02-21 12:48:37 -08:00
committed by GitHub
parent 544382df5e
commit 6fa1f37506
5 changed files with 76 additions and 22 deletions

View File

@@ -259,7 +259,8 @@ def _pydantic_to_field(name: str, field: FieldInfo) -> pa.Field:
def pydantic_to_schema(model: Type[pydantic.BaseModel]) -> pa.Schema:
"""Convert a Pydantic model to a PyArrow Schema.
"""Convert a [Pydantic Model][pydantic.BaseModel] to a
[PyArrow Schema][pyarrow.Schema].
Parameters
----------
@@ -269,24 +270,25 @@ def pydantic_to_schema(model: Type[pydantic.BaseModel]) -> pa.Schema:
Returns
-------
pyarrow.Schema
The Arrow Schema
Examples
--------
>>> from typing import List, Optional
>>> import pydantic
>>> from lancedb.pydantic import pydantic_to_schema
>>> from lancedb.pydantic import pydantic_to_schema, Vector
>>> class FooModel(pydantic.BaseModel):
... id: int
... s: str
... vec: List[float]
... vec: Vector(1536) # fixed_size_list<item: float32>[1536]
... li: List[int]
...
>>> schema = pydantic_to_schema(FooModel)
>>> assert schema == pa.schema([
... pa.field("id", pa.int64(), False),
... pa.field("s", pa.utf8(), False),
... pa.field("vec", pa.list_(pa.float64()), False),
... pa.field("vec", pa.list_(pa.float32(), 1536)),
... pa.field("li", pa.list_(pa.int64()), False),
... ])
"""
@@ -308,7 +310,7 @@ class LanceModel(pydantic.BaseModel):
... vector: Vector(2)
...
>>> db = lancedb.connect("./example")
>>> table = db.create_table("test", schema=TestModel.to_arrow_schema())
>>> table = db.create_table("test", schema=TestModel)
>>> table.add([
... TestModel(name="test", vector=[1.0, 2.0])
... ])

View File

@@ -0,0 +1,36 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
# --8<-- [start:imports]
import lancedb
from lancedb.pydantic import Vector, LanceModel
# --8<-- [end:imports]
def test_pydantic_model(tmp_path):
# --8<-- [start:base_model]
class PersonModel(LanceModel):
name: str
age: int
vector: Vector(2)
# --8<-- [end:base_model]
# --8<-- [start:set_url]
url = "./example"
# --8<-- [end:set_url]
url = tmp_path
# --8<-- [start:base_example]
db = lancedb.connect(url)
table = db.create_table("person", schema=PersonModel)
table.add(
[
PersonModel(name="bob", age=1, vector=[1.0, 2.0]),
PersonModel(name="alice", age=2, vector=[3.0, 4.0]),
]
)
assert table.count_rows() == 2
person = table.search([0.0, 0.0]).limit(1).to_pydantic(PersonModel)
assert person[0].name == "bob"
# --8<-- [end:base_example]